diff --git a/.claude/scheduled_tasks.lock b/.claude/scheduled_tasks.lock new file mode 100644 index 000000000..df47ff816 --- /dev/null +++ b/.claude/scheduled_tasks.lock @@ -0,0 +1 @@ +{"sessionId":"d6574c47-eafc-4a94-9dce-f9ffea22b53c","pid":10111,"acquiredAt":1775248373916} \ No newline at end of file diff --git a/.superset/config.json b/.superset/config.json new file mode 100644 index 000000000..f806b5255 --- /dev/null +++ b/.superset/config.json @@ -0,0 +1,5 @@ +{ + "setup": [], + "teardown": [], + "run": [] +} diff --git a/seaweed-volume/Cargo.lock b/seaweed-volume/Cargo.lock index b5401c9a5..ad47d1d6f 100644 --- a/seaweed-volume/Cargo.lock +++ b/seaweed-volume/Cargo.lock @@ -2561,6 +2561,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +[[package]] +name = "openssl-src" +version = "300.5.5+3.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f1787d533e03597a7934fd0a765f0d28e94ecc5fb7789f8053b1e699a56f709" +dependencies = [ + "cc", +] + [[package]] name = "openssl-sys" version = "0.9.111" @@ -2569,6 +2578,7 @@ checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" dependencies = [ "cc", "libc", + "openssl-src", "pkg-config", "vcpkg", ] @@ -4654,6 +4664,7 @@ dependencies = [ "memmap2", "mime_guess", "multer", + "openssl", "parking_lot 0.12.5", "pprof", "prometheus", diff --git a/weed/admin/dash/types.go b/weed/admin/dash/types.go index 965166de4..a7f22c541 100644 --- a/weed/admin/dash/types.go +++ b/weed/admin/dash/types.go @@ -447,11 +447,6 @@ type QueueStats = maintenance.QueueStats type WorkerDetailsData = maintenance.WorkerDetailsData type WorkerPerformance = maintenance.WorkerPerformance -// GetTaskIcon returns the icon CSS class for a task type from its UI provider -func GetTaskIcon(taskType MaintenanceTaskType) string { - return maintenance.GetTaskIcon(taskType) -} - // Status constants (these are still static) const ( TaskStatusPending = maintenance.TaskStatusPending diff --git a/weed/admin/handlers/cluster_handlers.go b/weed/admin/handlers/cluster_handlers.go index c5303458f..b3bcc4fe3 100644 --- a/weed/admin/handlers/cluster_handlers.go +++ b/weed/admin/handlers/cluster_handlers.go @@ -312,29 +312,6 @@ func (h *ClusterHandlers) ShowClusterFilers(w http.ResponseWriter, r *http.Reque } } -// ShowClusterBrokers renders the cluster message brokers page -func (h *ClusterHandlers) ShowClusterBrokers(w http.ResponseWriter, r *http.Request) { - // Get cluster brokers data - brokersData, err := h.adminServer.GetClusterBrokers() - if err != nil { - writeJSONError(w, http.StatusInternalServerError, "Failed to get cluster brokers: "+err.Error()) - return - } - - username := usernameOrDefault(r) - brokersData.Username = username - - // Render HTML template - w.Header().Set("Content-Type", "text/html") - brokersComponent := app.ClusterBrokers(*brokersData) - viewCtx := layout.NewViewContext(r, username, dash.CSRFTokenFromContext(r.Context())) - layoutComponent := layout.Layout(viewCtx, brokersComponent) - if err := layoutComponent.Render(r.Context(), w); err != nil { - writeJSONError(w, http.StatusInternalServerError, "Failed to render template: "+err.Error()) - return - } -} - // GetClusterTopology returns the cluster topology as JSON func (h *ClusterHandlers) GetClusterTopology(w http.ResponseWriter, r *http.Request) { topology, err := h.adminServer.GetClusterTopology() diff --git a/weed/admin/handlers/mq_handlers.go b/weed/admin/handlers/mq_handlers.go index 5efa3cc3a..6c6e46e57 100644 --- a/weed/admin/handlers/mq_handlers.go +++ b/weed/admin/handlers/mq_handlers.go @@ -78,34 +78,6 @@ func (h *MessageQueueHandlers) ShowTopics(w http.ResponseWriter, r *http.Request } } -// ShowSubscribers renders the message queue subscribers page -func (h *MessageQueueHandlers) ShowSubscribers(w http.ResponseWriter, r *http.Request) { - // Get subscribers data - subscribersData, err := h.adminServer.GetSubscribers() - if err != nil { - writeJSONError(w, http.StatusInternalServerError, "Failed to get subscribers: "+err.Error()) - return - } - - // Set username - username := dash.UsernameFromContext(r.Context()) - if username == "" { - username = "admin" - } - subscribersData.Username = username - - // Render HTML template - w.Header().Set("Content-Type", "text/html") - subscribersComponent := app.Subscribers(*subscribersData) - viewCtx := layout.NewViewContext(r, username, dash.CSRFTokenFromContext(r.Context())) - layoutComponent := layout.Layout(viewCtx, subscribersComponent) - err = layoutComponent.Render(r.Context(), w) - if err != nil { - writeJSONError(w, http.StatusInternalServerError, "Failed to render template: "+err.Error()) - return - } -} - // ShowTopicDetails renders the topic details page func (h *MessageQueueHandlers) ShowTopicDetails(w http.ResponseWriter, r *http.Request) { // Get topic parameters from URL diff --git a/weed/admin/maintenance/config_verification.go b/weed/admin/maintenance/config_verification.go deleted file mode 100644 index 0ac40aad1..000000000 --- a/weed/admin/maintenance/config_verification.go +++ /dev/null @@ -1,124 +0,0 @@ -package maintenance - -import ( - "fmt" - - "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb" -) - -// VerifyProtobufConfig demonstrates that the protobuf configuration system is working -func VerifyProtobufConfig() error { - // Create configuration manager - configManager := NewMaintenanceConfigManager() - config := configManager.GetConfig() - - // Verify basic configuration - if !config.Enabled { - return fmt.Errorf("expected config to be enabled by default") - } - - if config.ScanIntervalSeconds != 30*60 { - return fmt.Errorf("expected scan interval to be 1800 seconds, got %d", config.ScanIntervalSeconds) - } - - // Verify policy configuration - if config.Policy == nil { - return fmt.Errorf("expected policy to be configured") - } - - if config.Policy.GlobalMaxConcurrent != 4 { - return fmt.Errorf("expected global max concurrent to be 4, got %d", config.Policy.GlobalMaxConcurrent) - } - - // Verify task policies - vacuumPolicy := config.Policy.TaskPolicies["vacuum"] - if vacuumPolicy == nil { - return fmt.Errorf("expected vacuum policy to be configured") - } - - if !vacuumPolicy.Enabled { - return fmt.Errorf("expected vacuum policy to be enabled") - } - - // Verify typed configuration access - vacuumConfig := vacuumPolicy.GetVacuumConfig() - if vacuumConfig == nil { - return fmt.Errorf("expected vacuum config to be accessible") - } - - if vacuumConfig.GarbageThreshold != 0.3 { - return fmt.Errorf("expected garbage threshold to be 0.3, got %f", vacuumConfig.GarbageThreshold) - } - - // Verify helper functions work - if !IsTaskEnabled(config.Policy, "vacuum") { - return fmt.Errorf("expected vacuum task to be enabled via helper function") - } - - maxConcurrent := GetMaxConcurrent(config.Policy, "vacuum") - if maxConcurrent != 2 { - return fmt.Errorf("expected vacuum max concurrent to be 2, got %d", maxConcurrent) - } - - // Verify erasure coding configuration - ecPolicy := config.Policy.TaskPolicies["erasure_coding"] - if ecPolicy == nil { - return fmt.Errorf("expected EC policy to be configured") - } - - ecConfig := ecPolicy.GetErasureCodingConfig() - if ecConfig == nil { - return fmt.Errorf("expected EC config to be accessible") - } - - // Verify configurable EC fields only - if ecConfig.FullnessRatio <= 0 || ecConfig.FullnessRatio > 1 { - return fmt.Errorf("expected EC config to have valid fullness ratio (0-1), got %f", ecConfig.FullnessRatio) - } - - return nil -} - -// GetProtobufConfigSummary returns a summary of the current protobuf configuration -func GetProtobufConfigSummary() string { - configManager := NewMaintenanceConfigManager() - config := configManager.GetConfig() - - summary := fmt.Sprintf("SeaweedFS Protobuf Maintenance Configuration:\n") - summary += fmt.Sprintf(" Enabled: %v\n", config.Enabled) - summary += fmt.Sprintf(" Scan Interval: %d seconds\n", config.ScanIntervalSeconds) - summary += fmt.Sprintf(" Max Retries: %d\n", config.MaxRetries) - summary += fmt.Sprintf(" Global Max Concurrent: %d\n", config.Policy.GlobalMaxConcurrent) - summary += fmt.Sprintf(" Task Policies: %d configured\n", len(config.Policy.TaskPolicies)) - - for taskType, policy := range config.Policy.TaskPolicies { - summary += fmt.Sprintf(" %s: enabled=%v, max_concurrent=%d\n", - taskType, policy.Enabled, policy.MaxConcurrent) - } - - return summary -} - -// CreateCustomConfig demonstrates creating a custom protobuf configuration -func CreateCustomConfig() *worker_pb.MaintenanceConfig { - return &worker_pb.MaintenanceConfig{ - Enabled: true, - ScanIntervalSeconds: 60 * 60, // 1 hour - MaxRetries: 5, - Policy: &worker_pb.MaintenancePolicy{ - GlobalMaxConcurrent: 8, - TaskPolicies: map[string]*worker_pb.TaskPolicy{ - "custom_vacuum": { - Enabled: true, - MaxConcurrent: 4, - TaskConfig: &worker_pb.TaskPolicy_VacuumConfig{ - VacuumConfig: &worker_pb.VacuumTaskConfig{ - GarbageThreshold: 0.5, - MinVolumeAgeHours: 48, - }, - }, - }, - }, - }, - } -} diff --git a/weed/admin/maintenance/maintenance_config_proto.go b/weed/admin/maintenance/maintenance_config_proto.go index 0d0bca7c6..4295a706f 100644 --- a/weed/admin/maintenance/maintenance_config_proto.go +++ b/weed/admin/maintenance/maintenance_config_proto.go @@ -1,24 +1,9 @@ package maintenance import ( - "fmt" - "time" - "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb" ) -// MaintenanceConfigManager handles protobuf-based configuration -type MaintenanceConfigManager struct { - config *worker_pb.MaintenanceConfig -} - -// NewMaintenanceConfigManager creates a new config manager with defaults -func NewMaintenanceConfigManager() *MaintenanceConfigManager { - return &MaintenanceConfigManager{ - config: DefaultMaintenanceConfigProto(), - } -} - // DefaultMaintenanceConfigProto returns default configuration as protobuf func DefaultMaintenanceConfigProto() *worker_pb.MaintenanceConfig { return &worker_pb.MaintenanceConfig{ @@ -34,253 +19,3 @@ func DefaultMaintenanceConfigProto() *worker_pb.MaintenanceConfig { Policy: nil, } } - -// GetConfig returns the current configuration -func (mcm *MaintenanceConfigManager) GetConfig() *worker_pb.MaintenanceConfig { - return mcm.config -} - -// Type-safe configuration accessors - -// GetVacuumConfig returns vacuum-specific configuration for a task type -func (mcm *MaintenanceConfigManager) GetVacuumConfig(taskType string) *worker_pb.VacuumTaskConfig { - if policy := mcm.getTaskPolicy(taskType); policy != nil { - if vacuumConfig := policy.GetVacuumConfig(); vacuumConfig != nil { - return vacuumConfig - } - } - // Return defaults if not configured - return &worker_pb.VacuumTaskConfig{ - GarbageThreshold: 0.3, - MinVolumeAgeHours: 24, - } -} - -// GetErasureCodingConfig returns EC-specific configuration for a task type -func (mcm *MaintenanceConfigManager) GetErasureCodingConfig(taskType string) *worker_pb.ErasureCodingTaskConfig { - if policy := mcm.getTaskPolicy(taskType); policy != nil { - if ecConfig := policy.GetErasureCodingConfig(); ecConfig != nil { - return ecConfig - } - } - // Return defaults if not configured - return &worker_pb.ErasureCodingTaskConfig{ - FullnessRatio: 0.95, - QuietForSeconds: 3600, - MinVolumeSizeMb: 100, - CollectionFilter: "", - } -} - -// GetBalanceConfig returns balance-specific configuration for a task type -func (mcm *MaintenanceConfigManager) GetBalanceConfig(taskType string) *worker_pb.BalanceTaskConfig { - if policy := mcm.getTaskPolicy(taskType); policy != nil { - if balanceConfig := policy.GetBalanceConfig(); balanceConfig != nil { - return balanceConfig - } - } - // Return defaults if not configured - return &worker_pb.BalanceTaskConfig{ - ImbalanceThreshold: 0.2, - MinServerCount: 2, - } -} - -// GetReplicationConfig returns replication-specific configuration for a task type -func (mcm *MaintenanceConfigManager) GetReplicationConfig(taskType string) *worker_pb.ReplicationTaskConfig { - if policy := mcm.getTaskPolicy(taskType); policy != nil { - if replicationConfig := policy.GetReplicationConfig(); replicationConfig != nil { - return replicationConfig - } - } - // Return defaults if not configured - return &worker_pb.ReplicationTaskConfig{ - TargetReplicaCount: 2, - } -} - -// Typed convenience methods for getting task configurations - -// GetVacuumTaskConfigForType returns vacuum configuration for a specific task type -func (mcm *MaintenanceConfigManager) GetVacuumTaskConfigForType(taskType string) *worker_pb.VacuumTaskConfig { - return GetVacuumTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType)) -} - -// GetErasureCodingTaskConfigForType returns erasure coding configuration for a specific task type -func (mcm *MaintenanceConfigManager) GetErasureCodingTaskConfigForType(taskType string) *worker_pb.ErasureCodingTaskConfig { - return GetErasureCodingTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType)) -} - -// GetBalanceTaskConfigForType returns balance configuration for a specific task type -func (mcm *MaintenanceConfigManager) GetBalanceTaskConfigForType(taskType string) *worker_pb.BalanceTaskConfig { - return GetBalanceTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType)) -} - -// GetReplicationTaskConfigForType returns replication configuration for a specific task type -func (mcm *MaintenanceConfigManager) GetReplicationTaskConfigForType(taskType string) *worker_pb.ReplicationTaskConfig { - return GetReplicationTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType)) -} - -// Helper methods - -func (mcm *MaintenanceConfigManager) getTaskPolicy(taskType string) *worker_pb.TaskPolicy { - if mcm.config.Policy != nil && mcm.config.Policy.TaskPolicies != nil { - return mcm.config.Policy.TaskPolicies[taskType] - } - return nil -} - -// IsTaskEnabled returns whether a task type is enabled -func (mcm *MaintenanceConfigManager) IsTaskEnabled(taskType string) bool { - if policy := mcm.getTaskPolicy(taskType); policy != nil { - return policy.Enabled - } - return false -} - -// GetMaxConcurrent returns the max concurrent limit for a task type -func (mcm *MaintenanceConfigManager) GetMaxConcurrent(taskType string) int32 { - if policy := mcm.getTaskPolicy(taskType); policy != nil { - return policy.MaxConcurrent - } - return 1 // Default -} - -// GetRepeatInterval returns the repeat interval for a task type in seconds -func (mcm *MaintenanceConfigManager) GetRepeatInterval(taskType string) int32 { - if policy := mcm.getTaskPolicy(taskType); policy != nil { - return policy.RepeatIntervalSeconds - } - return mcm.config.Policy.DefaultRepeatIntervalSeconds -} - -// GetCheckInterval returns the check interval for a task type in seconds -func (mcm *MaintenanceConfigManager) GetCheckInterval(taskType string) int32 { - if policy := mcm.getTaskPolicy(taskType); policy != nil { - return policy.CheckIntervalSeconds - } - return mcm.config.Policy.DefaultCheckIntervalSeconds -} - -// Duration accessor methods - -// GetScanInterval returns the scan interval as a time.Duration -func (mcm *MaintenanceConfigManager) GetScanInterval() time.Duration { - return time.Duration(mcm.config.ScanIntervalSeconds) * time.Second -} - -// GetWorkerTimeout returns the worker timeout as a time.Duration -func (mcm *MaintenanceConfigManager) GetWorkerTimeout() time.Duration { - return time.Duration(mcm.config.WorkerTimeoutSeconds) * time.Second -} - -// GetTaskTimeout returns the task timeout as a time.Duration -func (mcm *MaintenanceConfigManager) GetTaskTimeout() time.Duration { - return time.Duration(mcm.config.TaskTimeoutSeconds) * time.Second -} - -// GetRetryDelay returns the retry delay as a time.Duration -func (mcm *MaintenanceConfigManager) GetRetryDelay() time.Duration { - return time.Duration(mcm.config.RetryDelaySeconds) * time.Second -} - -// GetCleanupInterval returns the cleanup interval as a time.Duration -func (mcm *MaintenanceConfigManager) GetCleanupInterval() time.Duration { - return time.Duration(mcm.config.CleanupIntervalSeconds) * time.Second -} - -// GetTaskRetention returns the task retention period as a time.Duration -func (mcm *MaintenanceConfigManager) GetTaskRetention() time.Duration { - return time.Duration(mcm.config.TaskRetentionSeconds) * time.Second -} - -// ValidateMaintenanceConfigWithSchema validates protobuf maintenance configuration using ConfigField rules -func ValidateMaintenanceConfigWithSchema(config *worker_pb.MaintenanceConfig) error { - if config == nil { - return fmt.Errorf("configuration cannot be nil") - } - - // Get the schema to access field validation rules - schema := GetMaintenanceConfigSchema() - - // Validate each field individually using the ConfigField rules - if err := validateFieldWithSchema(schema, "enabled", config.Enabled); err != nil { - return err - } - - if err := validateFieldWithSchema(schema, "scan_interval_seconds", int(config.ScanIntervalSeconds)); err != nil { - return err - } - - if err := validateFieldWithSchema(schema, "worker_timeout_seconds", int(config.WorkerTimeoutSeconds)); err != nil { - return err - } - - if err := validateFieldWithSchema(schema, "task_timeout_seconds", int(config.TaskTimeoutSeconds)); err != nil { - return err - } - - if err := validateFieldWithSchema(schema, "retry_delay_seconds", int(config.RetryDelaySeconds)); err != nil { - return err - } - - if err := validateFieldWithSchema(schema, "max_retries", int(config.MaxRetries)); err != nil { - return err - } - - if err := validateFieldWithSchema(schema, "cleanup_interval_seconds", int(config.CleanupIntervalSeconds)); err != nil { - return err - } - - if err := validateFieldWithSchema(schema, "task_retention_seconds", int(config.TaskRetentionSeconds)); err != nil { - return err - } - - // Validate policy fields if present - if config.Policy != nil { - // Note: These field names might need to be adjusted based on the actual schema - if err := validatePolicyField("global_max_concurrent", int(config.Policy.GlobalMaxConcurrent)); err != nil { - return err - } - - if err := validatePolicyField("default_repeat_interval_seconds", int(config.Policy.DefaultRepeatIntervalSeconds)); err != nil { - return err - } - - if err := validatePolicyField("default_check_interval_seconds", int(config.Policy.DefaultCheckIntervalSeconds)); err != nil { - return err - } - } - - return nil -} - -// validateFieldWithSchema validates a single field using its ConfigField definition -func validateFieldWithSchema(schema *MaintenanceConfigSchema, fieldName string, value interface{}) error { - field := schema.GetFieldByName(fieldName) - if field == nil { - // Field not in schema, skip validation - return nil - } - - return field.ValidateValue(value) -} - -// validatePolicyField validates policy fields (simplified validation for now) -func validatePolicyField(fieldName string, value int) error { - switch fieldName { - case "global_max_concurrent": - if value < 1 || value > 20 { - return fmt.Errorf("Global Max Concurrent must be between 1 and 20, got %d", value) - } - case "default_repeat_interval": - if value < 1 || value > 168 { - return fmt.Errorf("Default Repeat Interval must be between 1 and 168 hours, got %d", value) - } - case "default_check_interval": - if value < 1 || value > 168 { - return fmt.Errorf("Default Check Interval must be between 1 and 168 hours, got %d", value) - } - } - return nil -} diff --git a/weed/admin/maintenance/maintenance_queue.go b/weed/admin/maintenance/maintenance_queue.go index 28dbc1c5c..dc6546d40 100644 --- a/weed/admin/maintenance/maintenance_queue.go +++ b/weed/admin/maintenance/maintenance_queue.go @@ -1055,28 +1055,6 @@ func (mq *MaintenanceQueue) getMaxConcurrentForTaskType(taskType MaintenanceTask return 1 } -// getRunningTasks returns all currently running tasks -func (mq *MaintenanceQueue) getRunningTasks() []*MaintenanceTask { - var runningTasks []*MaintenanceTask - for _, task := range mq.tasks { - if task.Status == TaskStatusAssigned || task.Status == TaskStatusInProgress { - runningTasks = append(runningTasks, task) - } - } - return runningTasks -} - -// getAvailableWorkers returns all workers that can take more work -func (mq *MaintenanceQueue) getAvailableWorkers() []*MaintenanceWorker { - var availableWorkers []*MaintenanceWorker - for _, worker := range mq.workers { - if worker.Status == "active" && worker.CurrentLoad < worker.MaxConcurrent { - availableWorkers = append(availableWorkers, worker) - } - } - return availableWorkers -} - // trackPendingOperation adds a task to the pending operations tracker func (mq *MaintenanceQueue) trackPendingOperation(task *MaintenanceTask) { if mq.integration == nil { diff --git a/weed/admin/maintenance/maintenance_types.go b/weed/admin/maintenance/maintenance_types.go index 31c797e50..bb8f0a737 100644 --- a/weed/admin/maintenance/maintenance_types.go +++ b/weed/admin/maintenance/maintenance_types.go @@ -2,15 +2,11 @@ package maintenance import ( "html/template" - "sort" "sync" "time" - "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/pb/master_pb" "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb" - "github.com/seaweedfs/seaweedfs/weed/worker/tasks" - "github.com/seaweedfs/seaweedfs/weed/worker/types" ) // AdminClient interface defines what the maintenance system needs from the admin server @@ -21,51 +17,6 @@ type AdminClient interface { // MaintenanceTaskType represents different types of maintenance operations type MaintenanceTaskType string -// GetRegisteredMaintenanceTaskTypes returns all registered task types as MaintenanceTaskType values -// sorted alphabetically for consistent menu ordering -func GetRegisteredMaintenanceTaskTypes() []MaintenanceTaskType { - typesRegistry := tasks.GetGlobalTypesRegistry() - var taskTypes []MaintenanceTaskType - - for workerTaskType := range typesRegistry.GetAllDetectors() { - maintenanceTaskType := MaintenanceTaskType(string(workerTaskType)) - taskTypes = append(taskTypes, maintenanceTaskType) - } - - // Sort task types alphabetically to ensure consistent menu ordering - sort.Slice(taskTypes, func(i, j int) bool { - return string(taskTypes[i]) < string(taskTypes[j]) - }) - - return taskTypes -} - -// GetMaintenanceTaskType returns a specific task type if it's registered, or empty string if not found -func GetMaintenanceTaskType(taskTypeName string) MaintenanceTaskType { - typesRegistry := tasks.GetGlobalTypesRegistry() - - for workerTaskType := range typesRegistry.GetAllDetectors() { - if string(workerTaskType) == taskTypeName { - return MaintenanceTaskType(taskTypeName) - } - } - - return MaintenanceTaskType("") -} - -// IsMaintenanceTaskTypeRegistered checks if a task type is registered -func IsMaintenanceTaskTypeRegistered(taskType MaintenanceTaskType) bool { - typesRegistry := tasks.GetGlobalTypesRegistry() - - for workerTaskType := range typesRegistry.GetAllDetectors() { - if string(workerTaskType) == string(taskType) { - return true - } - } - - return false -} - // MaintenanceTaskPriority represents task execution priority type MaintenanceTaskPriority int @@ -200,14 +151,6 @@ func GetTaskPolicy(mp *MaintenancePolicy, taskType MaintenanceTaskType) *TaskPol return mp.TaskPolicies[string(taskType)] } -// SetTaskPolicy sets the policy for a specific task type -func SetTaskPolicy(mp *MaintenancePolicy, taskType MaintenanceTaskType, policy *TaskPolicy) { - if mp.TaskPolicies == nil { - mp.TaskPolicies = make(map[string]*TaskPolicy) - } - mp.TaskPolicies[string(taskType)] = policy -} - // IsTaskEnabled returns whether a task type is enabled func IsTaskEnabled(mp *MaintenancePolicy, taskType MaintenanceTaskType) bool { policy := GetTaskPolicy(mp, taskType) @@ -235,84 +178,6 @@ func GetRepeatInterval(mp *MaintenancePolicy, taskType MaintenanceTaskType) int return int(policy.RepeatIntervalSeconds) } -// GetVacuumTaskConfig returns the vacuum task configuration -func GetVacuumTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.VacuumTaskConfig { - policy := GetTaskPolicy(mp, taskType) - if policy == nil { - return nil - } - return policy.GetVacuumConfig() -} - -// GetErasureCodingTaskConfig returns the erasure coding task configuration -func GetErasureCodingTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.ErasureCodingTaskConfig { - policy := GetTaskPolicy(mp, taskType) - if policy == nil { - return nil - } - return policy.GetErasureCodingConfig() -} - -// GetBalanceTaskConfig returns the balance task configuration -func GetBalanceTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.BalanceTaskConfig { - policy := GetTaskPolicy(mp, taskType) - if policy == nil { - return nil - } - return policy.GetBalanceConfig() -} - -// GetReplicationTaskConfig returns the replication task configuration -func GetReplicationTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.ReplicationTaskConfig { - policy := GetTaskPolicy(mp, taskType) - if policy == nil { - return nil - } - return policy.GetReplicationConfig() -} - -// Note: GetTaskConfig was removed - use typed getters: GetVacuumTaskConfig, GetErasureCodingTaskConfig, GetBalanceTaskConfig, or GetReplicationTaskConfig - -// SetVacuumTaskConfig sets the vacuum task configuration -func SetVacuumTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.VacuumTaskConfig) { - policy := GetTaskPolicy(mp, taskType) - if policy != nil { - policy.TaskConfig = &worker_pb.TaskPolicy_VacuumConfig{ - VacuumConfig: config, - } - } -} - -// SetErasureCodingTaskConfig sets the erasure coding task configuration -func SetErasureCodingTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.ErasureCodingTaskConfig) { - policy := GetTaskPolicy(mp, taskType) - if policy != nil { - policy.TaskConfig = &worker_pb.TaskPolicy_ErasureCodingConfig{ - ErasureCodingConfig: config, - } - } -} - -// SetBalanceTaskConfig sets the balance task configuration -func SetBalanceTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.BalanceTaskConfig) { - policy := GetTaskPolicy(mp, taskType) - if policy != nil { - policy.TaskConfig = &worker_pb.TaskPolicy_BalanceConfig{ - BalanceConfig: config, - } - } -} - -// SetReplicationTaskConfig sets the replication task configuration -func SetReplicationTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.ReplicationTaskConfig) { - policy := GetTaskPolicy(mp, taskType) - if policy != nil { - policy.TaskConfig = &worker_pb.TaskPolicy_ReplicationConfig{ - ReplicationConfig: config, - } - } -} - // SetTaskConfig sets a configuration value for a task type (legacy method - use typed setters above) // Note: SetTaskConfig was removed - use typed setters: SetVacuumTaskConfig, SetErasureCodingTaskConfig, SetBalanceTaskConfig, or SetReplicationTaskConfig @@ -475,180 +340,6 @@ type ClusterReplicationTask struct { Metadata map[string]string `json:"metadata,omitempty"` } -// BuildMaintenancePolicyFromTasks creates a maintenance policy with configurations -// from all registered tasks using their UI providers -func BuildMaintenancePolicyFromTasks() *MaintenancePolicy { - policy := &MaintenancePolicy{ - TaskPolicies: make(map[string]*TaskPolicy), - GlobalMaxConcurrent: 4, - DefaultRepeatIntervalSeconds: 6 * 3600, // 6 hours in seconds - DefaultCheckIntervalSeconds: 12 * 3600, // 12 hours in seconds - } - - // Get all registered task types from the UI registry - uiRegistry := tasks.GetGlobalUIRegistry() - typesRegistry := tasks.GetGlobalTypesRegistry() - - for taskType, provider := range uiRegistry.GetAllProviders() { - // Convert task type to maintenance task type - maintenanceTaskType := MaintenanceTaskType(string(taskType)) - - // Get the default configuration from the UI provider - defaultConfig := provider.GetCurrentConfig() - - // Create task policy from UI configuration - taskPolicy := &TaskPolicy{ - Enabled: true, // Default enabled - MaxConcurrent: 2, // Default concurrency - RepeatIntervalSeconds: policy.DefaultRepeatIntervalSeconds, - CheckIntervalSeconds: policy.DefaultCheckIntervalSeconds, - } - - // Extract configuration using TaskConfig interface - no more map conversions! - if taskConfig, ok := defaultConfig.(interface{ ToTaskPolicy() *worker_pb.TaskPolicy }); ok { - // Use protobuf directly for clean, type-safe config extraction - pbTaskPolicy := taskConfig.ToTaskPolicy() - taskPolicy.Enabled = pbTaskPolicy.Enabled - taskPolicy.MaxConcurrent = pbTaskPolicy.MaxConcurrent - if pbTaskPolicy.RepeatIntervalSeconds > 0 { - taskPolicy.RepeatIntervalSeconds = pbTaskPolicy.RepeatIntervalSeconds - } - if pbTaskPolicy.CheckIntervalSeconds > 0 { - taskPolicy.CheckIntervalSeconds = pbTaskPolicy.CheckIntervalSeconds - } - } - - // Also get defaults from scheduler if available (using types.TaskScheduler explicitly) - var scheduler types.TaskScheduler = typesRegistry.GetScheduler(taskType) - if scheduler != nil { - if taskPolicy.MaxConcurrent <= 0 { - taskPolicy.MaxConcurrent = int32(scheduler.GetMaxConcurrent()) - } - // Convert default repeat interval to seconds - if repeatInterval := scheduler.GetDefaultRepeatInterval(); repeatInterval > 0 { - taskPolicy.RepeatIntervalSeconds = int32(repeatInterval.Seconds()) - } - } - - // Also get defaults from detector if available (using types.TaskDetector explicitly) - var detector types.TaskDetector = typesRegistry.GetDetector(taskType) - if detector != nil { - // Convert scan interval to check interval (seconds) - if scanInterval := detector.ScanInterval(); scanInterval > 0 { - taskPolicy.CheckIntervalSeconds = int32(scanInterval.Seconds()) - } - } - - policy.TaskPolicies[string(maintenanceTaskType)] = taskPolicy - glog.V(3).Infof("Built policy for task type %s: enabled=%v, max_concurrent=%d", - maintenanceTaskType, taskPolicy.Enabled, taskPolicy.MaxConcurrent) - } - - glog.V(2).Infof("Built maintenance policy with %d task configurations", len(policy.TaskPolicies)) - return policy -} - -// SetPolicyFromTasks sets the maintenance policy from registered tasks -func SetPolicyFromTasks(policy *MaintenancePolicy) { - if policy == nil { - return - } - - // Build new policy from tasks - newPolicy := BuildMaintenancePolicyFromTasks() - - // Copy task policies - policy.TaskPolicies = newPolicy.TaskPolicies - - glog.V(1).Infof("Updated maintenance policy with %d task configurations from registered tasks", len(policy.TaskPolicies)) -} - -// GetTaskIcon returns the icon CSS class for a task type from its UI provider -func GetTaskIcon(taskType MaintenanceTaskType) string { - typesRegistry := tasks.GetGlobalTypesRegistry() - uiRegistry := tasks.GetGlobalUIRegistry() - - // Convert MaintenanceTaskType to TaskType - for workerTaskType := range typesRegistry.GetAllDetectors() { - if string(workerTaskType) == string(taskType) { - // Get the UI provider for this task type - provider := uiRegistry.GetProvider(workerTaskType) - if provider != nil { - return provider.GetIcon() - } - break - } - } - - // Default icon if no UI provider found - return "fas fa-cog text-muted" -} - -// GetTaskDisplayName returns the display name for a task type from its UI provider -func GetTaskDisplayName(taskType MaintenanceTaskType) string { - typesRegistry := tasks.GetGlobalTypesRegistry() - uiRegistry := tasks.GetGlobalUIRegistry() - - // Convert MaintenanceTaskType to TaskType - for workerTaskType := range typesRegistry.GetAllDetectors() { - if string(workerTaskType) == string(taskType) { - // Get the UI provider for this task type - provider := uiRegistry.GetProvider(workerTaskType) - if provider != nil { - return provider.GetDisplayName() - } - break - } - } - - // Fallback to the task type string - return string(taskType) -} - -// GetTaskDescription returns the description for a task type from its UI provider -func GetTaskDescription(taskType MaintenanceTaskType) string { - typesRegistry := tasks.GetGlobalTypesRegistry() - uiRegistry := tasks.GetGlobalUIRegistry() - - // Convert MaintenanceTaskType to TaskType - for workerTaskType := range typesRegistry.GetAllDetectors() { - if string(workerTaskType) == string(taskType) { - // Get the UI provider for this task type - provider := uiRegistry.GetProvider(workerTaskType) - if provider != nil { - return provider.GetDescription() - } - break - } - } - - // Fallback to a generic description - return "Configure detailed settings for " + string(taskType) + " tasks." -} - -// BuildMaintenanceMenuItems creates menu items for all registered task types -func BuildMaintenanceMenuItems() []*MaintenanceMenuItem { - var menuItems []*MaintenanceMenuItem - - // Get all registered task types - registeredTypes := GetRegisteredMaintenanceTaskTypes() - - for _, taskType := range registeredTypes { - menuItem := &MaintenanceMenuItem{ - TaskType: taskType, - DisplayName: GetTaskDisplayName(taskType), - Description: GetTaskDescription(taskType), - Icon: GetTaskIcon(taskType), - IsEnabled: IsMaintenanceTaskTypeRegistered(taskType), - Path: "/maintenance/config/" + string(taskType), - } - - menuItems = append(menuItems, menuItem) - } - - return menuItems -} - // Helper functions to extract configuration fields // Note: Removed getVacuumConfigField, getErasureCodingConfigField, getBalanceConfigField, getReplicationConfigField diff --git a/weed/admin/maintenance/maintenance_worker.go b/weed/admin/maintenance/maintenance_worker.go deleted file mode 100644 index e4a6b4cf6..000000000 --- a/weed/admin/maintenance/maintenance_worker.go +++ /dev/null @@ -1,421 +0,0 @@ -package maintenance - -import ( - "context" - "fmt" - "os" - "sync" - "time" - - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/worker" - "github.com/seaweedfs/seaweedfs/weed/worker/tasks" - "github.com/seaweedfs/seaweedfs/weed/worker/types" - - // Import task packages to trigger their auto-registration - _ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/balance" - _ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/erasure_coding" - _ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/vacuum" -) - -// MaintenanceWorkerService manages maintenance task execution -// TaskExecutor defines the function signature for task execution -type TaskExecutor func(*MaintenanceWorkerService, *MaintenanceTask) error - -// TaskExecutorFactory creates a task executor for a given worker service -type TaskExecutorFactory func() TaskExecutor - -// Global registry for task executor factories -var taskExecutorFactories = make(map[MaintenanceTaskType]TaskExecutorFactory) -var executorRegistryMutex sync.RWMutex -var executorRegistryInitOnce sync.Once - -// initializeExecutorFactories dynamically registers executor factories for all auto-registered task types -func initializeExecutorFactories() { - executorRegistryInitOnce.Do(func() { - // Get all registered task types from the global registry - typesRegistry := tasks.GetGlobalTypesRegistry() - - var taskTypes []MaintenanceTaskType - for workerTaskType := range typesRegistry.GetAllDetectors() { - // Convert types.TaskType to MaintenanceTaskType by string conversion - maintenanceTaskType := MaintenanceTaskType(string(workerTaskType)) - taskTypes = append(taskTypes, maintenanceTaskType) - } - - // Register generic executor for all task types - for _, taskType := range taskTypes { - RegisterTaskExecutorFactory(taskType, createGenericTaskExecutor) - } - - glog.V(1).Infof("Dynamically registered generic task executor for %d task types: %v", len(taskTypes), taskTypes) - }) -} - -// RegisterTaskExecutorFactory registers a factory function for creating task executors -func RegisterTaskExecutorFactory(taskType MaintenanceTaskType, factory TaskExecutorFactory) { - executorRegistryMutex.Lock() - defer executorRegistryMutex.Unlock() - taskExecutorFactories[taskType] = factory - glog.V(2).Infof("Registered executor factory for task type: %s", taskType) -} - -// GetTaskExecutorFactory returns the factory for a task type -func GetTaskExecutorFactory(taskType MaintenanceTaskType) (TaskExecutorFactory, bool) { - // Ensure executor factories are initialized - initializeExecutorFactories() - - executorRegistryMutex.RLock() - defer executorRegistryMutex.RUnlock() - factory, exists := taskExecutorFactories[taskType] - return factory, exists -} - -// GetSupportedExecutorTaskTypes returns all task types with registered executor factories -func GetSupportedExecutorTaskTypes() []MaintenanceTaskType { - // Ensure executor factories are initialized - initializeExecutorFactories() - - executorRegistryMutex.RLock() - defer executorRegistryMutex.RUnlock() - - taskTypes := make([]MaintenanceTaskType, 0, len(taskExecutorFactories)) - for taskType := range taskExecutorFactories { - taskTypes = append(taskTypes, taskType) - } - return taskTypes -} - -// createGenericTaskExecutor creates a generic task executor that uses the task registry -func createGenericTaskExecutor() TaskExecutor { - return func(mws *MaintenanceWorkerService, task *MaintenanceTask) error { - return mws.executeGenericTask(task) - } -} - -// init does minimal initialization - actual registration happens lazily -func init() { - // Executor factory registration will happen lazily when first accessed - glog.V(1).Infof("Maintenance worker initialized - executor factories will be registered on first access") -} - -type MaintenanceWorkerService struct { - workerID string - address string - adminServer string - capabilities []MaintenanceTaskType - maxConcurrent int - currentTasks map[string]*MaintenanceTask - queue *MaintenanceQueue - adminClient AdminClient - running bool - stopChan chan struct{} - - // Task execution registry - taskExecutors map[MaintenanceTaskType]TaskExecutor - - // Task registry for creating task instances - taskRegistry *tasks.TaskRegistry -} - -// NewMaintenanceWorkerService creates a new maintenance worker service -func NewMaintenanceWorkerService(workerID, address, adminServer string) *MaintenanceWorkerService { - // Get all registered maintenance task types dynamically - capabilities := GetRegisteredMaintenanceTaskTypes() - - worker := &MaintenanceWorkerService{ - workerID: workerID, - address: address, - adminServer: adminServer, - capabilities: capabilities, - maxConcurrent: 2, // Default concurrent task limit - currentTasks: make(map[string]*MaintenanceTask), - stopChan: make(chan struct{}), - taskExecutors: make(map[MaintenanceTaskType]TaskExecutor), - taskRegistry: tasks.GetGlobalTaskRegistry(), // Use global registry with auto-registered tasks - } - - // Initialize task executor registry - worker.initializeTaskExecutors() - - glog.V(1).Infof("Created maintenance worker with %d registered task types", len(worker.taskRegistry.GetAll())) - - return worker -} - -// executeGenericTask executes a task using the task registry instead of hardcoded methods -func (mws *MaintenanceWorkerService) executeGenericTask(task *MaintenanceTask) error { - glog.V(2).Infof("Executing generic task %s: %s for volume %d", task.ID, task.Type, task.VolumeID) - - // Validate that task has proper typed parameters - if task.TypedParams == nil { - return fmt.Errorf("task %s has no typed parameters - task was not properly planned (insufficient destinations)", task.ID) - } - - // Convert MaintenanceTask to types.TaskType - taskType := types.TaskType(string(task.Type)) - - // Create task instance using the registry - taskInstance, err := mws.taskRegistry.Get(taskType).Create(task.TypedParams) - if err != nil { - return fmt.Errorf("failed to create task instance: %w", err) - } - - // Update progress to show task has started - mws.updateTaskProgress(task.ID, 5) - - // Execute the task - err = taskInstance.Execute(context.Background(), task.TypedParams) - if err != nil { - return fmt.Errorf("task execution failed: %w", err) - } - - // Update progress to show completion - mws.updateTaskProgress(task.ID, 100) - - glog.V(2).Infof("Generic task %s completed successfully", task.ID) - return nil -} - -// initializeTaskExecutors sets up the task execution registry dynamically -func (mws *MaintenanceWorkerService) initializeTaskExecutors() { - mws.taskExecutors = make(map[MaintenanceTaskType]TaskExecutor) - - // Get all registered executor factories and create executors - executorRegistryMutex.RLock() - defer executorRegistryMutex.RUnlock() - - for taskType, factory := range taskExecutorFactories { - executor := factory() - mws.taskExecutors[taskType] = executor - glog.V(3).Infof("Initialized executor for task type: %s", taskType) - } - - glog.V(2).Infof("Initialized %d task executors", len(mws.taskExecutors)) -} - -// RegisterTaskExecutor allows dynamic registration of new task executors -func (mws *MaintenanceWorkerService) RegisterTaskExecutor(taskType MaintenanceTaskType, executor TaskExecutor) { - if mws.taskExecutors == nil { - mws.taskExecutors = make(map[MaintenanceTaskType]TaskExecutor) - } - mws.taskExecutors[taskType] = executor - glog.V(1).Infof("Registered executor for task type: %s", taskType) -} - -// GetSupportedTaskTypes returns all task types that this worker can execute -func (mws *MaintenanceWorkerService) GetSupportedTaskTypes() []MaintenanceTaskType { - return GetSupportedExecutorTaskTypes() -} - -// Start begins the worker service -func (mws *MaintenanceWorkerService) Start() error { - mws.running = true - - // Register with admin server - worker := &MaintenanceWorker{ - ID: mws.workerID, - Address: mws.address, - Capabilities: mws.capabilities, - MaxConcurrent: mws.maxConcurrent, - } - - if mws.queue != nil { - mws.queue.RegisterWorker(worker) - } - - // Start worker loop - go mws.workerLoop() - - glog.Infof("Maintenance worker %s started at %s", mws.workerID, mws.address) - return nil -} - -// Stop terminates the worker service -func (mws *MaintenanceWorkerService) Stop() { - mws.running = false - close(mws.stopChan) - - // Wait for current tasks to complete or timeout - timeout := time.NewTimer(30 * time.Second) - defer timeout.Stop() - - for len(mws.currentTasks) > 0 { - select { - case <-timeout.C: - glog.Warningf("Worker %s stopping with %d tasks still running", mws.workerID, len(mws.currentTasks)) - return - case <-time.After(time.Second): - // Check again - } - } - - glog.Infof("Maintenance worker %s stopped", mws.workerID) -} - -// workerLoop is the main worker event loop -func (mws *MaintenanceWorkerService) workerLoop() { - heartbeatTicker := time.NewTicker(30 * time.Second) - defer heartbeatTicker.Stop() - - taskRequestTicker := time.NewTicker(5 * time.Second) - defer taskRequestTicker.Stop() - - for mws.running { - select { - case <-mws.stopChan: - return - case <-heartbeatTicker.C: - mws.sendHeartbeat() - case <-taskRequestTicker.C: - mws.requestTasks() - } - } -} - -// sendHeartbeat sends heartbeat to admin server -func (mws *MaintenanceWorkerService) sendHeartbeat() { - if mws.queue != nil { - mws.queue.UpdateWorkerHeartbeat(mws.workerID) - } -} - -// requestTasks requests new tasks from the admin server -func (mws *MaintenanceWorkerService) requestTasks() { - if len(mws.currentTasks) >= mws.maxConcurrent { - return // Already at capacity - } - - if mws.queue != nil { - task := mws.queue.GetNextTask(mws.workerID, mws.capabilities) - if task != nil { - mws.executeTask(task) - } - } -} - -// executeTask executes a maintenance task -func (mws *MaintenanceWorkerService) executeTask(task *MaintenanceTask) { - mws.currentTasks[task.ID] = task - - go func() { - defer func() { - delete(mws.currentTasks, task.ID) - }() - - glog.Infof("Worker %s executing task %s: %s", mws.workerID, task.ID, task.Type) - - // Execute task using dynamic executor registry - var err error - if executor, exists := mws.taskExecutors[task.Type]; exists { - err = executor(mws, task) - } else { - err = fmt.Errorf("unsupported task type: %s", task.Type) - glog.Errorf("No executor registered for task type: %s", task.Type) - } - - // Report task completion - if mws.queue != nil { - errorMsg := "" - if err != nil { - errorMsg = err.Error() - } - mws.queue.CompleteTask(task.ID, errorMsg) - } - - if err != nil { - glog.Errorf("Worker %s failed to execute task %s: %v", mws.workerID, task.ID, err) - } else { - glog.Infof("Worker %s completed task %s successfully", mws.workerID, task.ID) - } - }() -} - -// updateTaskProgress updates the progress of a task -func (mws *MaintenanceWorkerService) updateTaskProgress(taskID string, progress float64) { - if mws.queue != nil { - mws.queue.UpdateTaskProgress(taskID, progress) - } -} - -// GetStatus returns the current status of the worker -func (mws *MaintenanceWorkerService) GetStatus() map[string]interface{} { - return map[string]interface{}{ - "worker_id": mws.workerID, - "address": mws.address, - "running": mws.running, - "capabilities": mws.capabilities, - "max_concurrent": mws.maxConcurrent, - "current_tasks": len(mws.currentTasks), - "task_details": mws.currentTasks, - } -} - -// SetQueue sets the maintenance queue for the worker -func (mws *MaintenanceWorkerService) SetQueue(queue *MaintenanceQueue) { - mws.queue = queue -} - -// SetAdminClient sets the admin client for the worker -func (mws *MaintenanceWorkerService) SetAdminClient(client AdminClient) { - mws.adminClient = client -} - -// SetCapabilities sets the worker capabilities -func (mws *MaintenanceWorkerService) SetCapabilities(capabilities []MaintenanceTaskType) { - mws.capabilities = capabilities -} - -// SetMaxConcurrent sets the maximum concurrent tasks -func (mws *MaintenanceWorkerService) SetMaxConcurrent(max int) { - mws.maxConcurrent = max -} - -// SetHeartbeatInterval sets the heartbeat interval (placeholder for future use) -func (mws *MaintenanceWorkerService) SetHeartbeatInterval(interval time.Duration) { - // Future implementation for configurable heartbeat -} - -// SetTaskRequestInterval sets the task request interval (placeholder for future use) -func (mws *MaintenanceWorkerService) SetTaskRequestInterval(interval time.Duration) { - // Future implementation for configurable task requests -} - -// MaintenanceWorkerCommand represents a standalone maintenance worker command -type MaintenanceWorkerCommand struct { - workerService *MaintenanceWorkerService -} - -// NewMaintenanceWorkerCommand creates a new worker command -func NewMaintenanceWorkerCommand(workerID, address, adminServer string) *MaintenanceWorkerCommand { - return &MaintenanceWorkerCommand{ - workerService: NewMaintenanceWorkerService(workerID, address, adminServer), - } -} - -// Run starts the maintenance worker as a standalone service -func (mwc *MaintenanceWorkerCommand) Run() error { - // Generate or load persistent worker ID if not provided - if mwc.workerService.workerID == "" { - // Get current working directory for worker ID persistence - wd, err := os.Getwd() - if err != nil { - return fmt.Errorf("failed to get working directory: %w", err) - } - - workerID, err := worker.GenerateOrLoadWorkerID(wd) - if err != nil { - return fmt.Errorf("failed to generate or load worker ID: %w", err) - } - mwc.workerService.workerID = workerID - } - - // Start the worker service - err := mwc.workerService.Start() - if err != nil { - return fmt.Errorf("failed to start maintenance worker: %w", err) - } - - // Wait for interrupt signal - select {} -} diff --git a/weed/admin/plugin/plugin.go b/weed/admin/plugin/plugin.go index e14e7ae41..094a15830 100644 --- a/weed/admin/plugin/plugin.go +++ b/weed/admin/plugin/plugin.go @@ -122,6 +122,7 @@ type Plugin struct { type streamSession struct { workerID string outgoing chan *plugin_pb.AdminToWorkerMessage + done chan struct{} closeOnce sync.Once } @@ -274,6 +275,7 @@ func (r *Plugin) WorkerStream(stream plugin_pb.PluginControlService_WorkerStream session := &streamSession{ workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, r.outgoingBuffer), + done: make(chan struct{}), } r.putSession(session) defer r.cleanupSession(workerID) @@ -908,8 +910,10 @@ func (r *Plugin) sendLoop( return nil case <-r.shutdownCh: return nil - case msg, ok := <-session.outgoing: - if !ok { + case <-session.done: + return nil + case msg := <-session.outgoing: + if msg == nil { return nil } if err := stream.Send(msg); err != nil { @@ -930,6 +934,8 @@ func (r *Plugin) sendToWorker(workerID string, message *plugin_pb.AdminToWorkerM select { case <-r.shutdownCh: return fmt.Errorf("plugin is shutting down") + case <-session.done: + return fmt.Errorf("worker %s session is closed", workerID) case session.outgoing <- message: return nil case <-time.After(r.sendTimeout): @@ -1425,7 +1431,7 @@ func CloneConfigValueMap(in map[string]*plugin_pb.ConfigValue) map[string]*plugi func (s *streamSession) close() { s.closeOnce.Do(func() { - close(s.outgoing) + close(s.done) }) } diff --git a/weed/admin/plugin/plugin_cancel_test.go b/weed/admin/plugin/plugin_cancel_test.go index 2a966ae8c..ef129ea08 100644 --- a/weed/admin/plugin/plugin_cancel_test.go +++ b/weed/admin/plugin/plugin_cancel_test.go @@ -26,7 +26,7 @@ func TestRunDetectionSendsCancelOnContextDone(t *testing.T) { {JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1}, }, }) - session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4)} + session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4), done: make(chan struct{})} pluginSvc.putSession(session) ctx, cancel := context.WithCancel(context.Background()) @@ -77,7 +77,7 @@ func TestExecuteJobSendsCancelOnContextDone(t *testing.T) { {JobType: jobType, CanExecute: true, MaxExecutionConcurrency: 1}, }, }) - session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4)} + session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4), done: make(chan struct{})} pluginSvc.putSession(session) job := &plugin_pb.JobSpec{JobId: "job-1", JobType: jobType} @@ -135,8 +135,8 @@ func TestAdminScriptExecutionBlocksOtherDetection(t *testing.T) { {JobType: "vacuum", CanDetect: true, MaxDetectionConcurrency: 1}, }, }) - adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)} - otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)} + adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})} + otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})} pluginSvc.putSession(adminSession) pluginSvc.putSession(otherSession) @@ -214,8 +214,8 @@ func TestAdminScriptExecutionBlocksOtherExecution(t *testing.T) { {JobType: "vacuum", CanExecute: true, MaxExecutionConcurrency: 1}, }, }) - adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)} - otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)} + adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})} + otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})} pluginSvc.putSession(adminSession) pluginSvc.putSession(otherSession) diff --git a/weed/admin/plugin/plugin_detection_test.go b/weed/admin/plugin/plugin_detection_test.go index be2aac50c..ee86c353a 100644 --- a/weed/admin/plugin/plugin_detection_test.go +++ b/weed/admin/plugin/plugin_detection_test.go @@ -22,7 +22,7 @@ func TestRunDetectionIncludesLatestSuccessfulRun(t *testing.T) { {JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1}, }, }) - session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)} + session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})} pluginSvc.putSession(session) oldSuccess := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) @@ -80,7 +80,7 @@ func TestRunDetectionOmitsLastSuccessfulRunWhenNoSuccessHistory(t *testing.T) { {JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1}, }, }) - session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)} + session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})} pluginSvc.putSession(session) if err := pluginSvc.store.AppendRunRecord(jobType, &JobRunRecord{ @@ -130,7 +130,7 @@ func TestRunDetectionWithReportCapturesDetectionActivities(t *testing.T) { {JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1}, }, }) - session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)} + session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})} pluginSvc.putSession(session) reportCh := make(chan *DetectionReport, 1) @@ -210,7 +210,7 @@ func TestRunDetectionAdminScriptUsesLastCompletedRun(t *testing.T) { {JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1}, }, }) - session := &streamSession{workerID: "worker-admin-script", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)} + session := &streamSession{workerID: "worker-admin-script", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})} pluginSvc.putSession(session) successCompleted := time.Date(2026, 2, 1, 10, 0, 0, 0, time.UTC) diff --git a/weed/admin/plugin/plugin_scheduler.go b/weed/admin/plugin/plugin_scheduler.go index 248ca8985..121074913 100644 --- a/weed/admin/plugin/plugin_scheduler.go +++ b/weed/admin/plugin/plugin_scheduler.go @@ -95,16 +95,6 @@ func (r *Plugin) laneSchedulerLoop(ls *schedulerLaneState) { } } -// schedulerLoop is kept for backward compatibility; it delegates to -// laneSchedulerLoop with the default lane. New code should not call this. -func (r *Plugin) schedulerLoop() { - ls := r.lanes[LaneDefault] - if ls == nil { - ls = newLaneState(LaneDefault) - } - r.laneSchedulerLoop(ls) -} - // runLaneSchedulerIteration runs one scheduling pass for a single lane, // processing only the job types assigned to that lane. // @@ -229,82 +219,6 @@ func (r *Plugin) runLaneSchedulerIterationConcurrent(ls *schedulerLaneState, job return hadJobs.Load() } -// runSchedulerIteration is kept for backward compatibility. It runs a -// single iteration across ALL job types (equivalent to the old single-loop -// behavior). It is only used by the legacy schedulerLoop() fallback. -func (r *Plugin) runSchedulerIteration() bool { - ls := r.lanes[LaneDefault] - if ls == nil { - ls = newLaneState(LaneDefault) - } - // For backward compat, the old function processes all job types. - r.expireStaleJobs(time.Now().UTC()) - - jobTypes := r.registry.DetectableJobTypes() - if len(jobTypes) == 0 { - r.setSchedulerLoopState("", "idle") - return false - } - - r.setSchedulerLoopState("", "waiting_for_lock") - releaseLock, err := r.acquireAdminLock("plugin scheduler iteration") - if err != nil { - glog.Warningf("Plugin scheduler failed to acquire lock: %v", err) - r.setSchedulerLoopState("", "idle") - return false - } - if releaseLock != nil { - defer releaseLock() - } - - active := make(map[string]struct{}, len(jobTypes)) - hadJobs := false - - for _, jobType := range jobTypes { - active[jobType] = struct{}{} - - policy, enabled, err := r.loadSchedulerPolicy(jobType) - if err != nil { - glog.Warningf("Plugin scheduler failed to load policy for %s: %v", jobType, err) - continue - } - if !enabled { - r.clearSchedulerJobType(jobType) - continue - } - initialDelay := time.Duration(0) - if runInfo := r.snapshotSchedulerRun(jobType); runInfo.lastRunStartedAt.IsZero() { - initialDelay = 5 * time.Second - } - if !r.markDetectionDue(jobType, policy.DetectionInterval, initialDelay) { - continue - } - - detected := r.runJobTypeIteration(jobType, policy) - if detected { - hadJobs = true - } - } - - r.pruneSchedulerState(active) - r.pruneDetectorLeases(active) - r.setSchedulerLoopState("", "idle") - return hadJobs -} - -// wakeLane wakes the scheduler goroutine for a specific lane. -func (r *Plugin) wakeLane(lane SchedulerLane) { - if r == nil { - return - } - if ls, ok := r.lanes[lane]; ok { - select { - case ls.wakeCh <- struct{}{}: - default: - } - } -} - // wakeAllLanes wakes all lane scheduler goroutines. func (r *Plugin) wakeAllLanes() { if r == nil { diff --git a/weed/admin/plugin/scheduler_status.go b/weed/admin/plugin/scheduler_status.go index 19de4ea2e..d5a33069a 100644 --- a/weed/admin/plugin/scheduler_status.go +++ b/weed/admin/plugin/scheduler_status.go @@ -210,16 +210,6 @@ func (r *Plugin) setSchedulerLoopStateForJobType(jobType, phase string) { } } -func (r *Plugin) recordSchedulerIterationComplete(hadJobs bool) { - if r == nil { - return - } - r.schedulerLoopMu.Lock() - r.schedulerLoopState.lastIterationHadJobs = hadJobs - r.schedulerLoopState.lastIterationCompleted = time.Now().UTC() - r.schedulerLoopMu.Unlock() -} - func (r *Plugin) snapshotSchedulerLoopState() schedulerLoopState { if r == nil { return schedulerLoopState{} diff --git a/weed/admin/view/app/template_helpers.go b/weed/admin/view/app/template_helpers.go index 14814a9bd..fff28de09 100644 --- a/weed/admin/view/app/template_helpers.go +++ b/weed/admin/view/app/template_helpers.go @@ -6,20 +6,6 @@ import ( "strings" ) -// getStatusColor returns Bootstrap color class for status -func getStatusColor(status string) string { - switch status { - case "active", "healthy": - return "success" - case "warning": - return "warning" - case "critical", "unreachable": - return "danger" - default: - return "secondary" - } -} - // formatBytes converts bytes to human readable format func formatBytes(bytes int64) string { if bytes == 0 { diff --git a/weed/cluster/cluster.go b/weed/cluster/cluster.go index 8327065b3..4d4614fb0 100644 --- a/weed/cluster/cluster.go +++ b/weed/cluster/cluster.go @@ -95,18 +95,6 @@ func NewCluster() *Cluster { } } -func (cluster *Cluster) getGroupMembers(filerGroup FilerGroupName, nodeType string, createIfNotFound bool) *GroupMembers { - switch nodeType { - case FilerType: - return cluster.filerGroups.getGroupMembers(filerGroup, createIfNotFound) - case BrokerType: - return cluster.brokerGroups.getGroupMembers(filerGroup, createIfNotFound) - case S3Type: - return cluster.s3Groups.getGroupMembers(filerGroup, createIfNotFound) - } - return nil -} - func (cluster *Cluster) AddClusterNode(ns, nodeType string, dataCenter DataCenter, rack Rack, address pb.ServerAddress, version string) []*master_pb.KeepConnectedResponse { filerGroup := FilerGroupName(ns) switch nodeType { diff --git a/weed/command/admin.go b/weed/command/admin.go index f5e4a8360..6d6dc7198 100644 --- a/weed/command/admin.go +++ b/weed/command/admin.go @@ -511,11 +511,6 @@ func recoveryMiddleware(next http.Handler) http.Handler { }) } -// GetAdminOptions returns the admin command options for testing -func GetAdminOptions() *AdminOptions { - return &AdminOptions{} -} - // loadOrGenerateSessionKeys loads or creates authentication/encryption keys for session cookies. func loadOrGenerateSessionKeys(dataDir string) ([]byte, []byte, error) { const keyLen = 32 diff --git a/weed/command/download.go b/weed/command/download.go index e44335097..a155ad74a 100644 --- a/weed/command/download.go +++ b/weed/command/download.go @@ -132,16 +132,3 @@ func fetchContent(masterFn operation.GetMasterFn, grpcDialOption grpc.DialOption content, e = io.ReadAll(rc.Body) return } - -func WriteFile(filename string, data []byte, perm os.FileMode) error { - f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) - if err != nil { - return err - } - n, err := f.Write(data) - f.Close() - if err == nil && n < len(data) { - err = io.ErrShortWrite - } - return err -} diff --git a/weed/credential/config_loader.go b/weed/credential/config_loader.go index 959f1cfb4..df57b55d3 100644 --- a/weed/credential/config_loader.go +++ b/weed/credential/config_loader.go @@ -57,42 +57,6 @@ func LoadCredentialConfiguration() (*CredentialConfig, error) { }, nil } -// GetCredentialStoreConfig extracts credential store configuration from command line flags -// This is used when credential store is configured via command line instead of credential.toml -func GetCredentialStoreConfig(store string, config util.Configuration, prefix string) *CredentialConfig { - if store == "" { - return nil - } - - return &CredentialConfig{ - Store: store, - Config: config, - Prefix: prefix, - } -} - -// MergeCredentialConfig merges command line credential config with credential.toml config -// Command line flags take priority over credential.toml -func MergeCredentialConfig(cmdLineStore string, cmdLineConfig util.Configuration, cmdLinePrefix string) (*CredentialConfig, error) { - // If command line credential store is specified, use it - if cmdLineStore != "" { - glog.V(0).Infof("Using command line credential configuration: store=%s", cmdLineStore) - return GetCredentialStoreConfig(cmdLineStore, cmdLineConfig, cmdLinePrefix), nil - } - - // Otherwise, try to load from credential.toml - config, err := LoadCredentialConfiguration() - if err != nil { - return nil, err - } - - if config == nil { - glog.V(1).Info("No credential store configured") - } - - return config, nil -} - // NewCredentialManagerWithDefaults creates a credential manager with fallback to defaults // If explicitStore is provided, it will be used regardless of credential.toml // If explicitStore is empty, it tries credential.toml first, then defaults to "filer_etc" diff --git a/weed/credential/filer_etc/filer_etc_policy.go b/weed/credential/filer_etc/filer_etc_policy.go index c83e56647..98cf1e721 100644 --- a/weed/credential/filer_etc/filer_etc_policy.go +++ b/weed/credential/filer_etc/filer_etc_policy.go @@ -207,32 +207,6 @@ func (store *FilerEtcStore) loadPoliciesFromMultiFile(ctx context.Context, polic }) } -func (store *FilerEtcStore) migratePoliciesToMultiFile(ctx context.Context, policies map[string]policy_engine.PolicyDocument) error { - glog.Infof("Migrating IAM policies to multi-file layout...") - - // 1. Save all policies to individual files - for name, policy := range policies { - if err := store.savePolicy(ctx, name, policy); err != nil { - return err - } - } - - // 2. Rename legacy file - return store.withFilerClient(func(client filer_pb.SeaweedFilerClient) error { - _, err := client.AtomicRenameEntry(ctx, &filer_pb.AtomicRenameEntryRequest{ - OldDirectory: filer.IamConfigDirectory, - OldName: filer.IamPoliciesFile, - NewDirectory: filer.IamConfigDirectory, - NewName: IamLegacyPoliciesOldFile, - }) - if err != nil { - glog.Errorf("Failed to rename legacy IAM policies file %s/%s to %s: %v", - filer.IamConfigDirectory, filer.IamPoliciesFile, IamLegacyPoliciesOldFile, err) - } - return err - }) -} - func (store *FilerEtcStore) savePolicy(ctx context.Context, name string, document policy_engine.PolicyDocument) error { if err := validatePolicyName(name); err != nil { return err diff --git a/weed/credential/migration.go b/weed/credential/migration.go deleted file mode 100644 index 41d0e3840..000000000 --- a/weed/credential/migration.go +++ /dev/null @@ -1,221 +0,0 @@ -package credential - -import ( - "context" - "fmt" - - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb" - "github.com/seaweedfs/seaweedfs/weed/util" -) - -// MigrateCredentials migrates credentials from one store to another -func MigrateCredentials(fromStoreName, toStoreName CredentialStoreTypeName, configuration util.Configuration, fromPrefix, toPrefix string) error { - ctx := context.Background() - - // Create source credential manager - fromCM, err := NewCredentialManager(fromStoreName, configuration, fromPrefix) - if err != nil { - return fmt.Errorf("failed to create source credential manager (%s): %v", fromStoreName, err) - } - defer fromCM.Shutdown() - - // Create destination credential manager - toCM, err := NewCredentialManager(toStoreName, configuration, toPrefix) - if err != nil { - return fmt.Errorf("failed to create destination credential manager (%s): %v", toStoreName, err) - } - defer toCM.Shutdown() - - // Load configuration from source - glog.Infof("Loading configuration from %s store...", fromStoreName) - config, err := fromCM.LoadConfiguration(ctx) - if err != nil { - return fmt.Errorf("failed to load configuration from source store: %w", err) - } - - if config == nil || len(config.Identities) == 0 { - glog.Info("No identities found in source store") - return nil - } - - glog.Infof("Found %d identities in source store", len(config.Identities)) - - // Migrate each identity - var migrated, failed int - for _, identity := range config.Identities { - glog.V(1).Infof("Migrating user: %s", identity.Name) - - // Check if user already exists in destination - existingUser, err := toCM.GetUser(ctx, identity.Name) - if err != nil && err != ErrUserNotFound { - glog.Errorf("Failed to check if user %s exists in destination: %v", identity.Name, err) - failed++ - continue - } - - if existingUser != nil { - glog.Warningf("User %s already exists in destination store, skipping", identity.Name) - continue - } - - // Create user in destination - err = toCM.CreateUser(ctx, identity) - if err != nil { - glog.Errorf("Failed to create user %s in destination store: %v", identity.Name, err) - failed++ - continue - } - - migrated++ - glog.V(1).Infof("Successfully migrated user: %s", identity.Name) - } - - glog.Infof("Migration completed: %d migrated, %d failed", migrated, failed) - - if failed > 0 { - return fmt.Errorf("migration completed with %d failures", failed) - } - - return nil -} - -// ExportCredentials exports credentials from a store to a configuration -func ExportCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string) (*iam_pb.S3ApiConfiguration, error) { - ctx := context.Background() - - // Create credential manager - cm, err := NewCredentialManager(storeName, configuration, prefix) - if err != nil { - return nil, fmt.Errorf("failed to create credential manager (%s): %v", storeName, err) - } - defer cm.Shutdown() - - // Load configuration - config, err := cm.LoadConfiguration(ctx) - if err != nil { - return nil, fmt.Errorf("failed to load configuration: %w", err) - } - - return config, nil -} - -// ImportCredentials imports credentials from a configuration to a store -func ImportCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string, config *iam_pb.S3ApiConfiguration) error { - ctx := context.Background() - - // Create credential manager - cm, err := NewCredentialManager(storeName, configuration, prefix) - if err != nil { - return fmt.Errorf("failed to create credential manager (%s): %v", storeName, err) - } - defer cm.Shutdown() - - // Import each identity - var imported, failed int - for _, identity := range config.Identities { - glog.V(1).Infof("Importing user: %s", identity.Name) - - // Check if user already exists - existingUser, err := cm.GetUser(ctx, identity.Name) - if err != nil && err != ErrUserNotFound { - glog.Errorf("Failed to check if user %s exists: %v", identity.Name, err) - failed++ - continue - } - - if existingUser != nil { - glog.Warningf("User %s already exists, skipping", identity.Name) - continue - } - - // Create user - err = cm.CreateUser(ctx, identity) - if err != nil { - glog.Errorf("Failed to create user %s: %v", identity.Name, err) - failed++ - continue - } - - imported++ - glog.V(1).Infof("Successfully imported user: %s", identity.Name) - } - - glog.Infof("Import completed: %d imported, %d failed", imported, failed) - - if failed > 0 { - return fmt.Errorf("import completed with %d failures", failed) - } - - return nil -} - -// ValidateCredentials validates that all credentials in a store are accessible -func ValidateCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string) error { - ctx := context.Background() - - // Create credential manager - cm, err := NewCredentialManager(storeName, configuration, prefix) - if err != nil { - return fmt.Errorf("failed to create credential manager (%s): %v", storeName, err) - } - defer cm.Shutdown() - - // Load configuration - config, err := cm.LoadConfiguration(ctx) - if err != nil { - return fmt.Errorf("failed to load configuration: %w", err) - } - - if config == nil || len(config.Identities) == 0 { - glog.Info("No identities found in store") - return nil - } - - glog.Infof("Validating %d identities...", len(config.Identities)) - - // Validate each identity - var validated, failed int - for _, identity := range config.Identities { - // Check if user can be retrieved - user, err := cm.GetUser(ctx, identity.Name) - if err != nil { - glog.Errorf("Failed to retrieve user %s: %v", identity.Name, err) - failed++ - continue - } - - if user == nil { - glog.Errorf("User %s not found", identity.Name) - failed++ - continue - } - - // Validate access keys - for _, credential := range identity.Credentials { - accessKeyUser, err := cm.GetUserByAccessKey(ctx, credential.AccessKey) - if err != nil { - glog.Errorf("Failed to retrieve user by access key %s: %v", credential.AccessKey, err) - failed++ - continue - } - - if accessKeyUser == nil || accessKeyUser.Name != identity.Name { - glog.Errorf("Access key %s does not map to correct user %s", credential.AccessKey, identity.Name) - failed++ - continue - } - } - - validated++ - glog.V(1).Infof("Successfully validated user: %s", identity.Name) - } - - glog.Infof("Validation completed: %d validated, %d failed", validated, failed) - - if failed > 0 { - return fmt.Errorf("validation completed with %d failures", failed) - } - - return nil -} diff --git a/weed/filer/filer_notify_read.go b/weed/filer/filer_notify_read.go index 0cf71efe1..cf0641852 100644 --- a/weed/filer/filer_notify_read.go +++ b/weed/filer/filer_notify_read.go @@ -246,10 +246,6 @@ func NewLogFileEntryCollector(f *Filer, startPosition log_buffer.MessagePosition } } -func (c *LogFileEntryCollector) hasMore() bool { - return c.dayEntryQueue.Len() > 0 -} - func (c *LogFileEntryCollector) collectMore(v *OrderedLogVisitor) (err error) { dayEntry := c.dayEntryQueue.Dequeue() if dayEntry == nil { diff --git a/weed/filer/meta_replay.go b/weed/filer/meta_replay.go index f6b009e92..51c4e6987 100644 --- a/weed/filer/meta_replay.go +++ b/weed/filer/meta_replay.go @@ -2,7 +2,6 @@ package filer import ( "context" - "sync" "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" @@ -36,39 +35,3 @@ func Replay(filerStore FilerStore, resp *filer_pb.SubscribeMetadataResponse) err return nil } - -// ParallelProcessDirectoryStructure processes each entry in parallel, and also ensure parent directories are processed first. -// This also assumes the parent directories are in the entryChan already. -func ParallelProcessDirectoryStructure(entryChan chan *Entry, concurrency int, eachEntryFn func(entry *Entry) error) (firstErr error) { - - executors := util.NewLimitedConcurrentExecutor(concurrency) - - var wg sync.WaitGroup - for entry := range entryChan { - wg.Add(1) - if entry.IsDirectory() { - func() { - defer wg.Done() - if err := eachEntryFn(entry); err != nil { - if firstErr == nil { - firstErr = err - } - } - }() - } else { - executors.Execute(func() { - defer wg.Done() - if err := eachEntryFn(entry); err != nil { - if firstErr == nil { - firstErr = err - } - } - }) - } - if firstErr != nil { - break - } - } - wg.Wait() - return -} diff --git a/weed/filer/redis3/ItemList.go b/weed/filer/redis3/ItemList.go index 05457e596..b4043d01c 100644 --- a/weed/filer/redis3/ItemList.go +++ b/weed/filer/redis3/ItemList.go @@ -16,15 +16,6 @@ type ItemList struct { prefix string } -func newItemList(client redis.UniversalClient, prefix string, store skiplist.ListStore, batchSize int) *ItemList { - return &ItemList{ - skipList: skiplist.New(store), - batchSize: batchSize, - client: client, - prefix: prefix, - } -} - /* Be reluctant to create new nodes. Try to fit into either previous node or next node. Prefer to add to previous node. diff --git a/weed/filer/redis_lua/redis_cluster_store.go b/weed/filer/redis_lua/redis_cluster_store.go deleted file mode 100644 index b64342fc2..000000000 --- a/weed/filer/redis_lua/redis_cluster_store.go +++ /dev/null @@ -1,48 +0,0 @@ -package redis_lua - -import ( - "github.com/redis/go-redis/v9" - "github.com/seaweedfs/seaweedfs/weed/filer" - "github.com/seaweedfs/seaweedfs/weed/util" -) - -func init() { - filer.Stores = append(filer.Stores, &RedisLuaClusterStore{}) -} - -type RedisLuaClusterStore struct { - UniversalRedisLuaStore -} - -func (store *RedisLuaClusterStore) GetName() string { - return "redis_lua_cluster" -} - -func (store *RedisLuaClusterStore) Initialize(configuration util.Configuration, prefix string) (err error) { - - configuration.SetDefault(prefix+"useReadOnly", false) - configuration.SetDefault(prefix+"routeByLatency", false) - - return store.initialize( - configuration.GetStringSlice(prefix+"addresses"), - configuration.GetString(prefix+"username"), - configuration.GetString(prefix+"password"), - configuration.GetString(prefix+"keyPrefix"), - configuration.GetBool(prefix+"useReadOnly"), - configuration.GetBool(prefix+"routeByLatency"), - configuration.GetStringSlice(prefix+"superLargeDirectories"), - ) -} - -func (store *RedisLuaClusterStore) initialize(addresses []string, username string, password string, keyPrefix string, readOnly, routeByLatency bool, superLargeDirectories []string) (err error) { - store.Client = redis.NewClusterClient(&redis.ClusterOptions{ - Addrs: addresses, - Username: username, - Password: password, - ReadOnly: readOnly, - RouteByLatency: routeByLatency, - }) - store.keyPrefix = keyPrefix - store.loadSuperLargeDirectories(superLargeDirectories) - return -} diff --git a/weed/filer/redis_lua/redis_sentinel_store.go b/weed/filer/redis_lua/redis_sentinel_store.go deleted file mode 100644 index 6dd85dd06..000000000 --- a/weed/filer/redis_lua/redis_sentinel_store.go +++ /dev/null @@ -1,48 +0,0 @@ -package redis_lua - -import ( - "time" - - "github.com/redis/go-redis/v9" - "github.com/seaweedfs/seaweedfs/weed/filer" - "github.com/seaweedfs/seaweedfs/weed/util" -) - -func init() { - filer.Stores = append(filer.Stores, &RedisLuaSentinelStore{}) -} - -type RedisLuaSentinelStore struct { - UniversalRedisLuaStore -} - -func (store *RedisLuaSentinelStore) GetName() string { - return "redis_lua_sentinel" -} - -func (store *RedisLuaSentinelStore) Initialize(configuration util.Configuration, prefix string) (err error) { - return store.initialize( - configuration.GetStringSlice(prefix+"addresses"), - configuration.GetString(prefix+"masterName"), - configuration.GetString(prefix+"username"), - configuration.GetString(prefix+"password"), - configuration.GetInt(prefix+"database"), - configuration.GetString(prefix+"keyPrefix"), - ) -} - -func (store *RedisLuaSentinelStore) initialize(addresses []string, masterName string, username string, password string, database int, keyPrefix string) (err error) { - store.Client = redis.NewFailoverClient(&redis.FailoverOptions{ - MasterName: masterName, - SentinelAddrs: addresses, - Username: username, - Password: password, - DB: database, - MinRetryBackoff: time.Millisecond * 100, - MaxRetryBackoff: time.Minute * 1, - ReadTimeout: time.Second * 30, - WriteTimeout: time.Second * 5, - }) - store.keyPrefix = keyPrefix - return -} diff --git a/weed/filer/redis_lua/redis_store.go b/weed/filer/redis_lua/redis_store.go deleted file mode 100644 index 4f6354e96..000000000 --- a/weed/filer/redis_lua/redis_store.go +++ /dev/null @@ -1,42 +0,0 @@ -package redis_lua - -import ( - "github.com/redis/go-redis/v9" - "github.com/seaweedfs/seaweedfs/weed/filer" - "github.com/seaweedfs/seaweedfs/weed/util" -) - -func init() { - filer.Stores = append(filer.Stores, &RedisLuaStore{}) -} - -type RedisLuaStore struct { - UniversalRedisLuaStore -} - -func (store *RedisLuaStore) GetName() string { - return "redis_lua" -} - -func (store *RedisLuaStore) Initialize(configuration util.Configuration, prefix string) (err error) { - return store.initialize( - configuration.GetString(prefix+"address"), - configuration.GetString(prefix+"username"), - configuration.GetString(prefix+"password"), - configuration.GetInt(prefix+"database"), - configuration.GetString(prefix+"keyPrefix"), - configuration.GetStringSlice(prefix+"superLargeDirectories"), - ) -} - -func (store *RedisLuaStore) initialize(hostPort string, username string, password string, database int, keyPrefix string, superLargeDirectories []string) (err error) { - store.Client = redis.NewClient(&redis.Options{ - Addr: hostPort, - Username: username, - Password: password, - DB: database, - }) - store.keyPrefix = keyPrefix - store.loadSuperLargeDirectories(superLargeDirectories) - return -} diff --git a/weed/filer/redis_lua/stored_procedure/delete_entry.lua b/weed/filer/redis_lua/stored_procedure/delete_entry.lua deleted file mode 100644 index 445337c77..000000000 --- a/weed/filer/redis_lua/stored_procedure/delete_entry.lua +++ /dev/null @@ -1,19 +0,0 @@ --- KEYS[1]: full path of entry -local fullpath = KEYS[1] --- KEYS[2]: full path of entry -local fullpath_list_key = KEYS[2] --- KEYS[3]: dir of the entry -local dir_list_key = KEYS[3] - --- ARGV[1]: isSuperLargeDirectory -local isSuperLargeDirectory = ARGV[1] == "1" --- ARGV[2]: name of the entry -local name = ARGV[2] - -redis.call("DEL", fullpath, fullpath_list_key) - -if not isSuperLargeDirectory and name ~= "" then - redis.call("ZREM", dir_list_key, name) -end - -return 0 \ No newline at end of file diff --git a/weed/filer/redis_lua/stored_procedure/delete_folder_children.lua b/weed/filer/redis_lua/stored_procedure/delete_folder_children.lua deleted file mode 100644 index 77e4839f9..000000000 --- a/weed/filer/redis_lua/stored_procedure/delete_folder_children.lua +++ /dev/null @@ -1,15 +0,0 @@ --- KEYS[1]: full path of entry -local fullpath = KEYS[1] - -if fullpath ~= "" and string.sub(fullpath, -1) == "/" then - fullpath = string.sub(fullpath, 0, -2) -end - -local files = redis.call("ZRANGE", fullpath .. "\0", "0", "-1") - -for _, name in ipairs(files) do - local file_path = fullpath .. "/" .. name - redis.call("DEL", file_path, file_path .. "\0") -end - -return 0 \ No newline at end of file diff --git a/weed/filer/redis_lua/stored_procedure/init.go b/weed/filer/redis_lua/stored_procedure/init.go deleted file mode 100644 index 685ea364d..000000000 --- a/weed/filer/redis_lua/stored_procedure/init.go +++ /dev/null @@ -1,25 +0,0 @@ -package stored_procedure - -import ( - _ "embed" - - "github.com/redis/go-redis/v9" -) - -func init() { - InsertEntryScript = redis.NewScript(insertEntry) - DeleteEntryScript = redis.NewScript(deleteEntry) - DeleteFolderChildrenScript = redis.NewScript(deleteFolderChildren) -} - -//go:embed insert_entry.lua -var insertEntry string -var InsertEntryScript *redis.Script - -//go:embed delete_entry.lua -var deleteEntry string -var DeleteEntryScript *redis.Script - -//go:embed delete_folder_children.lua -var deleteFolderChildren string -var DeleteFolderChildrenScript *redis.Script diff --git a/weed/filer/redis_lua/stored_procedure/insert_entry.lua b/weed/filer/redis_lua/stored_procedure/insert_entry.lua deleted file mode 100644 index 8deef3446..000000000 --- a/weed/filer/redis_lua/stored_procedure/insert_entry.lua +++ /dev/null @@ -1,27 +0,0 @@ --- KEYS[1]: full path of entry -local full_path = KEYS[1] --- KEYS[2]: dir of the entry -local dir_list_key = KEYS[2] - --- ARGV[1]: content of the entry -local entry = ARGV[1] --- ARGV[2]: TTL of the entry -local ttlSec = tonumber(ARGV[2]) --- ARGV[3]: isSuperLargeDirectory -local isSuperLargeDirectory = ARGV[3] == "1" --- ARGV[4]: zscore of the entry in zset -local zscore = tonumber(ARGV[4]) --- ARGV[5]: name of the entry -local name = ARGV[5] - -if ttlSec > 0 then - redis.call("SET", full_path, entry, "EX", ttlSec) -else - redis.call("SET", full_path, entry) -end - -if not isSuperLargeDirectory and name ~= "" then - redis.call("ZADD", dir_list_key, "NX", zscore, name) -end - -return 0 \ No newline at end of file diff --git a/weed/filer/redis_lua/universal_redis_store.go b/weed/filer/redis_lua/universal_redis_store.go deleted file mode 100644 index 0a02a0730..000000000 --- a/weed/filer/redis_lua/universal_redis_store.go +++ /dev/null @@ -1,206 +0,0 @@ -package redis_lua - -import ( - "context" - "fmt" - "time" - - "github.com/redis/go-redis/v9" - - "github.com/seaweedfs/seaweedfs/weed/filer" - "github.com/seaweedfs/seaweedfs/weed/filer/redis_lua/stored_procedure" - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/util" -) - -const ( - DIR_LIST_MARKER = "\x00" -) - -type UniversalRedisLuaStore struct { - Client redis.UniversalClient - keyPrefix string - superLargeDirectoryHash map[string]bool -} - -func (store *UniversalRedisLuaStore) isSuperLargeDirectory(dir string) (isSuperLargeDirectory bool) { - _, isSuperLargeDirectory = store.superLargeDirectoryHash[dir] - return -} - -func (store *UniversalRedisLuaStore) loadSuperLargeDirectories(superLargeDirectories []string) { - // set directory hash - store.superLargeDirectoryHash = make(map[string]bool) - for _, dir := range superLargeDirectories { - store.superLargeDirectoryHash[dir] = true - } -} - -func (store *UniversalRedisLuaStore) getKey(key string) string { - if store.keyPrefix == "" { - return key - } - return store.keyPrefix + key -} - -func (store *UniversalRedisLuaStore) BeginTransaction(ctx context.Context) (context.Context, error) { - return ctx, nil -} -func (store *UniversalRedisLuaStore) CommitTransaction(ctx context.Context) error { - return nil -} -func (store *UniversalRedisLuaStore) RollbackTransaction(ctx context.Context) error { - return nil -} - -func (store *UniversalRedisLuaStore) InsertEntry(ctx context.Context, entry *filer.Entry) (err error) { - - value, err := entry.EncodeAttributesAndChunks() - if err != nil { - return fmt.Errorf("encoding %s %+v: %v", entry.FullPath, entry.Attr, err) - } - - if len(entry.GetChunks()) > filer.CountEntryChunksForGzip { - value = util.MaybeGzipData(value) - } - - dir, name := entry.FullPath.DirAndName() - - err = stored_procedure.InsertEntryScript.Run(ctx, store.Client, - []string{store.getKey(string(entry.FullPath)), store.getKey(genDirectoryListKey(dir))}, - value, entry.TtlSec, - store.isSuperLargeDirectory(dir), 0, name).Err() - - if err != nil { - return fmt.Errorf("persisting %s : %v", entry.FullPath, err) - } - - return nil -} - -func (store *UniversalRedisLuaStore) UpdateEntry(ctx context.Context, entry *filer.Entry) (err error) { - - return store.InsertEntry(ctx, entry) -} - -func (store *UniversalRedisLuaStore) FindEntry(ctx context.Context, fullpath util.FullPath) (entry *filer.Entry, err error) { - - data, err := store.Client.Get(ctx, store.getKey(string(fullpath))).Result() - if err == redis.Nil { - return nil, filer_pb.ErrNotFound - } - - if err != nil { - return nil, fmt.Errorf("get %s : %v", fullpath, err) - } - - entry = &filer.Entry{ - FullPath: fullpath, - } - err = entry.DecodeAttributesAndChunks(util.MaybeDecompressData([]byte(data))) - if err != nil { - return entry, fmt.Errorf("decode %s : %v", entry.FullPath, err) - } - - return entry, nil -} - -func (store *UniversalRedisLuaStore) DeleteEntry(ctx context.Context, fullpath util.FullPath) (err error) { - - dir, name := fullpath.DirAndName() - - err = stored_procedure.DeleteEntryScript.Run(ctx, store.Client, - []string{store.getKey(string(fullpath)), store.getKey(genDirectoryListKey(string(fullpath))), store.getKey(genDirectoryListKey(dir))}, - store.isSuperLargeDirectory(dir), name).Err() - - if err != nil { - return fmt.Errorf("DeleteEntry %s : %v", fullpath, err) - } - - return nil -} - -func (store *UniversalRedisLuaStore) DeleteFolderChildren(ctx context.Context, fullpath util.FullPath) (err error) { - - if store.isSuperLargeDirectory(string(fullpath)) { - return nil - } - - err = stored_procedure.DeleteFolderChildrenScript.Run(ctx, store.Client, - []string{store.getKey(string(fullpath))}).Err() - - if err != nil { - return fmt.Errorf("DeleteFolderChildren %s : %v", fullpath, err) - } - - return nil -} - -func (store *UniversalRedisLuaStore) ListDirectoryPrefixedEntries(ctx context.Context, dirPath util.FullPath, startFileName string, includeStartFile bool, limit int64, prefix string, eachEntryFunc filer.ListEachEntryFunc) (lastFileName string, err error) { - return lastFileName, filer.ErrUnsupportedListDirectoryPrefixed -} - -func (store *UniversalRedisLuaStore) ListDirectoryEntries(ctx context.Context, dirPath util.FullPath, startFileName string, includeStartFile bool, limit int64, eachEntryFunc filer.ListEachEntryFunc) (lastFileName string, err error) { - - dirListKey := store.getKey(genDirectoryListKey(string(dirPath))) - - min := "-" - if startFileName != "" { - if includeStartFile { - min = "[" + startFileName - } else { - min = "(" + startFileName - } - } - - members, err := store.Client.ZRangeByLex(ctx, dirListKey, &redis.ZRangeBy{ - Min: min, - Max: "+", - Offset: 0, - Count: limit, - }).Result() - if err != nil { - return lastFileName, fmt.Errorf("list %s : %v", dirPath, err) - } - - // fetch entry meta - for _, fileName := range members { - path := util.NewFullPath(string(dirPath), fileName) - entry, err := store.FindEntry(ctx, path) - lastFileName = fileName - if err != nil { - glog.V(0).InfofCtx(ctx, "list %s : %v", path, err) - if err == filer_pb.ErrNotFound { - continue - } - } else { - if entry.TtlSec > 0 { - if entry.Attr.Crtime.Add(time.Duration(entry.TtlSec) * time.Second).Before(time.Now()) { - store.DeleteEntry(ctx, path) - continue - } - } - - resEachEntryFunc, resEachEntryFuncErr := eachEntryFunc(entry) - if resEachEntryFuncErr != nil { - err = fmt.Errorf("failed to process eachEntryFunc: %w", resEachEntryFuncErr) - break - } - - if !resEachEntryFunc { - break - } - } - } - - return lastFileName, err -} - -func genDirectoryListKey(dir string) (dirList string) { - return dir + DIR_LIST_MARKER -} - -func (store *UniversalRedisLuaStore) Shutdown() { - store.Client.Close() -} diff --git a/weed/filer/redis_lua/universal_redis_store_kv.go b/weed/filer/redis_lua/universal_redis_store_kv.go deleted file mode 100644 index 79b6495ce..000000000 --- a/weed/filer/redis_lua/universal_redis_store_kv.go +++ /dev/null @@ -1,42 +0,0 @@ -package redis_lua - -import ( - "context" - "fmt" - - "github.com/redis/go-redis/v9" - "github.com/seaweedfs/seaweedfs/weed/filer" -) - -func (store *UniversalRedisLuaStore) KvPut(ctx context.Context, key []byte, value []byte) (err error) { - - _, err = store.Client.Set(ctx, string(key), value, 0).Result() - - if err != nil { - return fmt.Errorf("kv put: %w", err) - } - - return nil -} - -func (store *UniversalRedisLuaStore) KvGet(ctx context.Context, key []byte) (value []byte, err error) { - - data, err := store.Client.Get(ctx, string(key)).Result() - - if err == redis.Nil { - return nil, filer.ErrKvNotFound - } - - return []byte(data), err -} - -func (store *UniversalRedisLuaStore) KvDelete(ctx context.Context, key []byte) (err error) { - - _, err = store.Client.Del(ctx, string(key)).Result() - - if err != nil { - return fmt.Errorf("kv delete: %w", err) - } - - return nil -} diff --git a/weed/filer/stream.go b/weed/filer/stream.go index c60d147e5..e49794fd2 100644 --- a/weed/filer/stream.go +++ b/weed/filer/stream.go @@ -102,10 +102,6 @@ func PrepareStreamContent(masterClient wdclient.HasLookupFileIdFunction, jwtFunc type VolumeServerJwtFunction func(fileId string) string -func noJwtFunc(string) string { - return "" -} - type CacheInvalidator interface { InvalidateCache(fileId string) } @@ -276,33 +272,6 @@ func writeZero(w io.Writer, size int64) (err error) { return } -func ReadAll(ctx context.Context, buffer []byte, masterClient *wdclient.MasterClient, chunks []*filer_pb.FileChunk) error { - - lookupFileIdFn := func(ctx context.Context, fileId string) (targetUrls []string, err error) { - return masterClient.LookupFileId(ctx, fileId) - } - - chunkViews := ViewFromChunks(ctx, lookupFileIdFn, chunks, 0, int64(len(buffer))) - - idx := 0 - - for x := chunkViews.Front(); x != nil; x = x.Next { - chunkView := x.Value - urlStrings, err := lookupFileIdFn(ctx, chunkView.FileId) - if err != nil { - glog.V(1).InfofCtx(ctx, "operation LookupFileId %s failed, err: %v", chunkView.FileId, err) - return err - } - - n, err := util_http.RetriedFetchChunkData(ctx, buffer[idx:idx+int(chunkView.ViewSize)], urlStrings, chunkView.CipherKey, chunkView.IsGzipped, chunkView.IsFullChunk(), chunkView.OffsetInChunk, chunkView.FileId) - if err != nil { - return err - } - idx += n - } - return nil -} - // ---------------- ChunkStreamReader ---------------------------------- type ChunkStreamReader struct { head *Interval[*ChunkView] diff --git a/weed/filer/stream_failover_test.go b/weed/filer/stream_failover_test.go deleted file mode 100644 index aaa59c523..000000000 --- a/weed/filer/stream_failover_test.go +++ /dev/null @@ -1,281 +0,0 @@ -package filer - -import ( - "bytes" - "context" - "errors" - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/wdclient" -) - -// mockMasterClient implements HasLookupFileIdFunction and CacheInvalidator -type mockMasterClient struct { - lookupFunc func(ctx context.Context, fileId string) ([]string, error) - invalidatedFileIds []string -} - -func (m *mockMasterClient) GetLookupFileIdFunction() wdclient.LookupFileIdFunctionType { - return m.lookupFunc -} - -func (m *mockMasterClient) InvalidateCache(fileId string) { - m.invalidatedFileIds = append(m.invalidatedFileIds, fileId) -} - -// Test urlSlicesEqual helper function -func TestUrlSlicesEqual(t *testing.T) { - tests := []struct { - name string - a []string - b []string - expected bool - }{ - { - name: "identical slices", - a: []string{"http://server1", "http://server2"}, - b: []string{"http://server1", "http://server2"}, - expected: true, - }, - { - name: "same URLs different order", - a: []string{"http://server1", "http://server2"}, - b: []string{"http://server2", "http://server1"}, - expected: true, - }, - { - name: "different URLs", - a: []string{"http://server1", "http://server2"}, - b: []string{"http://server1", "http://server3"}, - expected: false, - }, - { - name: "different lengths", - a: []string{"http://server1"}, - b: []string{"http://server1", "http://server2"}, - expected: false, - }, - { - name: "empty slices", - a: []string{}, - b: []string{}, - expected: true, - }, - { - name: "duplicates in both", - a: []string{"http://server1", "http://server1"}, - b: []string{"http://server1", "http://server1"}, - expected: true, - }, - { - name: "different duplicate counts", - a: []string{"http://server1", "http://server1"}, - b: []string{"http://server1", "http://server2"}, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := urlSlicesEqual(tt.a, tt.b) - if result != tt.expected { - t.Errorf("urlSlicesEqual(%v, %v) = %v; want %v", tt.a, tt.b, result, tt.expected) - } - }) - } -} - -// Test cache invalidation when read fails -func TestStreamContentWithCacheInvalidation(t *testing.T) { - ctx := context.Background() - fileId := "3,01234567890" - - callCount := 0 - oldUrls := []string{"http://failed-server:8080"} - newUrls := []string{"http://working-server:8080"} - - mock := &mockMasterClient{ - lookupFunc: func(ctx context.Context, fid string) ([]string, error) { - callCount++ - if callCount == 1 { - // First call returns failing server - return oldUrls, nil - } - // After invalidation, return working server - return newUrls, nil - }, - } - - // Create a simple chunk - chunks := []*filer_pb.FileChunk{ - { - FileId: fileId, - Offset: 0, - Size: 10, - }, - } - - streamFn, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0) - if err != nil { - t.Fatalf("PrepareStreamContentWithThrottler failed: %v", err) - } - - // Note: This test can't fully execute streamFn because it would require actual HTTP servers - // However, we can verify the setup was created correctly - if streamFn == nil { - t.Fatal("Expected non-nil stream function") - } - - // Verify the lookup was called - if callCount != 1 { - t.Errorf("Expected 1 lookup call, got %d", callCount) - } -} - -// Test that InvalidateCache is called on read failure -func TestCacheInvalidationInterface(t *testing.T) { - mock := &mockMasterClient{ - lookupFunc: func(ctx context.Context, fileId string) ([]string, error) { - return []string{"http://server:8080"}, nil - }, - } - - fileId := "3,test123" - - // Simulate invalidation - if invalidator, ok := interface{}(mock).(CacheInvalidator); ok { - invalidator.InvalidateCache(fileId) - } else { - t.Fatal("mockMasterClient should implement CacheInvalidator") - } - - // Check that the file ID was recorded as invalidated - if len(mock.invalidatedFileIds) != 1 { - t.Fatalf("Expected 1 invalidated file ID, got %d", len(mock.invalidatedFileIds)) - } - if mock.invalidatedFileIds[0] != fileId { - t.Errorf("Expected invalidated file ID %s, got %s", fileId, mock.invalidatedFileIds[0]) - } -} - -// Test retry logic doesn't retry with same URLs -func TestRetryLogicSkipsSameUrls(t *testing.T) { - // This test verifies that the urlSlicesEqual check prevents infinite retries - sameUrls := []string{"http://server1:8080", "http://server2:8080"} - differentUrls := []string{"http://server3:8080", "http://server4:8080"} - - // Same URLs should return true (and thus skip retry) - if !urlSlicesEqual(sameUrls, sameUrls) { - t.Error("Expected same URLs to be equal") - } - - // Different URLs should return false (and thus allow retry) - if urlSlicesEqual(sameUrls, differentUrls) { - t.Error("Expected different URLs to not be equal") - } -} - -func TestCanceledStreamSkipsCacheInvalidation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - fileId := "3,canceled" - - mock := &mockMasterClient{ - lookupFunc: func(ctx context.Context, fid string) ([]string, error) { - return []string{"http://server:8080"}, nil - }, - } - - chunks := []*filer_pb.FileChunk{ - { - FileId: fileId, - Offset: 0, - Size: 10, - }, - } - - streamFn, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0) - if err != nil { - t.Fatalf("PrepareStreamContentWithThrottler failed: %v", err) - } - - cancel() - - err = streamFn(&bytes.Buffer{}) - if err != context.Canceled { - t.Fatalf("expected context.Canceled, got %v", err) - } - if len(mock.invalidatedFileIds) != 0 { - t.Fatalf("expected no cache invalidation on cancellation, got %v", mock.invalidatedFileIds) - } -} - -func TestPrepareStreamContentSkipsLookupWhenContextAlreadyCanceled(t *testing.T) { - oldSchedule := getLookupFileIdBackoffSchedule - getLookupFileIdBackoffSchedule = []time.Duration{time.Millisecond} - t.Cleanup(func() { - getLookupFileIdBackoffSchedule = oldSchedule - }) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - lookupCalls := 0 - mock := &mockMasterClient{ - lookupFunc: func(ctx context.Context, fileId string) ([]string, error) { - lookupCalls++ - return nil, errors.New("lookup should not run") - }, - } - - chunks := []*filer_pb.FileChunk{ - { - FileId: "3,precanceled", - Offset: 0, - Size: 10, - }, - } - - _, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0) - if !errors.Is(err, context.Canceled) { - t.Fatalf("expected context.Canceled, got %v", err) - } - if lookupCalls != 0 { - t.Fatalf("expected no lookup calls after cancellation, got %d", lookupCalls) - } -} - -func TestPrepareStreamContentStopsLookupRetriesAfterContextCancellation(t *testing.T) { - oldSchedule := getLookupFileIdBackoffSchedule - getLookupFileIdBackoffSchedule = []time.Duration{time.Millisecond, time.Millisecond, time.Millisecond} - t.Cleanup(func() { - getLookupFileIdBackoffSchedule = oldSchedule - }) - - ctx, cancel := context.WithCancel(context.Background()) - lookupCalls := 0 - mock := &mockMasterClient{ - lookupFunc: func(ctx context.Context, fileId string) ([]string, error) { - lookupCalls++ - cancel() - return nil, context.Canceled - }, - } - - chunks := []*filer_pb.FileChunk{ - { - FileId: "3,cancel-during-lookup", - Offset: 0, - Size: 10, - }, - } - - _, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0) - if !errors.Is(err, context.Canceled) { - t.Fatalf("expected context.Canceled, got %v", err) - } - if lookupCalls != 1 { - t.Fatalf("expected lookup retries to stop after cancellation, got %d calls", lookupCalls) - } -} diff --git a/weed/iam/helpers.go b/weed/iam/helpers.go index ef94940af..cbbc86fb2 100644 --- a/weed/iam/helpers.go +++ b/weed/iam/helpers.go @@ -37,11 +37,6 @@ func GenerateRandomString(length int, charset string) (string, error) { return string(b), nil } -// GenerateAccessKeyId generates a new access key ID. -func GenerateAccessKeyId() (string, error) { - return GenerateRandomString(AccessKeyIdLength, CharsetUpper) -} - // GenerateSecretAccessKey generates a new secret access key. func GenerateSecretAccessKey() (string, error) { return GenerateRandomString(SecretAccessKeyLength, Charset) @@ -179,11 +174,3 @@ func MapToIdentitiesAction(action string) string { return "" } } - -// MaskAccessKey masks an access key for logging, showing only the first 4 characters. -func MaskAccessKey(accessKeyId string) string { - if len(accessKeyId) > 4 { - return accessKeyId[:4] + "***" - } - return accessKeyId -} diff --git a/weed/iam/helpers_test.go b/weed/iam/helpers_test.go deleted file mode 100644 index 6b39a3779..000000000 --- a/weed/iam/helpers_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package iam - -import ( - "testing" - - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/stretchr/testify/assert" -) - -func TestHash(t *testing.T) { - input := "test" - result := Hash(&input) - assert.NotEmpty(t, result) - assert.Len(t, result, 40) // SHA1 hex is 40 chars - - // Same input should produce same hash - result2 := Hash(&input) - assert.Equal(t, result, result2) - - // Different input should produce different hash - different := "different" - result3 := Hash(&different) - assert.NotEqual(t, result, result3) -} - -func TestGenerateRandomString(t *testing.T) { - // Valid generation - result, err := GenerateRandomString(10, CharsetUpper) - assert.NoError(t, err) - assert.Len(t, result, 10) - - // Different calls should produce different results (with high probability) - result2, err := GenerateRandomString(10, CharsetUpper) - assert.NoError(t, err) - assert.NotEqual(t, result, result2) - - // Invalid length - _, err = GenerateRandomString(0, CharsetUpper) - assert.Error(t, err) - - _, err = GenerateRandomString(-1, CharsetUpper) - assert.Error(t, err) - - // Empty charset - _, err = GenerateRandomString(10, "") - assert.Error(t, err) -} - -func TestGenerateAccessKeyId(t *testing.T) { - keyId, err := GenerateAccessKeyId() - assert.NoError(t, err) - assert.Len(t, keyId, AccessKeyIdLength) -} - -func TestGenerateSecretAccessKey(t *testing.T) { - secretKey, err := GenerateSecretAccessKey() - assert.NoError(t, err) - assert.Len(t, secretKey, SecretAccessKeyLength) -} - -func TestGenerateSecretAccessKey_URLSafe(t *testing.T) { - // Generate multiple keys to increase probability of catching unsafe chars - for i := 0; i < 100; i++ { - secretKey, err := GenerateSecretAccessKey() - assert.NoError(t, err) - - // Verify no URL-unsafe characters that would cause authentication issues - assert.NotContains(t, secretKey, "/", "Secret key should not contain /") - assert.NotContains(t, secretKey, "+", "Secret key should not contain +") - - // Verify only expected characters are present - for _, char := range secretKey { - assert.Contains(t, Charset, string(char), "Secret key contains unexpected character: %c", char) - } - } -} - -func TestStringSlicesEqual(t *testing.T) { - tests := []struct { - a []string - b []string - expected bool - }{ - {[]string{"a", "b", "c"}, []string{"a", "b", "c"}, true}, - {[]string{"c", "b", "a"}, []string{"a", "b", "c"}, true}, // Order independent - {[]string{"a", "b"}, []string{"a", "b", "c"}, false}, - {[]string{}, []string{}, true}, - {nil, nil, true}, - {[]string{"a"}, []string{"b"}, false}, - } - - for _, test := range tests { - result := StringSlicesEqual(test.a, test.b) - assert.Equal(t, test.expected, result) - } -} - -func TestMapToStatementAction(t *testing.T) { - tests := []struct { - input string - expected string - }{ - {StatementActionAdmin, s3_constants.ACTION_ADMIN}, - {StatementActionWrite, s3_constants.ACTION_WRITE}, - {StatementActionRead, s3_constants.ACTION_READ}, - {StatementActionList, s3_constants.ACTION_LIST}, - {StatementActionDelete, s3_constants.ACTION_DELETE_BUCKET}, - // Test fine-grained S3 action mappings (Issue #7864) - {"DeleteObject", s3_constants.ACTION_WRITE}, - {"s3:DeleteObject", s3_constants.ACTION_WRITE}, - {"PutObject", s3_constants.ACTION_WRITE}, - {"s3:PutObject", s3_constants.ACTION_WRITE}, - {"GetObject", s3_constants.ACTION_READ}, - {"s3:GetObject", s3_constants.ACTION_READ}, - {"ListBucket", s3_constants.ACTION_LIST}, - {"s3:ListBucket", s3_constants.ACTION_LIST}, - {"PutObjectAcl", s3_constants.ACTION_WRITE_ACP}, - {"s3:PutObjectAcl", s3_constants.ACTION_WRITE_ACP}, - {"GetObjectAcl", s3_constants.ACTION_READ_ACP}, - {"s3:GetObjectAcl", s3_constants.ACTION_READ_ACP}, - {"unknown", ""}, - } - - for _, test := range tests { - result := MapToStatementAction(test.input) - assert.Equal(t, test.expected, result, "Failed for input: %s", test.input) - } -} - -func TestMapToIdentitiesAction(t *testing.T) { - tests := []struct { - input string - expected string - }{ - {s3_constants.ACTION_ADMIN, StatementActionAdmin}, - {s3_constants.ACTION_WRITE, StatementActionWrite}, - {s3_constants.ACTION_READ, StatementActionRead}, - {s3_constants.ACTION_LIST, StatementActionList}, - {s3_constants.ACTION_DELETE_BUCKET, StatementActionDelete}, - {"unknown", ""}, - } - - for _, test := range tests { - result := MapToIdentitiesAction(test.input) - assert.Equal(t, test.expected, result) - } -} - -func TestMaskAccessKey(t *testing.T) { - tests := []struct { - input string - expected string - }{ - {"AKIAIOSFODNN7EXAMPLE", "AKIA***"}, - {"AKIA", "AKIA"}, - {"AKI", "AKI"}, - {"", ""}, - } - - for _, test := range tests { - result := MaskAccessKey(test.input) - assert.Equal(t, test.expected, result) - } -} diff --git a/weed/iam/integration/iam_manager.go b/weed/iam/integration/iam_manager.go index fb8a47895..bfff9cefd 100644 --- a/weed/iam/integration/iam_manager.go +++ b/weed/iam/integration/iam_manager.go @@ -202,32 +202,6 @@ func (m *IAMManager) getFilerAddress() string { return "" // Fallback to empty string if no provider is set } -// createRoleStore creates a role store based on configuration -func (m *IAMManager) createRoleStore(config *RoleStoreConfig) (RoleStore, error) { - if config == nil { - // Default to generic cached filer role store when no config provided - return NewGenericCachedRoleStore(nil, nil) - } - - switch config.StoreType { - case "", "filer": - // Check if caching is explicitly disabled - if config.StoreConfig != nil { - if noCache, ok := config.StoreConfig["noCache"].(bool); ok && noCache { - return NewFilerRoleStore(config.StoreConfig, nil) - } - } - // Default to generic cached filer store for better performance - return NewGenericCachedRoleStore(config.StoreConfig, nil) - case "cached-filer", "generic-cached": - return NewGenericCachedRoleStore(config.StoreConfig, nil) - case "memory": - return NewMemoryRoleStore(), nil - default: - return nil, fmt.Errorf("unsupported role store type: %s", config.StoreType) - } -} - // createRoleStoreWithProvider creates a role store with a filer address provider function func (m *IAMManager) createRoleStoreWithProvider(config *RoleStoreConfig, filerAddressProvider func() string) (RoleStore, error) { if config == nil { diff --git a/weed/iam/integration/role_store.go b/weed/iam/integration/role_store.go index f2dc128c7..11fbbb44e 100644 --- a/weed/iam/integration/role_store.go +++ b/weed/iam/integration/role_store.go @@ -388,157 +388,3 @@ type CachedFilerRoleStoreConfig struct { ListTTL string `json:"listTtl,omitempty"` // e.g., "1m", "30s" MaxCacheSize int `json:"maxCacheSize,omitempty"` // Maximum number of cached roles } - -// NewCachedFilerRoleStore creates a new cached filer-based role store -func NewCachedFilerRoleStore(config map[string]interface{}) (*CachedFilerRoleStore, error) { - // Create underlying filer store - filerStore, err := NewFilerRoleStore(config, nil) - if err != nil { - return nil, fmt.Errorf("failed to create filer role store: %w", err) - } - - // Parse cache configuration with defaults - cacheTTL := 5 * time.Minute // Default 5 minutes for role cache - listTTL := 1 * time.Minute // Default 1 minute for list cache - maxCacheSize := 1000 // Default max 1000 cached roles - - if config != nil { - if ttlStr, ok := config["ttl"].(string); ok && ttlStr != "" { - if parsed, err := time.ParseDuration(ttlStr); err == nil { - cacheTTL = parsed - } - } - if listTTLStr, ok := config["listTtl"].(string); ok && listTTLStr != "" { - if parsed, err := time.ParseDuration(listTTLStr); err == nil { - listTTL = parsed - } - } - if maxSize, ok := config["maxCacheSize"].(int); ok && maxSize > 0 { - maxCacheSize = maxSize - } - } - - // Create ccache instances with appropriate configurations - pruneCount := int64(maxCacheSize) >> 3 - if pruneCount <= 0 { - pruneCount = 100 - } - - store := &CachedFilerRoleStore{ - filerStore: filerStore, - cache: ccache.New(ccache.Configure().MaxSize(int64(maxCacheSize)).ItemsToPrune(uint32(pruneCount))), - listCache: ccache.New(ccache.Configure().MaxSize(100).ItemsToPrune(10)), // Smaller cache for lists - ttl: cacheTTL, - listTTL: listTTL, - } - - glog.V(2).Infof("Initialized CachedFilerRoleStore with TTL %v, List TTL %v, Max Cache Size %d", - cacheTTL, listTTL, maxCacheSize) - - return store, nil -} - -// StoreRole stores a role definition and invalidates the cache -func (c *CachedFilerRoleStore) StoreRole(ctx context.Context, filerAddress string, roleName string, role *RoleDefinition) error { - // Store in filer - err := c.filerStore.StoreRole(ctx, filerAddress, roleName, role) - if err != nil { - return err - } - - // Invalidate cache entries - c.cache.Delete(roleName) - c.listCache.Clear() // Invalidate list cache - - glog.V(3).Infof("Stored and invalidated cache for role %s", roleName) - return nil -} - -// GetRole retrieves a role definition with caching -func (c *CachedFilerRoleStore) GetRole(ctx context.Context, filerAddress string, roleName string) (*RoleDefinition, error) { - // Try to get from cache first - item := c.cache.Get(roleName) - if item != nil { - // Cache hit - return cached role (DO NOT extend TTL) - role := item.Value().(*RoleDefinition) - glog.V(4).Infof("Cache hit for role %s", roleName) - return copyRoleDefinition(role), nil - } - - // Cache miss - fetch from filer - glog.V(4).Infof("Cache miss for role %s, fetching from filer", roleName) - role, err := c.filerStore.GetRole(ctx, filerAddress, roleName) - if err != nil { - return nil, err - } - - // Cache the result with TTL - c.cache.Set(roleName, copyRoleDefinition(role), c.ttl) - glog.V(3).Infof("Cached role %s with TTL %v", roleName, c.ttl) - return role, nil -} - -// ListRoles lists all role names with caching -func (c *CachedFilerRoleStore) ListRoles(ctx context.Context, filerAddress string) ([]string, error) { - // Use a constant key for the role list cache - const listCacheKey = "role_list" - - // Try to get from list cache first - item := c.listCache.Get(listCacheKey) - if item != nil { - // Cache hit - return cached list (DO NOT extend TTL) - roles := item.Value().([]string) - glog.V(4).Infof("List cache hit, returning %d roles", len(roles)) - return append([]string(nil), roles...), nil // Return a copy - } - - // Cache miss - fetch from filer - glog.V(4).Infof("List cache miss, fetching from filer") - roles, err := c.filerStore.ListRoles(ctx, filerAddress) - if err != nil { - return nil, err - } - - // Cache the result with TTL (store a copy) - rolesCopy := append([]string(nil), roles...) - c.listCache.Set(listCacheKey, rolesCopy, c.listTTL) - glog.V(3).Infof("Cached role list with %d entries, TTL %v", len(roles), c.listTTL) - return roles, nil -} - -// DeleteRole deletes a role definition and invalidates the cache -func (c *CachedFilerRoleStore) DeleteRole(ctx context.Context, filerAddress string, roleName string) error { - // Delete from filer - err := c.filerStore.DeleteRole(ctx, filerAddress, roleName) - if err != nil { - return err - } - - // Invalidate cache entries - c.cache.Delete(roleName) - c.listCache.Clear() // Invalidate list cache - - glog.V(3).Infof("Deleted and invalidated cache for role %s", roleName) - return nil -} - -// ClearCache clears all cached entries (for testing or manual cache invalidation) -func (c *CachedFilerRoleStore) ClearCache() { - c.cache.Clear() - c.listCache.Clear() - glog.V(2).Infof("Cleared all role cache entries") -} - -// GetCacheStats returns cache statistics -func (c *CachedFilerRoleStore) GetCacheStats() map[string]interface{} { - return map[string]interface{}{ - "roleCache": map[string]interface{}{ - "size": c.cache.ItemCount(), - "ttl": c.ttl.String(), - }, - "listCache": map[string]interface{}{ - "size": c.listCache.ItemCount(), - "ttl": c.listTTL.String(), - }, - } -} diff --git a/weed/iam/policy/condition_set_test.go b/weed/iam/policy/condition_set_test.go deleted file mode 100644 index 4c7e8bb67..000000000 --- a/weed/iam/policy/condition_set_test.go +++ /dev/null @@ -1,687 +0,0 @@ -package policy - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestConditionSetOperators(t *testing.T) { - engine := setupTestPolicyEngine(t) - - t.Run("ForAnyValue:StringEquals", func(t *testing.T) { - trustPolicy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowOIDC", - Effect: "Allow", - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - Condition: map[string]map[string]interface{}{ - "ForAnyValue:StringEquals": { - "oidc:roles": []string{"Dev.SeaweedFS.TestBucket.ReadWrite", "Dev.SeaweedFS.Admin"}, - }, - }, - }, - }, - } - - // Match: Admin is in the requested roles - evalCtxMatch := &EvaluationContext{ - Principal: "web-identity-user", - Action: "sts:AssumeRoleWithWebIdentity", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "oidc:roles": []string{"Dev.SeaweedFS.Admin", "OtherRole"}, - }, - } - resultMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxMatch) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultMatch.Effect) - - // No Match - evalCtxNoMatch := &EvaluationContext{ - Principal: "web-identity-user", - Action: "sts:AssumeRoleWithWebIdentity", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "oidc:roles": []string{"OtherRole1", "OtherRole2"}, - }, - } - resultNoMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxNoMatch) - require.NoError(t, err) - assert.Equal(t, EffectDeny, resultNoMatch.Effect) - - // No Match: Empty context for ForAnyValue (should deny) - evalCtxEmpty := &EvaluationContext{ - Principal: "web-identity-user", - Action: "sts:AssumeRoleWithWebIdentity", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "oidc:roles": []string{}, - }, - } - resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxEmpty) - require.NoError(t, err) - assert.Equal(t, EffectDeny, resultEmpty.Effect, "ForAnyValue should deny when context is empty") - }) - - t.Run("ForAllValues:StringEquals", func(t *testing.T) { - trustPolicyAll := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowOIDCAll", - Effect: "Allow", - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - Condition: map[string]map[string]interface{}{ - "ForAllValues:StringEquals": { - "oidc:roles": []string{"RoleA", "RoleB", "RoleC"}, - }, - }, - }, - }, - } - - // Match: All requested roles ARE in the allowed set - evalCtxAllMatch := &EvaluationContext{ - Principal: "web-identity-user", - Action: "sts:AssumeRoleWithWebIdentity", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "oidc:roles": []string{"RoleA", "RoleB"}, - }, - } - resultAllMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxAllMatch) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultAllMatch.Effect) - - // Fail: RoleD is NOT in the allowed set - evalCtxAllFail := &EvaluationContext{ - Principal: "web-identity-user", - Action: "sts:AssumeRoleWithWebIdentity", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "oidc:roles": []string{"RoleA", "RoleD"}, - }, - } - resultAllFail, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxAllFail) - require.NoError(t, err) - assert.Equal(t, EffectDeny, resultAllFail.Effect) - - // Vacuously true: Request has NO roles - evalCtxEmpty := &EvaluationContext{ - Principal: "web-identity-user", - Action: "sts:AssumeRoleWithWebIdentity", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "oidc:roles": []string{}, - }, - } - resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxEmpty) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultEmpty.Effect) - }) - - t.Run("ForAllValues:NumericEqualsVacuouslyTrue", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowNumericAll", - Effect: "Allow", - Action: []string{"sts:AssumeRole"}, - Condition: map[string]map[string]interface{}{ - "ForAllValues:NumericEquals": { - "aws:MultiFactorAuthAge": []string{"3600", "7200"}, - }, - }, - }, - }, - } - - // Vacuously true: Request has NO MFA age info - evalCtxEmpty := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "aws:MultiFactorAuthAge": []string{}, - }, - } - resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when numeric context is empty for ForAllValues") - }) - - t.Run("ForAllValues:BoolVacuouslyTrue", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowBoolAll", - Effect: "Allow", - Action: []string{"sts:AssumeRole"}, - Condition: map[string]map[string]interface{}{ - "ForAllValues:Bool": { - "aws:SecureTransport": "true", - }, - }, - }, - }, - } - - // Vacuously true - evalCtxEmpty := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "aws:SecureTransport": []interface{}{}, - }, - } - resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when bool context is empty for ForAllValues") - }) - - t.Run("ForAllValues:DateVacuouslyTrue", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowDateAll", - Effect: "Allow", - Action: []string{"sts:AssumeRole"}, - Condition: map[string]map[string]interface{}{ - "ForAllValues:DateGreaterThan": { - "aws:CurrentTime": "2020-01-01T00:00:00Z", - }, - }, - }, - }, - } - - // Vacuously true - evalCtxEmpty := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "aws:CurrentTime": []interface{}{}, - }, - } - resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when date context is empty for ForAllValues") - }) - - t.Run("ForAllValues:DateWithLabelsAsStrings", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowDateStrings", - Effect: "Allow", - Action: []string{"sts:AssumeRole"}, - Condition: map[string]map[string]interface{}{ - "ForAllValues:DateGreaterThan": { - "aws:CurrentTime": "2020-01-01T00:00:00Z", - }, - }, - }, - }, - } - - evalCtx := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "aws:CurrentTime": []string{"2021-01-01T00:00:00Z", "2022-01-01T00:00:00Z"}, - }, - } - result, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtx) - require.NoError(t, err) - assert.Equal(t, EffectAllow, result.Effect, "Should allow when date context is a slice of strings") - }) - - t.Run("ForAllValues:BoolWithLabelsAsStrings", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowBoolStrings", - Effect: "Allow", - Action: []string{"sts:AssumeRole"}, - Condition: map[string]map[string]interface{}{ - "ForAllValues:Bool": { - "aws:SecureTransport": "true", - }, - }, - }, - }, - } - - evalCtx := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "aws:SecureTransport": []string{"true", "true"}, - }, - } - result, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtx) - require.NoError(t, err) - assert.Equal(t, EffectAllow, result.Effect, "Should allow when bool context is a slice of strings") - }) - - t.Run("StringEqualsIgnoreCaseWithVariable", func(t *testing.T) { - policyDoc := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowVar", - Effect: "Allow", - Action: []string{"s3:GetObject"}, - Resource: []string{"arn:aws:s3:::bucket/*"}, - Condition: map[string]map[string]interface{}{ - "StringEqualsIgnoreCase": { - "s3:prefix": "${aws:username}/", - }, - }, - }, - }, - } - - err := engine.AddPolicy("", "var-policy", policyDoc) - require.NoError(t, err) - - evalCtx := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/ALICE/file.txt", - RequestContext: map[string]interface{}{ - "s3:prefix": "ALICE/", - "aws:username": "alice", - }, - } - - result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"var-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectAllow, result.Effect, "Should allow when variable expands and matches case-insensitively") - }) - - t.Run("StringLike:CaseSensitivity", func(t *testing.T) { - policyDoc := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowCaseSensitiveLike", - Effect: "Allow", - Action: []string{"s3:GetObject"}, - Resource: []string{"arn:aws:s3:::bucket/*"}, - Condition: map[string]map[string]interface{}{ - "StringLike": { - "s3:prefix": "Project/*", - }, - }, - }, - }, - } - - err := engine.AddPolicy("", "like-policy", policyDoc) - require.NoError(t, err) - - // Match: Case sensitive match - evalCtxMatch := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/Project/file.txt", - RequestContext: map[string]interface{}{ - "s3:prefix": "Project/data", - }, - } - resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"like-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultMatch.Effect, "Should allow when case matches exactly") - - // Fail: Case insensitive match (should fail for StringLike) - evalCtxFail := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/project/file.txt", - RequestContext: map[string]interface{}{ - "s3:prefix": "project/data", // lowercase 'p' - }, - } - resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"like-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when case does not match for StringLike") - }) - - t.Run("NumericNotEquals:Logic", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "DenySpecificAges", - Effect: "Allow", - Action: []string{"sts:AssumeRole"}, - Resource: []string{"*"}, - Condition: map[string]map[string]interface{}{ - "ForAllValues:NumericNotEquals": { - "aws:MultiFactorAuthAge": []string{"3600", "7200"}, - }, - }, - }, - }, - } - - err := engine.AddPolicy("", "numeric-not-equals-policy", policy) - require.NoError(t, err) - - // Fail: One age matches an excluded value (3600) - evalCtxFail := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "aws:MultiFactorAuthAge": []string{"3600", "1800"}, - }, - } - resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"numeric-not-equals-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when one age matches an excluded value") - - // Pass: No age matches any excluded value - evalCtxPass := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "aws:MultiFactorAuthAge": []string{"1800", "900"}, - }, - } - resultPass, err := engine.Evaluate(context.Background(), "", evalCtxPass, []string{"numeric-not-equals-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultPass.Effect, "Should allow when no age matches excluded values") - }) - - t.Run("DateNotEquals:Logic", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "DenySpecificTimes", - Effect: "Allow", - Action: []string{"sts:AssumeRole"}, - Resource: []string{"*"}, - Condition: map[string]map[string]interface{}{ - "ForAllValues:DateNotEquals": { - "aws:CurrentTime": []string{"2024-01-01T00:00:00Z", "2024-01-02T00:00:00Z"}, - }, - }, - }, - }, - } - - err := engine.AddPolicy("", "date-not-equals-policy", policy) - require.NoError(t, err) - - // Fail: One time matches an excluded value - evalCtxFail := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "aws:CurrentTime": []string{"2024-01-01T00:00:00Z", "2024-01-03T00:00:00Z"}, - }, - } - resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"date-not-equals-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when one date matches an excluded value") - }) - - t.Run("IpAddress:SetOperators", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowSpecificIPs", - Effect: "Allow", - Action: []string{"s3:GetObject"}, - Resource: []string{"*"}, - Condition: map[string]map[string]interface{}{ - "ForAllValues:IpAddress": { - "aws:SourceIp": []string{"192.168.1.0/24", "10.0.0.1"}, - }, - }, - }, - }, - } - - err := engine.AddPolicy("", "ip-set-policy", policy) - require.NoError(t, err) - - // Match: All source IPs are in allowed ranges - evalCtxMatch := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/file.txt", - RequestContext: map[string]interface{}{ - "aws:SourceIp": []string{"192.168.1.10", "10.0.0.1"}, - }, - } - resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-set-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultMatch.Effect) - - // Fail: One source IP is NOT in allowed ranges - evalCtxFail := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/file.txt", - RequestContext: map[string]interface{}{ - "aws:SourceIp": []string{"192.168.1.10", "172.16.0.1"}, - }, - } - resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"ip-set-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectDeny, resultFail.Effect) - - // ForAnyValue: IPAddress - policyAny := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowAnySpecificIPs", - Effect: "Allow", - Action: []string{"s3:GetObject"}, - Resource: []string{"*"}, - Condition: map[string]map[string]interface{}{ - "ForAnyValue:IpAddress": { - "aws:SourceIp": []string{"192.168.1.0/24"}, - }, - }, - }, - }, - } - err = engine.AddPolicy("", "ip-any-policy", policyAny) - require.NoError(t, err) - - evalCtxAnyMatch := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/file.txt", - RequestContext: map[string]interface{}{ - "aws:SourceIp": []string{"192.168.1.10", "172.16.0.1"}, - }, - } - resultAnyMatch, err := engine.Evaluate(context.Background(), "", evalCtxAnyMatch, []string{"ip-any-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultAnyMatch.Effect) - }) - - t.Run("IpAddress:SingleStringValue", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowSingleIP", - Effect: "Allow", - Action: []string{"s3:GetObject"}, - Resource: []string{"*"}, - Condition: map[string]map[string]interface{}{ - "IpAddress": { - "aws:SourceIp": "192.168.1.1", - }, - }, - }, - }, - } - - err := engine.AddPolicy("", "ip-single-policy", policy) - require.NoError(t, err) - - evalCtxMatch := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/file.txt", - RequestContext: map[string]interface{}{ - "aws:SourceIp": "192.168.1.1", - }, - } - resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-single-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultMatch.Effect) - - evalCtxNoMatch := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/file.txt", - RequestContext: map[string]interface{}{ - "aws:SourceIp": "10.0.0.1", - }, - } - resultNoMatch, err := engine.Evaluate(context.Background(), "", evalCtxNoMatch, []string{"ip-single-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectDeny, resultNoMatch.Effect) - }) - - t.Run("Bool:StringSlicePolicyValues", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowWithBoolStrings", - Effect: "Allow", - Action: []string{"s3:GetObject"}, - Resource: []string{"*"}, - Condition: map[string]map[string]interface{}{ - "Bool": { - "aws:SecureTransport": []string{"true", "false"}, - }, - }, - }, - }, - } - - err := engine.AddPolicy("", "bool-string-slice-policy", policy) - require.NoError(t, err) - - evalCtx := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/file.txt", - RequestContext: map[string]interface{}{ - "aws:SecureTransport": "true", - }, - } - result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"bool-string-slice-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectAllow, result.Effect) - }) - - t.Run("StringEqualsIgnoreCase:StringSlicePolicyValues", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowWithIgnoreCaseStrings", - Effect: "Allow", - Action: []string{"s3:GetObject"}, - Resource: []string{"*"}, - Condition: map[string]map[string]interface{}{ - "StringEqualsIgnoreCase": { - "s3:x-amz-server-side-encryption": []string{"AES256", "aws:kms"}, - }, - }, - }, - }, - } - - err := engine.AddPolicy("", "string-ignorecase-slice-policy", policy) - require.NoError(t, err) - - evalCtx := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/file.txt", - RequestContext: map[string]interface{}{ - "s3:x-amz-server-side-encryption": "aes256", - }, - } - result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"string-ignorecase-slice-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectAllow, result.Effect) - }) - - t.Run("IpAddress:CustomContextKey", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowCustomIPKey", - Effect: "Allow", - Action: []string{"s3:GetObject"}, - Resource: []string{"*"}, - Condition: map[string]map[string]interface{}{ - "IpAddress": { - "custom:VpcIp": "10.0.0.0/16", - }, - }, - }, - }, - } - - err := engine.AddPolicy("", "ip-custom-key-policy", policy) - require.NoError(t, err) - - evalCtxMatch := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/file.txt", - RequestContext: map[string]interface{}{ - "custom:VpcIp": "10.0.5.1", - }, - } - resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-custom-key-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultMatch.Effect) - - evalCtxNoMatch := &EvaluationContext{ - Principal: "user", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::bucket/file.txt", - RequestContext: map[string]interface{}{ - "custom:VpcIp": "192.168.1.1", - }, - } - resultNoMatch, err := engine.Evaluate(context.Background(), "", evalCtxNoMatch, []string{"ip-custom-key-policy"}) - require.NoError(t, err) - assert.Equal(t, EffectDeny, resultNoMatch.Effect) - }) -} diff --git a/weed/iam/policy/negation_test.go b/weed/iam/policy/negation_test.go deleted file mode 100644 index 31eed396f..000000000 --- a/weed/iam/policy/negation_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package policy - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNegationSetOperators(t *testing.T) { - engine := setupTestPolicyEngine(t) - - t.Run("ForAllValues:StringNotEquals", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "DenyAdmin", - Effect: "Allow", - Action: []string{"sts:AssumeRole"}, - Condition: map[string]map[string]interface{}{ - "ForAllValues:StringNotEquals": { - "oidc:roles": []string{"Admin"}, - }, - }, - }, - }, - } - - // All roles are NOT "Admin" -> Should Allow - evalCtxAllow := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "oidc:roles": []string{"User", "Developer"}, - }, - } - resultAllow, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxAllow) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultAllow.Effect, "Should allow when ALL roles satisfy StringNotEquals Admin") - - // One role is "Admin" -> Should Deny - evalCtxDeny := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "oidc:roles": []string{"Admin", "User"}, - }, - } - resultDeny, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxDeny) - require.NoError(t, err) - assert.Equal(t, EffectDeny, resultDeny.Effect, "Should deny when one role is Admin and fails StringNotEquals") - }) - - t.Run("ForAnyValue:StringNotEquals", func(t *testing.T) { - policy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "Requirement", - Effect: "Allow", - Action: []string{"sts:AssumeRole"}, - Condition: map[string]map[string]interface{}{ - "ForAnyValue:StringNotEquals": { - "oidc:roles": []string{"Prohibited"}, - }, - }, - }, - }, - } - - // At least one role is NOT prohibited -> Should Allow - evalCtxAllow := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "oidc:roles": []string{"Prohibited", "Allowed"}, - }, - } - resultAllow, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxAllow) - require.NoError(t, err) - assert.Equal(t, EffectAllow, resultAllow.Effect, "Should allow when at least one role is NOT Prohibited") - - // All roles are Prohibited -> Should Deny - evalCtxDeny := &EvaluationContext{ - Principal: "user", - Action: "sts:AssumeRole", - Resource: "arn:aws:iam::role/test-role", - RequestContext: map[string]interface{}{ - "oidc:roles": []string{"Prohibited", "Prohibited"}, - }, - } - resultDeny, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxDeny) - require.NoError(t, err) - assert.Equal(t, EffectDeny, resultDeny.Effect, "Should deny when ALL roles are Prohibited") - }) -} diff --git a/weed/iam/policy/policy_engine.go b/weed/iam/policy/policy_engine.go index c8cd07367..7feca5c92 100644 --- a/weed/iam/policy/policy_engine.go +++ b/weed/iam/policy/policy_engine.go @@ -1155,11 +1155,6 @@ func ValidatePolicyDocumentWithType(policy *PolicyDocument, policyType string) e return nil } -// validateStatement validates a single statement (for backward compatibility) -func validateStatement(statement *Statement) error { - return validateStatementWithType(statement, "resource") -} - // validateStatementWithType validates a single statement based on policy type func validateStatementWithType(statement *Statement, policyType string) error { if statement.Effect != "Allow" && statement.Effect != "Deny" { @@ -1198,29 +1193,6 @@ func validateStatementWithType(statement *Statement, policyType string) error { return nil } -// matchResource checks if a resource pattern matches a requested resource -// Uses hybrid approach: simple suffix wildcards for compatibility, filepath.Match for complex patterns -func matchResource(pattern, resource string) bool { - if pattern == resource { - return true - } - - // Handle simple suffix wildcard (backward compatibility) - if strings.HasSuffix(pattern, "*") { - prefix := pattern[:len(pattern)-1] - return strings.HasPrefix(resource, prefix) - } - - // For complex patterns, use filepath.Match for advanced wildcard support (*, ?, []) - matched, err := filepath.Match(pattern, resource) - if err != nil { - // Fallback to exact match if pattern is malformed - return pattern == resource - } - - return matched -} - // awsIAMMatch performs AWS IAM-compliant pattern matching with case-insensitivity and policy variable support func awsIAMMatch(pattern, value string, evalCtx *EvaluationContext) bool { // Step 1: Substitute policy variables (e.g., ${aws:username}, ${saml:username}) @@ -1274,16 +1246,6 @@ func expandPolicyVariables(pattern string, evalCtx *EvaluationContext) string { return result } -// getContextValue safely gets a value from the evaluation context -func getContextValue(evalCtx *EvaluationContext, key, defaultValue string) string { - if value, exists := evalCtx.RequestContext[key]; exists { - if str, ok := value.(string); ok { - return str - } - } - return defaultValue -} - // AwsWildcardMatch performs case-insensitive wildcard matching like AWS IAM func AwsWildcardMatch(pattern, value string) bool { // Create regex pattern key for caching @@ -1322,29 +1284,6 @@ func AwsWildcardMatch(pattern, value string) bool { return regex.MatchString(value) } -// matchAction checks if an action pattern matches a requested action -// Uses hybrid approach: simple suffix wildcards for compatibility, filepath.Match for complex patterns -func matchAction(pattern, action string) bool { - if pattern == action { - return true - } - - // Handle simple suffix wildcard (backward compatibility) - if strings.HasSuffix(pattern, "*") { - prefix := pattern[:len(pattern)-1] - return strings.HasPrefix(action, prefix) - } - - // For complex patterns, use filepath.Match for advanced wildcard support (*, ?, []) - matched, err := filepath.Match(pattern, action) - if err != nil { - // Fallback to exact match if pattern is malformed - return pattern == action - } - - return matched -} - // evaluateStringConditionIgnoreCase evaluates string conditions with case insensitivity func (e *PolicyEngine) evaluateStringConditionIgnoreCase(block map[string]interface{}, evalCtx *EvaluationContext, shouldMatch bool, useWildcard bool, forAllValues bool) bool { for key, expectedValues := range block { diff --git a/weed/iam/policy/policy_engine_principal_test.go b/weed/iam/policy/policy_engine_principal_test.go deleted file mode 100644 index 58714eb98..000000000 --- a/weed/iam/policy/policy_engine_principal_test.go +++ /dev/null @@ -1,421 +0,0 @@ -package policy - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestPrincipalMatching tests the matchesPrincipal method -func TestPrincipalMatching(t *testing.T) { - engine := setupTestPolicyEngine(t) - - tests := []struct { - name string - principal interface{} - evalCtx *EvaluationContext - want bool - }{ - { - name: "plain wildcard principal", - principal: "*", - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{}, - }, - want: true, - }, - { - name: "structured wildcard federated principal", - principal: map[string]interface{}{ - "Federated": "*", - }, - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{}, - }, - want: true, - }, - { - name: "wildcard in array", - principal: map[string]interface{}{ - "Federated": []interface{}{"specific-provider", "*"}, - }, - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{}, - }, - want: true, - }, - { - name: "specific federated provider match", - principal: map[string]interface{}{ - "Federated": "https://example.com/oidc", - }, - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "https://example.com/oidc", - }, - }, - want: true, - }, - { - name: "specific federated provider no match", - principal: map[string]interface{}{ - "Federated": "https://example.com/oidc", - }, - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "https://other.com/oidc", - }, - }, - want: false, - }, - { - name: "array with specific provider match", - principal: map[string]interface{}{ - "Federated": []string{"https://provider1.com", "https://provider2.com"}, - }, - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "https://provider2.com", - }, - }, - want: true, - }, - { - name: "AWS principal match", - principal: map[string]interface{}{ - "AWS": "arn:aws:iam::123456789012:user/alice", - }, - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{ - "aws:PrincipalArn": "arn:aws:iam::123456789012:user/alice", - }, - }, - want: true, - }, - { - name: "Service principal match", - principal: map[string]interface{}{ - "Service": "s3.amazonaws.com", - }, - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{ - "aws:PrincipalServiceName": "s3.amazonaws.com", - }, - }, - want: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := engine.matchesPrincipal(tt.principal, tt.evalCtx) - assert.Equal(t, tt.want, result, "Principal matching failed for: %s", tt.name) - }) - } -} - -// TestEvaluatePrincipalValue tests the evaluatePrincipalValue method -func TestEvaluatePrincipalValue(t *testing.T) { - engine := setupTestPolicyEngine(t) - - tests := []struct { - name string - principalValue interface{} - contextKey string - evalCtx *EvaluationContext - want bool - }{ - { - name: "wildcard string", - principalValue: "*", - contextKey: "aws:FederatedProvider", - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{}, - }, - want: true, - }, - { - name: "specific string match", - principalValue: "https://example.com", - contextKey: "aws:FederatedProvider", - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "https://example.com", - }, - }, - want: true, - }, - { - name: "specific string no match", - principalValue: "https://example.com", - contextKey: "aws:FederatedProvider", - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "https://other.com", - }, - }, - want: false, - }, - { - name: "wildcard in array", - principalValue: []interface{}{"provider1", "*"}, - contextKey: "aws:FederatedProvider", - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{}, - }, - want: true, - }, - { - name: "array match", - principalValue: []string{"provider1", "provider2", "provider3"}, - contextKey: "aws:FederatedProvider", - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "provider2", - }, - }, - want: true, - }, - { - name: "array no match", - principalValue: []string{"provider1", "provider2"}, - contextKey: "aws:FederatedProvider", - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "provider3", - }, - }, - want: false, - }, - { - name: "missing context key", - principalValue: "specific-value", - contextKey: "aws:FederatedProvider", - evalCtx: &EvaluationContext{ - RequestContext: map[string]interface{}{}, - }, - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := engine.evaluatePrincipalValue(tt.principalValue, tt.evalCtx, tt.contextKey) - assert.Equal(t, tt.want, result, "Principal value evaluation failed for: %s", tt.name) - }) - } -} - -// TestTrustPolicyEvaluation tests the EvaluateTrustPolicy method -func TestTrustPolicyEvaluation(t *testing.T) { - engine := setupTestPolicyEngine(t) - - tests := []struct { - name string - trustPolicy *PolicyDocument - evalCtx *EvaluationContext - wantEffect Effect - wantErr bool - }{ - { - name: "wildcard federated principal allows any provider", - trustPolicy: &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Effect: "Allow", - Principal: map[string]interface{}{ - "Federated": "*", - }, - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - }, - }, - }, - evalCtx: &EvaluationContext{ - Action: "sts:AssumeRoleWithWebIdentity", - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "https://any-provider.com", - }, - }, - wantEffect: EffectAllow, - wantErr: false, - }, - { - name: "specific federated principal matches", - trustPolicy: &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Effect: "Allow", - Principal: map[string]interface{}{ - "Federated": "https://example.com/oidc", - }, - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - }, - }, - }, - evalCtx: &EvaluationContext{ - Action: "sts:AssumeRoleWithWebIdentity", - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "https://example.com/oidc", - }, - }, - wantEffect: EffectAllow, - wantErr: false, - }, - { - name: "specific federated principal does not match", - trustPolicy: &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Effect: "Allow", - Principal: map[string]interface{}{ - "Federated": "https://example.com/oidc", - }, - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - }, - }, - }, - evalCtx: &EvaluationContext{ - Action: "sts:AssumeRoleWithWebIdentity", - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "https://other.com/oidc", - }, - }, - wantEffect: EffectDeny, - wantErr: false, - }, - { - name: "plain wildcard principal", - trustPolicy: &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Effect: "Allow", - Principal: "*", - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - }, - }, - }, - evalCtx: &EvaluationContext{ - Action: "sts:AssumeRoleWithWebIdentity", - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "https://any-provider.com", - }, - }, - wantEffect: EffectAllow, - wantErr: false, - }, - { - name: "trust policy with conditions", - trustPolicy: &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Effect: "Allow", - Principal: map[string]interface{}{ - "Federated": "*", - }, - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - Condition: map[string]map[string]interface{}{ - "StringEquals": { - "oidc:aud": "my-app-id", - }, - }, - }, - }, - }, - evalCtx: &EvaluationContext{ - Action: "sts:AssumeRoleWithWebIdentity", - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "https://provider.com", - "oidc:aud": "my-app-id", - }, - }, - wantEffect: EffectAllow, - wantErr: false, - }, - { - name: "trust policy condition not met", - trustPolicy: &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Effect: "Allow", - Principal: map[string]interface{}{ - "Federated": "*", - }, - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - Condition: map[string]map[string]interface{}{ - "StringEquals": { - "oidc:aud": "my-app-id", - }, - }, - }, - }, - }, - evalCtx: &EvaluationContext{ - Action: "sts:AssumeRoleWithWebIdentity", - RequestContext: map[string]interface{}{ - "aws:FederatedProvider": "https://provider.com", - "oidc:aud": "wrong-app-id", - }, - }, - wantEffect: EffectDeny, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := engine.EvaluateTrustPolicy(context.Background(), tt.trustPolicy, tt.evalCtx) - - if tt.wantErr { - assert.Error(t, err) - } else { - require.NoError(t, err) - assert.Equal(t, tt.wantEffect, result.Effect, "Trust policy evaluation failed for: %s", tt.name) - } - }) - } -} - -// TestGetPrincipalContextKey tests the context key mapping -func TestGetPrincipalContextKey(t *testing.T) { - tests := []struct { - name string - principalType string - want string - }{ - { - name: "Federated principal", - principalType: "Federated", - want: "aws:FederatedProvider", - }, - { - name: "AWS principal", - principalType: "AWS", - want: "aws:PrincipalArn", - }, - { - name: "Service principal", - principalType: "Service", - want: "aws:PrincipalServiceName", - }, - { - name: "Custom principal type", - principalType: "CustomType", - want: "aws:PrincipalCustomType", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := getPrincipalContextKey(tt.principalType) - assert.Equal(t, tt.want, result, "Context key mapping failed for: %s", tt.name) - }) - } -} diff --git a/weed/iam/policy/policy_engine_test.go b/weed/iam/policy/policy_engine_test.go deleted file mode 100644 index 3a150ba99..000000000 --- a/weed/iam/policy/policy_engine_test.go +++ /dev/null @@ -1,426 +0,0 @@ -package policy - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestPolicyEngineInitialization tests policy engine initialization -func TestPolicyEngineInitialization(t *testing.T) { - tests := []struct { - name string - config *PolicyEngineConfig - wantErr bool - }{ - { - name: "valid config", - config: &PolicyEngineConfig{ - DefaultEffect: "Deny", - StoreType: "memory", - }, - wantErr: false, - }, - { - name: "invalid default effect", - config: &PolicyEngineConfig{ - DefaultEffect: "Invalid", - StoreType: "memory", - }, - wantErr: true, - }, - { - name: "nil config", - config: nil, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - engine := NewPolicyEngine() - - err := engine.Initialize(tt.config) - - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.True(t, engine.IsInitialized()) - } - }) - } -} - -// TestPolicyDocumentValidation tests policy document structure validation -func TestPolicyDocumentValidation(t *testing.T) { - tests := []struct { - name string - policy *PolicyDocument - wantErr bool - errorMsg string - }{ - { - name: "valid policy document", - policy: &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowS3Read", - Effect: "Allow", - Action: []string{"s3:GetObject", "s3:ListBucket"}, - Resource: []string{"arn:aws:s3:::mybucket/*"}, - }, - }, - }, - wantErr: false, - }, - { - name: "missing version", - policy: &PolicyDocument{ - Statement: []Statement{ - { - Effect: "Allow", - Action: []string{"s3:GetObject"}, - Resource: []string{"arn:aws:s3:::mybucket/*"}, - }, - }, - }, - wantErr: true, - errorMsg: "version is required", - }, - { - name: "empty statements", - policy: &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{}, - }, - wantErr: true, - errorMsg: "at least one statement is required", - }, - { - name: "invalid effect", - policy: &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Effect: "Maybe", - Action: []string{"s3:GetObject"}, - Resource: []string{"arn:aws:s3:::mybucket/*"}, - }, - }, - }, - wantErr: true, - errorMsg: "invalid effect", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := ValidatePolicyDocument(tt.policy) - - if tt.wantErr { - assert.Error(t, err) - if tt.errorMsg != "" { - assert.Contains(t, err.Error(), tt.errorMsg) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -// TestPolicyEvaluation tests policy evaluation logic -func TestPolicyEvaluation(t *testing.T) { - engine := setupTestPolicyEngine(t) - - // Add test policies - readPolicy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowS3Read", - Effect: "Allow", - Action: []string{"s3:GetObject", "s3:ListBucket"}, - Resource: []string{ - "arn:aws:s3:::public-bucket/*", // For object operations - "arn:aws:s3:::public-bucket", // For bucket operations - }, - }, - }, - } - - err := engine.AddPolicy("", "read-policy", readPolicy) - require.NoError(t, err) - - denyPolicy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "DenyS3Delete", - Effect: "Deny", - Action: []string{"s3:DeleteObject"}, - Resource: []string{"arn:aws:s3:::*"}, - }, - }, - } - - err = engine.AddPolicy("", "deny-policy", denyPolicy) - require.NoError(t, err) - - tests := []struct { - name string - context *EvaluationContext - policies []string - want Effect - }{ - { - name: "allow read access", - context: &EvaluationContext{ - Principal: "user:alice", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::public-bucket/file.txt", - RequestContext: map[string]interface{}{ - "aws:SourceIp": "192.168.1.100", - }, - }, - policies: []string{"read-policy"}, - want: EffectAllow, - }, - { - name: "deny delete access (explicit deny)", - context: &EvaluationContext{ - Principal: "user:alice", - Action: "s3:DeleteObject", - Resource: "arn:aws:s3:::public-bucket/file.txt", - }, - policies: []string{"read-policy", "deny-policy"}, - want: EffectDeny, - }, - { - name: "deny by default (no matching policy)", - context: &EvaluationContext{ - Principal: "user:alice", - Action: "s3:PutObject", - Resource: "arn:aws:s3:::public-bucket/file.txt", - }, - policies: []string{"read-policy"}, - want: EffectDeny, - }, - { - name: "allow with wildcard action", - context: &EvaluationContext{ - Principal: "user:admin", - Action: "s3:ListBucket", - Resource: "arn:aws:s3:::public-bucket", - }, - policies: []string{"read-policy"}, - want: EffectAllow, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := engine.Evaluate(context.Background(), "", tt.context, tt.policies) - - assert.NoError(t, err) - assert.Equal(t, tt.want, result.Effect) - - // Verify evaluation details - assert.NotNil(t, result.EvaluationDetails) - assert.Equal(t, tt.context.Action, result.EvaluationDetails.Action) - assert.Equal(t, tt.context.Resource, result.EvaluationDetails.Resource) - }) - } -} - -// TestConditionEvaluation tests policy conditions -func TestConditionEvaluation(t *testing.T) { - engine := setupTestPolicyEngine(t) - - // Policy with IP address condition - conditionalPolicy := &PolicyDocument{ - Version: "2012-10-17", - Statement: []Statement{ - { - Sid: "AllowFromOfficeIP", - Effect: "Allow", - Action: []string{"s3:*"}, - Resource: []string{"arn:aws:s3:::*"}, - Condition: map[string]map[string]interface{}{ - "IpAddress": { - "aws:SourceIp": []string{"192.168.1.0/24", "10.0.0.0/8"}, - }, - }, - }, - }, - } - - err := engine.AddPolicy("", "ip-conditional", conditionalPolicy) - require.NoError(t, err) - - tests := []struct { - name string - context *EvaluationContext - want Effect - }{ - { - name: "allow from office IP", - context: &EvaluationContext{ - Principal: "user:alice", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::mybucket/file.txt", - RequestContext: map[string]interface{}{ - "aws:SourceIp": "192.168.1.100", - }, - }, - want: EffectAllow, - }, - { - name: "deny from external IP", - context: &EvaluationContext{ - Principal: "user:alice", - Action: "s3:GetObject", - Resource: "arn:aws:s3:::mybucket/file.txt", - RequestContext: map[string]interface{}{ - "aws:SourceIp": "8.8.8.8", - }, - }, - want: EffectDeny, - }, - { - name: "allow from internal IP", - context: &EvaluationContext{ - Principal: "user:alice", - Action: "s3:PutObject", - Resource: "arn:aws:s3:::mybucket/newfile.txt", - RequestContext: map[string]interface{}{ - "aws:SourceIp": "10.1.2.3", - }, - }, - want: EffectAllow, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := engine.Evaluate(context.Background(), "", tt.context, []string{"ip-conditional"}) - - assert.NoError(t, err) - assert.Equal(t, tt.want, result.Effect) - }) - } -} - -// TestResourceMatching tests resource ARN matching -func TestResourceMatching(t *testing.T) { - tests := []struct { - name string - policyResource string - requestResource string - want bool - }{ - { - name: "exact match", - policyResource: "arn:aws:s3:::mybucket/file.txt", - requestResource: "arn:aws:s3:::mybucket/file.txt", - want: true, - }, - { - name: "wildcard match", - policyResource: "arn:aws:s3:::mybucket/*", - requestResource: "arn:aws:s3:::mybucket/folder/file.txt", - want: true, - }, - { - name: "bucket wildcard", - policyResource: "arn:aws:s3:::*", - requestResource: "arn:aws:s3:::anybucket/file.txt", - want: true, - }, - { - name: "no match different bucket", - policyResource: "arn:aws:s3:::mybucket/*", - requestResource: "arn:aws:s3:::otherbucket/file.txt", - want: false, - }, - { - name: "prefix match", - policyResource: "arn:aws:s3:::mybucket/documents/*", - requestResource: "arn:aws:s3:::mybucket/documents/secret.txt", - want: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := matchResource(tt.policyResource, tt.requestResource) - assert.Equal(t, tt.want, result) - }) - } -} - -// TestActionMatching tests action pattern matching -func TestActionMatching(t *testing.T) { - tests := []struct { - name string - policyAction string - requestAction string - want bool - }{ - { - name: "exact match", - policyAction: "s3:GetObject", - requestAction: "s3:GetObject", - want: true, - }, - { - name: "wildcard service", - policyAction: "s3:*", - requestAction: "s3:PutObject", - want: true, - }, - { - name: "wildcard all", - policyAction: "*", - requestAction: "filer:CreateEntry", - want: true, - }, - { - name: "prefix match", - policyAction: "s3:Get*", - requestAction: "s3:GetObject", - want: true, - }, - { - name: "no match different service", - policyAction: "s3:GetObject", - requestAction: "filer:GetEntry", - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := matchAction(tt.policyAction, tt.requestAction) - assert.Equal(t, tt.want, result) - }) - } -} - -// Helper function to set up test policy engine -func setupTestPolicyEngine(t *testing.T) *PolicyEngine { - engine := NewPolicyEngine() - config := &PolicyEngineConfig{ - DefaultEffect: "Deny", - StoreType: "memory", - } - - err := engine.Initialize(config) - require.NoError(t, err) - - return engine -} diff --git a/weed/iam/providers/provider_test.go b/weed/iam/providers/provider_test.go deleted file mode 100644 index 99cf360c1..000000000 --- a/weed/iam/providers/provider_test.go +++ /dev/null @@ -1,246 +0,0 @@ -package providers - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestIdentityProviderInterface tests the core identity provider interface -func TestIdentityProviderInterface(t *testing.T) { - tests := []struct { - name string - provider IdentityProvider - wantErr bool - }{ - // We'll add test cases as we implement providers - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test provider name - name := tt.provider.Name() - assert.NotEmpty(t, name, "Provider name should not be empty") - - // Test initialization - err := tt.provider.Initialize(nil) - if tt.wantErr { - assert.Error(t, err) - return - } - require.NoError(t, err) - - // Test authentication with invalid token - ctx := context.Background() - _, err = tt.provider.Authenticate(ctx, "invalid-token") - assert.Error(t, err, "Should fail with invalid token") - }) - } -} - -// TestExternalIdentityValidation tests external identity structure validation -func TestExternalIdentityValidation(t *testing.T) { - tests := []struct { - name string - identity *ExternalIdentity - wantErr bool - }{ - { - name: "valid identity", - identity: &ExternalIdentity{ - UserID: "user123", - Email: "user@example.com", - DisplayName: "Test User", - Groups: []string{"group1", "group2"}, - Attributes: map[string]string{"dept": "engineering"}, - Provider: "test-provider", - }, - wantErr: false, - }, - { - name: "missing user id", - identity: &ExternalIdentity{ - Email: "user@example.com", - Provider: "test-provider", - }, - wantErr: true, - }, - { - name: "missing provider", - identity: &ExternalIdentity{ - UserID: "user123", - Email: "user@example.com", - }, - wantErr: true, - }, - { - name: "invalid email", - identity: &ExternalIdentity{ - UserID: "user123", - Email: "invalid-email", - Provider: "test-provider", - }, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.identity.Validate() - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -// TestTokenClaimsValidation tests token claims structure -func TestTokenClaimsValidation(t *testing.T) { - tests := []struct { - name string - claims *TokenClaims - valid bool - }{ - { - name: "valid claims", - claims: &TokenClaims{ - Subject: "user123", - Issuer: "https://provider.example.com", - Audience: "seaweedfs", - ExpiresAt: time.Now().Add(time.Hour), - IssuedAt: time.Now().Add(-time.Minute), - Claims: map[string]interface{}{"email": "user@example.com"}, - }, - valid: true, - }, - { - name: "expired token", - claims: &TokenClaims{ - Subject: "user123", - Issuer: "https://provider.example.com", - Audience: "seaweedfs", - ExpiresAt: time.Now().Add(-time.Hour), // Expired - IssuedAt: time.Now().Add(-time.Hour * 2), - Claims: map[string]interface{}{"email": "user@example.com"}, - }, - valid: false, - }, - { - name: "future issued token", - claims: &TokenClaims{ - Subject: "user123", - Issuer: "https://provider.example.com", - Audience: "seaweedfs", - ExpiresAt: time.Now().Add(time.Hour), - IssuedAt: time.Now().Add(time.Hour), // Future - Claims: map[string]interface{}{"email": "user@example.com"}, - }, - valid: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - valid := tt.claims.IsValid() - assert.Equal(t, tt.valid, valid) - }) - } -} - -// TestProviderRegistry tests provider registration and discovery -func TestProviderRegistry(t *testing.T) { - // Clear registry for test - registry := NewProviderRegistry() - - t.Run("register provider", func(t *testing.T) { - mockProvider := &MockProvider{name: "test-provider"} - - err := registry.RegisterProvider(mockProvider) - assert.NoError(t, err) - - // Test duplicate registration - err = registry.RegisterProvider(mockProvider) - assert.Error(t, err, "Should not allow duplicate registration") - }) - - t.Run("get provider", func(t *testing.T) { - provider, exists := registry.GetProvider("test-provider") - assert.True(t, exists) - assert.Equal(t, "test-provider", provider.Name()) - - // Test non-existent provider - _, exists = registry.GetProvider("non-existent") - assert.False(t, exists) - }) - - t.Run("list providers", func(t *testing.T) { - providers := registry.ListProviders() - assert.Len(t, providers, 1) - assert.Equal(t, "test-provider", providers[0]) - }) -} - -// MockProvider for testing -type MockProvider struct { - name string - initialized bool - shouldError bool -} - -func (m *MockProvider) Name() string { - return m.name -} - -func (m *MockProvider) Initialize(config interface{}) error { - if m.shouldError { - return assert.AnError - } - m.initialized = true - return nil -} - -func (m *MockProvider) Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) { - if !m.initialized { - return nil, assert.AnError - } - if token == "invalid-token" { - return nil, assert.AnError - } - return &ExternalIdentity{ - UserID: "test-user", - Email: "test@example.com", - DisplayName: "Test User", - Provider: m.name, - }, nil -} - -func (m *MockProvider) GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) { - if !m.initialized || userID == "" { - return nil, assert.AnError - } - return &ExternalIdentity{ - UserID: userID, - Email: userID + "@example.com", - DisplayName: "User " + userID, - Provider: m.name, - }, nil -} - -func (m *MockProvider) ValidateToken(ctx context.Context, token string) (*TokenClaims, error) { - if !m.initialized || token == "invalid-token" { - return nil, assert.AnError - } - return &TokenClaims{ - Subject: "test-user", - Issuer: "test-issuer", - Audience: "seaweedfs", - ExpiresAt: time.Now().Add(time.Hour), - IssuedAt: time.Now(), - Claims: map[string]interface{}{"email": "test@example.com"}, - }, nil -} diff --git a/weed/iam/providers/registry.go b/weed/iam/providers/registry.go deleted file mode 100644 index dee50df44..000000000 --- a/weed/iam/providers/registry.go +++ /dev/null @@ -1,109 +0,0 @@ -package providers - -import ( - "fmt" - "sync" -) - -// ProviderRegistry manages registered identity providers -type ProviderRegistry struct { - mu sync.RWMutex - providers map[string]IdentityProvider -} - -// NewProviderRegistry creates a new provider registry -func NewProviderRegistry() *ProviderRegistry { - return &ProviderRegistry{ - providers: make(map[string]IdentityProvider), - } -} - -// RegisterProvider registers a new identity provider -func (r *ProviderRegistry) RegisterProvider(provider IdentityProvider) error { - if provider == nil { - return fmt.Errorf("provider cannot be nil") - } - - name := provider.Name() - if name == "" { - return fmt.Errorf("provider name cannot be empty") - } - - r.mu.Lock() - defer r.mu.Unlock() - - if _, exists := r.providers[name]; exists { - return fmt.Errorf("provider %s is already registered", name) - } - - r.providers[name] = provider - return nil -} - -// GetProvider retrieves a provider by name -func (r *ProviderRegistry) GetProvider(name string) (IdentityProvider, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - - provider, exists := r.providers[name] - return provider, exists -} - -// ListProviders returns all registered provider names -func (r *ProviderRegistry) ListProviders() []string { - r.mu.RLock() - defer r.mu.RUnlock() - - var names []string - for name := range r.providers { - names = append(names, name) - } - return names -} - -// UnregisterProvider removes a provider from the registry -func (r *ProviderRegistry) UnregisterProvider(name string) error { - r.mu.Lock() - defer r.mu.Unlock() - - if _, exists := r.providers[name]; !exists { - return fmt.Errorf("provider %s is not registered", name) - } - - delete(r.providers, name) - return nil -} - -// Clear removes all providers from the registry -func (r *ProviderRegistry) Clear() { - r.mu.Lock() - defer r.mu.Unlock() - - r.providers = make(map[string]IdentityProvider) -} - -// GetProviderCount returns the number of registered providers -func (r *ProviderRegistry) GetProviderCount() int { - r.mu.RLock() - defer r.mu.RUnlock() - - return len(r.providers) -} - -// Default global registry -var defaultRegistry = NewProviderRegistry() - -// RegisterProvider registers a provider in the default registry -func RegisterProvider(provider IdentityProvider) error { - return defaultRegistry.RegisterProvider(provider) -} - -// GetProvider retrieves a provider from the default registry -func GetProvider(name string) (IdentityProvider, bool) { - return defaultRegistry.GetProvider(name) -} - -// ListProviders returns all provider names from the default registry -func ListProviders() []string { - return defaultRegistry.ListProviders() -} diff --git a/weed/iam/sts/cross_instance_token_test.go b/weed/iam/sts/cross_instance_token_test.go deleted file mode 100644 index 8a375a885..000000000 --- a/weed/iam/sts/cross_instance_token_test.go +++ /dev/null @@ -1,503 +0,0 @@ -package sts - -import ( - "context" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/seaweedfs/seaweedfs/weed/iam/oidc" - "github.com/seaweedfs/seaweedfs/weed/iam/providers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// Test-only constants for mock providers -const ( - ProviderTypeMock = "mock" -) - -// createMockOIDCProvider creates a mock OIDC provider for testing -// This is only available in test builds -func createMockOIDCProvider(name string, config map[string]interface{}) (providers.IdentityProvider, error) { - // Convert config to OIDC format - factory := NewProviderFactory() - oidcConfig, err := factory.convertToOIDCConfig(config) - if err != nil { - return nil, err - } - - // Set default values for mock provider if not provided - if oidcConfig.Issuer == "" { - oidcConfig.Issuer = "http://localhost:9999" - } - - provider := oidc.NewMockOIDCProvider(name) - if err := provider.Initialize(oidcConfig); err != nil { - return nil, err - } - - // Set up default test data for the mock provider - provider.SetupDefaultTestData() - - return provider, nil -} - -// createMockJWT creates a test JWT token with the specified issuer for mock provider testing -func createMockJWT(t *testing.T, issuer, subject string) string { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "iss": issuer, - "sub": subject, - "aud": "test-client", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - }) - - tokenString, err := token.SignedString([]byte("test-signing-key")) - require.NoError(t, err) - return tokenString -} - -// TestCrossInstanceTokenUsage verifies that tokens generated by one STS instance -// can be used and validated by other STS instances in a distributed environment -func TestCrossInstanceTokenUsage(t *testing.T) { - ctx := context.Background() - // Dummy filer address for testing - - // Common configuration that would be shared across all instances in production - sharedConfig := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{12 * time.Hour}, - Issuer: "distributed-sts-cluster", // SAME across all instances - SigningKey: []byte(TestSigningKey32Chars), // SAME across all instances - Providers: []*ProviderConfig{ - { - Name: "company-oidc", - Type: ProviderTypeOIDC, - Enabled: true, - Config: map[string]interface{}{ - ConfigFieldIssuer: "https://sso.company.com/realms/production", - ConfigFieldClientID: "seaweedfs-cluster", - ConfigFieldJWKSUri: "https://sso.company.com/realms/production/protocol/openid-connect/certs", - }, - }, - }, - } - - // Create multiple STS instances simulating different S3 gateway instances - instanceA := NewSTSService() // e.g., s3-gateway-1 - instanceB := NewSTSService() // e.g., s3-gateway-2 - instanceC := NewSTSService() // e.g., s3-gateway-3 - - // Initialize all instances with IDENTICAL configuration - err := instanceA.Initialize(sharedConfig) - require.NoError(t, err, "Instance A should initialize") - - err = instanceB.Initialize(sharedConfig) - require.NoError(t, err, "Instance B should initialize") - - err = instanceC.Initialize(sharedConfig) - require.NoError(t, err, "Instance C should initialize") - - // Set up mock trust policy validator for all instances (required for STS testing) - mockValidator := &MockTrustPolicyValidator{} - instanceA.SetTrustPolicyValidator(mockValidator) - instanceB.SetTrustPolicyValidator(mockValidator) - instanceC.SetTrustPolicyValidator(mockValidator) - - // Manually register mock provider for testing (not available in production) - mockProviderConfig := map[string]interface{}{ - ConfigFieldIssuer: "http://test-mock:9999", - ConfigFieldClientID: TestClientID, - } - mockProviderA, err := createMockOIDCProvider("test-mock", mockProviderConfig) - require.NoError(t, err) - mockProviderB, err := createMockOIDCProvider("test-mock", mockProviderConfig) - require.NoError(t, err) - mockProviderC, err := createMockOIDCProvider("test-mock", mockProviderConfig) - require.NoError(t, err) - - instanceA.RegisterProvider(mockProviderA) - instanceB.RegisterProvider(mockProviderB) - instanceC.RegisterProvider(mockProviderC) - - // Test 1: Token generated on Instance A can be validated on Instance B & C - t.Run("cross_instance_token_validation", func(t *testing.T) { - // Generate session token on Instance A - sessionId := TestSessionID - expiresAt := time.Now().Add(time.Hour) - - tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt) - require.NoError(t, err, "Instance A should generate token") - - // Validate token on Instance B - claimsFromB, err := instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA) - require.NoError(t, err, "Instance B should validate token from Instance A") - assert.Equal(t, sessionId, claimsFromB.SessionId, "Session ID should match") - - // Validate same token on Instance C - claimsFromC, err := instanceC.GetTokenGenerator().ValidateSessionToken(tokenFromA) - require.NoError(t, err, "Instance C should validate token from Instance A") - assert.Equal(t, sessionId, claimsFromC.SessionId, "Session ID should match") - - // All instances should extract identical claims - assert.Equal(t, claimsFromB.SessionId, claimsFromC.SessionId) - assert.Equal(t, claimsFromB.ExpiresAt.Unix(), claimsFromC.ExpiresAt.Unix()) - assert.Equal(t, claimsFromB.IssuedAt.Unix(), claimsFromC.IssuedAt.Unix()) - }) - - // Test 2: Complete assume role flow across instances - t.Run("cross_instance_assume_role_flow", func(t *testing.T) { - // Step 1: User authenticates and assumes role on Instance A - // Create a valid JWT token for the mock provider - mockToken := createMockJWT(t, "http://test-mock:9999", "test-user") - - assumeRequest := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/CrossInstanceTestRole", - WebIdentityToken: mockToken, // JWT token for mock provider - RoleSessionName: "cross-instance-test-session", - DurationSeconds: int64ToPtr(3600), - } - - // Instance A processes assume role request - responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest) - require.NoError(t, err, "Instance A should process assume role") - - sessionToken := responseFromA.Credentials.SessionToken - accessKeyId := responseFromA.Credentials.AccessKeyId - secretAccessKey := responseFromA.Credentials.SecretAccessKey - - // Verify response structure - assert.NotEmpty(t, sessionToken, "Should have session token") - assert.NotEmpty(t, accessKeyId, "Should have access key ID") - assert.NotEmpty(t, secretAccessKey, "Should have secret access key") - assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user") - - // Step 2: Use session token on Instance B (different instance) - sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, sessionToken) - require.NoError(t, err, "Instance B should validate session token from Instance A") - - assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName) - assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn) - - // Step 3: Use same session token on Instance C (yet another instance) - sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, sessionToken) - require.NoError(t, err, "Instance C should validate session token from Instance A") - - // All instances should return identical session information - assert.Equal(t, sessionInfoFromB.SessionId, sessionInfoFromC.SessionId) - assert.Equal(t, sessionInfoFromB.SessionName, sessionInfoFromC.SessionName) - assert.Equal(t, sessionInfoFromB.RoleArn, sessionInfoFromC.RoleArn) - assert.Equal(t, sessionInfoFromB.Subject, sessionInfoFromC.Subject) - assert.Equal(t, sessionInfoFromB.Provider, sessionInfoFromC.Provider) - }) - - // Test 3: Session revocation across instances - t.Run("cross_instance_session_revocation", func(t *testing.T) { - // Create session on Instance A - mockToken := createMockJWT(t, "http://test-mock:9999", "test-user") - - assumeRequest := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/RevocationTestRole", - WebIdentityToken: mockToken, - RoleSessionName: "revocation-test-session", - } - - response, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest) - require.NoError(t, err) - sessionToken := response.Credentials.SessionToken - - // Verify token works on Instance B - _, err = instanceB.ValidateSessionToken(ctx, sessionToken) - require.NoError(t, err, "Token should be valid on Instance B initially") - - // Validate session on Instance C to verify cross-instance token compatibility - _, err = instanceC.ValidateSessionToken(ctx, sessionToken) - require.NoError(t, err, "Instance C should be able to validate session token") - - // In a stateless JWT system, tokens remain valid on all instances since they're self-contained - // No revocation is possible without breaking the stateless architecture - _, err = instanceA.ValidateSessionToken(ctx, sessionToken) - assert.NoError(t, err, "Token should still be valid on Instance A (stateless system)") - - // Verify token is still valid on Instance B - _, err = instanceB.ValidateSessionToken(ctx, sessionToken) - assert.NoError(t, err, "Token should still be valid on Instance B (stateless system)") - }) - - // Test 4: Provider consistency across instances - t.Run("provider_consistency_affects_token_generation", func(t *testing.T) { - // All instances should have same providers and be able to process same OIDC tokens - providerNamesA := instanceA.getProviderNames() - providerNamesB := instanceB.getProviderNames() - providerNamesC := instanceC.getProviderNames() - - assert.ElementsMatch(t, providerNamesA, providerNamesB, "Instance A and B should have same providers") - assert.ElementsMatch(t, providerNamesB, providerNamesC, "Instance B and C should have same providers") - - // All instances should be able to process same web identity token - testToken := createMockJWT(t, "http://test-mock:9999", "test-user") - - // Try to assume role with same token on different instances - assumeRequest := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/ProviderTestRole", - WebIdentityToken: testToken, - RoleSessionName: "provider-consistency-test", - } - - // Should work on any instance - responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest) - responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, assumeRequest) - responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, assumeRequest) - - require.NoError(t, errA, "Instance A should process OIDC token") - require.NoError(t, errB, "Instance B should process OIDC token") - require.NoError(t, errC, "Instance C should process OIDC token") - - // All should return valid responses (sessions will have different IDs but same structure) - assert.NotEmpty(t, responseA.Credentials.SessionToken) - assert.NotEmpty(t, responseB.Credentials.SessionToken) - assert.NotEmpty(t, responseC.Credentials.SessionToken) - }) -} - -// TestSTSDistributedConfigurationRequirements tests the configuration requirements -// for cross-instance token compatibility -func TestSTSDistributedConfigurationRequirements(t *testing.T) { - _ = "localhost:8888" // Dummy filer address for testing (not used in these tests) - - t.Run("same_signing_key_required", func(t *testing.T) { - // Instance A with signing key 1 - configA := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{12 * time.Hour}, - Issuer: "test-sts", - SigningKey: []byte("signing-key-1-32-characters-long"), - } - - // Instance B with different signing key - configB := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{12 * time.Hour}, - Issuer: "test-sts", - SigningKey: []byte("signing-key-2-32-characters-long"), // DIFFERENT! - } - - instanceA := NewSTSService() - instanceB := NewSTSService() - - err := instanceA.Initialize(configA) - require.NoError(t, err) - - err = instanceB.Initialize(configB) - require.NoError(t, err) - - // Generate token on Instance A - sessionId := "test-session" - expiresAt := time.Now().Add(time.Hour) - tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt) - require.NoError(t, err) - - // Instance A should validate its own token - _, err = instanceA.GetTokenGenerator().ValidateSessionToken(tokenFromA) - assert.NoError(t, err, "Instance A should validate own token") - - // Instance B should REJECT token due to different signing key - _, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA) - assert.Error(t, err, "Instance B should reject token with different signing key") - assert.Contains(t, err.Error(), "invalid token", "Should be signature validation error") - }) - - t.Run("same_issuer_required", func(t *testing.T) { - sharedSigningKey := []byte("shared-signing-key-32-characters-lo") - - // Instance A with issuer 1 - configA := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{12 * time.Hour}, - Issuer: "sts-cluster-1", - SigningKey: sharedSigningKey, - } - - // Instance B with different issuer - configB := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{12 * time.Hour}, - Issuer: "sts-cluster-2", // DIFFERENT! - SigningKey: sharedSigningKey, - } - - instanceA := NewSTSService() - instanceB := NewSTSService() - - err := instanceA.Initialize(configA) - require.NoError(t, err) - - err = instanceB.Initialize(configB) - require.NoError(t, err) - - // Generate token on Instance A - sessionId := "test-session" - expiresAt := time.Now().Add(time.Hour) - tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt) - require.NoError(t, err) - - // Instance B should REJECT token due to different issuer - _, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA) - assert.Error(t, err, "Instance B should reject token with different issuer") - assert.Contains(t, err.Error(), "invalid issuer", "Should be issuer validation error") - }) - - t.Run("identical_configuration_required", func(t *testing.T) { - // Identical configuration - identicalConfig := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{12 * time.Hour}, - Issuer: "production-sts-cluster", - SigningKey: []byte("production-signing-key-32-chars-l"), - } - - // Create multiple instances with identical config - instances := make([]*STSService, 5) - for i := 0; i < 5; i++ { - instances[i] = NewSTSService() - err := instances[i].Initialize(identicalConfig) - require.NoError(t, err, "Instance %d should initialize", i) - } - - // Generate token on Instance 0 - sessionId := "multi-instance-test" - expiresAt := time.Now().Add(time.Hour) - token, err := instances[0].GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt) - require.NoError(t, err) - - // All other instances should validate the token - for i := 1; i < 5; i++ { - claims, err := instances[i].GetTokenGenerator().ValidateSessionToken(token) - require.NoError(t, err, "Instance %d should validate token", i) - assert.Equal(t, sessionId, claims.SessionId, "Instance %d should extract correct session ID", i) - } - }) -} - -// TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios -func TestSTSRealWorldDistributedScenarios(t *testing.T) { - ctx := context.Background() - - t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) { - // Simulate real production scenario: - // 1. User authenticates with OIDC provider - // 2. User calls AssumeRoleWithWebIdentity on S3 Gateway 1 - // 3. User makes S3 requests that hit S3 Gateway 2 & 3 via load balancer - // 4. All instances should handle the session token correctly - - productionConfig := &STSConfig{ - TokenDuration: FlexibleDuration{2 * time.Hour}, - MaxSessionLength: FlexibleDuration{24 * time.Hour}, - Issuer: "seaweedfs-production-sts", - SigningKey: []byte("prod-signing-key-32-characters-lon"), - - Providers: []*ProviderConfig{ - { - Name: "corporate-oidc", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "issuer": "https://sso.company.com/realms/production", - "clientId": "seaweedfs-prod-cluster", - "clientSecret": "supersecret-prod-key", - "scopes": []string{"openid", "profile", "email", "groups"}, - }, - }, - }, - } - - // Create 3 S3 Gateway instances behind load balancer - gateway1 := NewSTSService() - gateway2 := NewSTSService() - gateway3 := NewSTSService() - - err := gateway1.Initialize(productionConfig) - require.NoError(t, err) - - err = gateway2.Initialize(productionConfig) - require.NoError(t, err) - - err = gateway3.Initialize(productionConfig) - require.NoError(t, err) - - // Set up mock trust policy validator for all gateway instances - mockValidator := &MockTrustPolicyValidator{} - gateway1.SetTrustPolicyValidator(mockValidator) - gateway2.SetTrustPolicyValidator(mockValidator) - gateway3.SetTrustPolicyValidator(mockValidator) - - // Manually register mock provider for testing (not available in production) - mockProviderConfig := map[string]interface{}{ - ConfigFieldIssuer: "http://test-mock:9999", - ConfigFieldClientID: "test-client-id", - } - mockProvider1, err := createMockOIDCProvider("test-mock", mockProviderConfig) - require.NoError(t, err) - mockProvider2, err := createMockOIDCProvider("test-mock", mockProviderConfig) - require.NoError(t, err) - mockProvider3, err := createMockOIDCProvider("test-mock", mockProviderConfig) - require.NoError(t, err) - - gateway1.RegisterProvider(mockProvider1) - gateway2.RegisterProvider(mockProvider2) - gateway3.RegisterProvider(mockProvider3) - - // Step 1: User authenticates and hits Gateway 1 for AssumeRole - mockToken := createMockJWT(t, "http://test-mock:9999", "production-user") - - assumeRequest := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/ProductionS3User", - WebIdentityToken: mockToken, // JWT token from mock provider - RoleSessionName: "user-production-session", - DurationSeconds: int64ToPtr(7200), // 2 hours - } - - stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, assumeRequest) - require.NoError(t, err, "Gateway 1 should handle AssumeRole") - - sessionToken := stsResponse.Credentials.SessionToken - accessKey := stsResponse.Credentials.AccessKeyId - secretKey := stsResponse.Credentials.SecretAccessKey - - // Step 2: User makes S3 requests that hit different gateways via load balancer - // Simulate S3 request validation on Gateway 2 - sessionInfo2, err := gateway2.ValidateSessionToken(ctx, sessionToken) - require.NoError(t, err, "Gateway 2 should validate session from Gateway 1") - assert.Equal(t, "user-production-session", sessionInfo2.SessionName) - assert.Equal(t, "arn:aws:iam::role/ProductionS3User", sessionInfo2.RoleArn) - - // Simulate S3 request validation on Gateway 3 - sessionInfo3, err := gateway3.ValidateSessionToken(ctx, sessionToken) - require.NoError(t, err, "Gateway 3 should validate session from Gateway 1") - assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session") - - // Step 3: Verify credentials are consistent - assert.Equal(t, accessKey, stsResponse.Credentials.AccessKeyId, "Access key should be consistent") - assert.Equal(t, secretKey, stsResponse.Credentials.SecretAccessKey, "Secret key should be consistent") - - // Step 4: Session expiration should be honored across all instances - assert.True(t, sessionInfo2.ExpiresAt.After(time.Now()), "Session should not be expired") - assert.True(t, sessionInfo3.ExpiresAt.After(time.Now()), "Session should not be expired") - - // Step 5: Token should be identical when parsed - claims2, err := gateway2.GetTokenGenerator().ValidateSessionToken(sessionToken) - require.NoError(t, err) - - claims3, err := gateway3.GetTokenGenerator().ValidateSessionToken(sessionToken) - require.NoError(t, err) - - assert.Equal(t, claims2.SessionId, claims3.SessionId, "Session IDs should match") - assert.Equal(t, claims2.ExpiresAt.Unix(), claims3.ExpiresAt.Unix(), "Expiration should match") - }) -} - -// Helper function to convert int64 to pointer -func int64ToPtr(i int64) *int64 { - return &i -} diff --git a/weed/iam/sts/distributed_sts_test.go b/weed/iam/sts/distributed_sts_test.go deleted file mode 100644 index 7997e7b8e..000000000 --- a/weed/iam/sts/distributed_sts_test.go +++ /dev/null @@ -1,340 +0,0 @@ -package sts - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestDistributedSTSService verifies that multiple STS instances with identical configurations -// behave consistently across distributed environments -func TestDistributedSTSService(t *testing.T) { - ctx := context.Background() - - // Common configuration for all instances - commonConfig := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{12 * time.Hour}, - Issuer: "distributed-sts-test", - SigningKey: []byte("test-signing-key-32-characters-long"), - - Providers: []*ProviderConfig{ - { - Name: "keycloak-oidc", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "issuer": "http://keycloak:8080/realms/seaweedfs-test", - "clientId": "seaweedfs-s3", - "jwksUri": "http://keycloak:8080/realms/seaweedfs-test/protocol/openid-connect/certs", - }, - }, - - { - Name: "disabled-ldap", - Type: "oidc", // Use OIDC as placeholder since LDAP isn't implemented - Enabled: false, - Config: map[string]interface{}{ - "issuer": "ldap://company.com", - "clientId": "ldap-client", - }, - }, - }, - } - - // Create multiple STS instances simulating distributed deployment - instance1 := NewSTSService() - instance2 := NewSTSService() - instance3 := NewSTSService() - - // Initialize all instances with identical configuration - err := instance1.Initialize(commonConfig) - require.NoError(t, err, "Instance 1 should initialize successfully") - - err = instance2.Initialize(commonConfig) - require.NoError(t, err, "Instance 2 should initialize successfully") - - err = instance3.Initialize(commonConfig) - require.NoError(t, err, "Instance 3 should initialize successfully") - - // Manually register mock providers for testing (not available in production) - mockProviderConfig := map[string]interface{}{ - "issuer": "http://localhost:9999", - "clientId": "test-client", - } - mockProvider1, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig) - require.NoError(t, err) - mockProvider2, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig) - require.NoError(t, err) - mockProvider3, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig) - require.NoError(t, err) - - instance1.RegisterProvider(mockProvider1) - instance2.RegisterProvider(mockProvider2) - instance3.RegisterProvider(mockProvider3) - - // Verify all instances have identical provider configurations - t.Run("provider_consistency", func(t *testing.T) { - // All instances should have same number of providers - assert.Len(t, instance1.providers, 2, "Instance 1 should have 2 enabled providers") - assert.Len(t, instance2.providers, 2, "Instance 2 should have 2 enabled providers") - assert.Len(t, instance3.providers, 2, "Instance 3 should have 2 enabled providers") - - // All instances should have same provider names - instance1Names := instance1.getProviderNames() - instance2Names := instance2.getProviderNames() - instance3Names := instance3.getProviderNames() - - assert.ElementsMatch(t, instance1Names, instance2Names, "Instance 1 and 2 should have same providers") - assert.ElementsMatch(t, instance2Names, instance3Names, "Instance 2 and 3 should have same providers") - - // Verify specific providers exist on all instances - expectedProviders := []string{"keycloak-oidc", "test-mock-provider"} - assert.ElementsMatch(t, instance1Names, expectedProviders, "Instance 1 should have expected providers") - assert.ElementsMatch(t, instance2Names, expectedProviders, "Instance 2 should have expected providers") - assert.ElementsMatch(t, instance3Names, expectedProviders, "Instance 3 should have expected providers") - - // Verify disabled providers are not loaded - assert.NotContains(t, instance1Names, "disabled-ldap", "Disabled providers should not be loaded") - assert.NotContains(t, instance2Names, "disabled-ldap", "Disabled providers should not be loaded") - assert.NotContains(t, instance3Names, "disabled-ldap", "Disabled providers should not be loaded") - }) - - // Test token generation consistency across instances - t.Run("token_generation_consistency", func(t *testing.T) { - sessionId := "test-session-123" - expiresAt := time.Now().Add(time.Hour) - - // Generate tokens from different instances - token1, err1 := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt) - token2, err2 := instance2.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt) - token3, err3 := instance3.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt) - - require.NoError(t, err1, "Instance 1 token generation should succeed") - require.NoError(t, err2, "Instance 2 token generation should succeed") - require.NoError(t, err3, "Instance 3 token generation should succeed") - - // All tokens should be different (due to timestamp variations) - // But they should all be valid JWTs with same signing key - assert.NotEmpty(t, token1) - assert.NotEmpty(t, token2) - assert.NotEmpty(t, token3) - }) - - // Test token validation consistency - any instance should validate tokens from any other instance - t.Run("cross_instance_token_validation", func(t *testing.T) { - sessionId := "cross-validation-session" - expiresAt := time.Now().Add(time.Hour) - - // Generate token on instance 1 - token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt) - require.NoError(t, err) - - // Validate on all instances - claims1, err1 := instance1.GetTokenGenerator().ValidateSessionToken(token) - claims2, err2 := instance2.GetTokenGenerator().ValidateSessionToken(token) - claims3, err3 := instance3.GetTokenGenerator().ValidateSessionToken(token) - - require.NoError(t, err1, "Instance 1 should validate token from instance 1") - require.NoError(t, err2, "Instance 2 should validate token from instance 1") - require.NoError(t, err3, "Instance 3 should validate token from instance 1") - - // All instances should extract same session ID - assert.Equal(t, sessionId, claims1.SessionId) - assert.Equal(t, sessionId, claims2.SessionId) - assert.Equal(t, sessionId, claims3.SessionId) - - assert.Equal(t, claims1.SessionId, claims2.SessionId) - assert.Equal(t, claims2.SessionId, claims3.SessionId) - }) - - // Test provider access consistency - t.Run("provider_access_consistency", func(t *testing.T) { - // All instances should be able to access the same providers - provider1, exists1 := instance1.providers["test-mock-provider"] - provider2, exists2 := instance2.providers["test-mock-provider"] - provider3, exists3 := instance3.providers["test-mock-provider"] - - assert.True(t, exists1, "Instance 1 should have test-mock-provider") - assert.True(t, exists2, "Instance 2 should have test-mock-provider") - assert.True(t, exists3, "Instance 3 should have test-mock-provider") - - assert.Equal(t, provider1.Name(), provider2.Name()) - assert.Equal(t, provider2.Name(), provider3.Name()) - - // Test authentication with the mock provider on all instances - testToken := "valid_test_token" - - identity1, err1 := provider1.Authenticate(ctx, testToken) - identity2, err2 := provider2.Authenticate(ctx, testToken) - identity3, err3 := provider3.Authenticate(ctx, testToken) - - require.NoError(t, err1, "Instance 1 provider should authenticate successfully") - require.NoError(t, err2, "Instance 2 provider should authenticate successfully") - require.NoError(t, err3, "Instance 3 provider should authenticate successfully") - - // All instances should return identical identity information - assert.Equal(t, identity1.UserID, identity2.UserID) - assert.Equal(t, identity2.UserID, identity3.UserID) - assert.Equal(t, identity1.Email, identity2.Email) - assert.Equal(t, identity2.Email, identity3.Email) - assert.Equal(t, identity1.Provider, identity2.Provider) - assert.Equal(t, identity2.Provider, identity3.Provider) - }) -} - -// TestSTSConfigurationValidation tests configuration validation for distributed deployments -func TestSTSConfigurationValidation(t *testing.T) { - t.Run("consistent_signing_keys_required", func(t *testing.T) { - // Different signing keys should result in incompatible token validation - config1 := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{12 * time.Hour}, - Issuer: "test-sts", - SigningKey: []byte("signing-key-1-32-characters-long"), - } - - config2 := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{12 * time.Hour}, - Issuer: "test-sts", - SigningKey: []byte("signing-key-2-32-characters-long"), // Different key! - } - - instance1 := NewSTSService() - instance2 := NewSTSService() - - err1 := instance1.Initialize(config1) - err2 := instance2.Initialize(config2) - - require.NoError(t, err1) - require.NoError(t, err2) - - // Generate token on instance 1 - sessionId := "test-session" - expiresAt := time.Now().Add(time.Hour) - token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt) - require.NoError(t, err) - - // Instance 1 should validate its own token - _, err = instance1.GetTokenGenerator().ValidateSessionToken(token) - assert.NoError(t, err, "Instance 1 should validate its own token") - - // Instance 2 should reject token from instance 1 (different signing key) - _, err = instance2.GetTokenGenerator().ValidateSessionToken(token) - assert.Error(t, err, "Instance 2 should reject token with different signing key") - }) - - t.Run("consistent_issuer_required", func(t *testing.T) { - // Different issuers should result in incompatible tokens - commonSigningKey := []byte("shared-signing-key-32-characters-lo") - - config1 := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{12 * time.Hour}, - Issuer: "sts-instance-1", - SigningKey: commonSigningKey, - } - - config2 := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{12 * time.Hour}, - Issuer: "sts-instance-2", // Different issuer! - SigningKey: commonSigningKey, - } - - instance1 := NewSTSService() - instance2 := NewSTSService() - - err1 := instance1.Initialize(config1) - err2 := instance2.Initialize(config2) - - require.NoError(t, err1) - require.NoError(t, err2) - - // Generate token on instance 1 - sessionId := "test-session" - expiresAt := time.Now().Add(time.Hour) - token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt) - require.NoError(t, err) - - // Instance 2 should reject token due to issuer mismatch - // (Even though signing key is the same, issuer validation will fail) - _, err = instance2.GetTokenGenerator().ValidateSessionToken(token) - assert.Error(t, err, "Instance 2 should reject token with different issuer") - }) -} - -// TestProviderFactoryDistributed tests the provider factory in distributed scenarios -func TestProviderFactoryDistributed(t *testing.T) { - factory := NewProviderFactory() - - // Simulate configuration that would be identical across all instances - configs := []*ProviderConfig{ - { - Name: "production-keycloak", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "issuer": "https://keycloak.company.com/realms/seaweedfs", - "clientId": "seaweedfs-prod", - "clientSecret": "super-secret-key", - "jwksUri": "https://keycloak.company.com/realms/seaweedfs/protocol/openid-connect/certs", - "scopes": []string{"openid", "profile", "email", "roles"}, - }, - }, - { - Name: "backup-oidc", - Type: "oidc", - Enabled: false, // Disabled by default - Config: map[string]interface{}{ - "issuer": "https://backup-oidc.company.com", - "clientId": "seaweedfs-backup", - }, - }, - } - - // Create providers multiple times (simulating multiple instances) - providers1, err1 := factory.LoadProvidersFromConfig(configs) - providers2, err2 := factory.LoadProvidersFromConfig(configs) - providers3, err3 := factory.LoadProvidersFromConfig(configs) - - require.NoError(t, err1, "First load should succeed") - require.NoError(t, err2, "Second load should succeed") - require.NoError(t, err3, "Third load should succeed") - - // All instances should have same provider counts - assert.Len(t, providers1, 1, "First instance should have 1 enabled provider") - assert.Len(t, providers2, 1, "Second instance should have 1 enabled provider") - assert.Len(t, providers3, 1, "Third instance should have 1 enabled provider") - - // All instances should have same provider names - names1 := make([]string, 0, len(providers1)) - names2 := make([]string, 0, len(providers2)) - names3 := make([]string, 0, len(providers3)) - - for name := range providers1 { - names1 = append(names1, name) - } - for name := range providers2 { - names2 = append(names2, name) - } - for name := range providers3 { - names3 = append(names3, name) - } - - assert.ElementsMatch(t, names1, names2, "Instance 1 and 2 should have same provider names") - assert.ElementsMatch(t, names2, names3, "Instance 2 and 3 should have same provider names") - - // Verify specific providers - expectedProviders := []string{"production-keycloak"} - assert.ElementsMatch(t, names1, expectedProviders, "Should have expected enabled providers") - - // Verify disabled providers are not included - assert.NotContains(t, names1, "backup-oidc", "Disabled providers should not be loaded") - assert.NotContains(t, names2, "backup-oidc", "Disabled providers should not be loaded") - assert.NotContains(t, names3, "backup-oidc", "Disabled providers should not be loaded") -} diff --git a/weed/iam/sts/provider_factory.go b/weed/iam/sts/provider_factory.go index 53635c8f2..eb87d6d7e 100644 --- a/weed/iam/sts/provider_factory.go +++ b/weed/iam/sts/provider_factory.go @@ -274,69 +274,3 @@ func (f *ProviderFactory) convertToRoleMapping(value interface{}) (*providers.Ro return roleMapping, nil } - -// ValidateProviderConfig validates a provider configuration -func (f *ProviderFactory) ValidateProviderConfig(config *ProviderConfig) error { - if config == nil { - return fmt.Errorf("provider config cannot be nil") - } - - if config.Name == "" { - return fmt.Errorf("provider name cannot be empty") - } - - if config.Type == "" { - return fmt.Errorf("provider type cannot be empty") - } - - if config.Config == nil { - return fmt.Errorf("provider config cannot be nil") - } - - // Type-specific validation - switch config.Type { - case "oidc": - return f.validateOIDCConfig(config.Config) - case "ldap": - return f.validateLDAPConfig(config.Config) - case "saml": - return f.validateSAMLConfig(config.Config) - default: - return fmt.Errorf("unsupported provider type: %s", config.Type) - } -} - -// validateOIDCConfig validates OIDC provider configuration -func (f *ProviderFactory) validateOIDCConfig(config map[string]interface{}) error { - if _, ok := config[ConfigFieldIssuer]; !ok { - return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldIssuer) - } - - if _, ok := config[ConfigFieldClientID]; !ok { - return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldClientID) - } - - return nil -} - -// validateLDAPConfig validates LDAP provider configuration -func (f *ProviderFactory) validateLDAPConfig(config map[string]interface{}) error { - if _, ok := config["server"]; !ok { - return fmt.Errorf("LDAP provider requires 'server' field") - } - if _, ok := config["baseDN"]; !ok { - return fmt.Errorf("LDAP provider requires 'baseDN' field") - } - return nil -} - -// validateSAMLConfig validates SAML provider configuration -func (f *ProviderFactory) validateSAMLConfig(config map[string]interface{}) error { - // TODO: Implement when SAML provider is available - return nil -} - -// GetSupportedProviderTypes returns list of supported provider types -func (f *ProviderFactory) GetSupportedProviderTypes() []string { - return []string{ProviderTypeOIDC} -} diff --git a/weed/iam/sts/provider_factory_test.go b/weed/iam/sts/provider_factory_test.go deleted file mode 100644 index 8c36142a7..000000000 --- a/weed/iam/sts/provider_factory_test.go +++ /dev/null @@ -1,312 +0,0 @@ -package sts - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestProviderFactory_CreateOIDCProvider(t *testing.T) { - factory := NewProviderFactory() - - config := &ProviderConfig{ - Name: "test-oidc", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "issuer": "https://test-issuer.com", - "clientId": "test-client", - "clientSecret": "test-secret", - "jwksUri": "https://test-issuer.com/.well-known/jwks.json", - "scopes": []string{"openid", "profile", "email"}, - }, - } - - provider, err := factory.CreateProvider(config) - require.NoError(t, err) - assert.NotNil(t, provider) - assert.Equal(t, "test-oidc", provider.Name()) -} - -// Note: Mock provider tests removed - mock providers are now test-only -// and not available through the production ProviderFactory - -func TestProviderFactory_DisabledProvider(t *testing.T) { - factory := NewProviderFactory() - - config := &ProviderConfig{ - Name: "disabled-provider", - Type: "oidc", - Enabled: false, - Config: map[string]interface{}{ - "issuer": "https://test-issuer.com", - "clientId": "test-client", - }, - } - - provider, err := factory.CreateProvider(config) - require.NoError(t, err) - assert.Nil(t, provider) // Should return nil for disabled providers -} - -func TestProviderFactory_InvalidProviderType(t *testing.T) { - factory := NewProviderFactory() - - config := &ProviderConfig{ - Name: "invalid-provider", - Type: "unsupported-type", - Enabled: true, - Config: map[string]interface{}{}, - } - - provider, err := factory.CreateProvider(config) - assert.Error(t, err) - assert.Nil(t, provider) - assert.Contains(t, err.Error(), "unsupported provider type") -} - -func TestProviderFactory_LoadMultipleProviders(t *testing.T) { - factory := NewProviderFactory() - - configs := []*ProviderConfig{ - { - Name: "oidc-provider", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "issuer": "https://oidc-issuer.com", - "clientId": "oidc-client", - }, - }, - - { - Name: "disabled-provider", - Type: "oidc", - Enabled: false, - Config: map[string]interface{}{ - "issuer": "https://disabled-issuer.com", - "clientId": "disabled-client", - }, - }, - } - - providers, err := factory.LoadProvidersFromConfig(configs) - require.NoError(t, err) - assert.Len(t, providers, 1) // Only enabled providers should be loaded - - assert.Contains(t, providers, "oidc-provider") - assert.NotContains(t, providers, "disabled-provider") -} - -func TestProviderFactory_ValidateOIDCConfig(t *testing.T) { - factory := NewProviderFactory() - - t.Run("valid config", func(t *testing.T) { - config := &ProviderConfig{ - Name: "valid-oidc", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "issuer": "https://valid-issuer.com", - "clientId": "valid-client", - }, - } - - err := factory.ValidateProviderConfig(config) - assert.NoError(t, err) - }) - - t.Run("missing issuer", func(t *testing.T) { - config := &ProviderConfig{ - Name: "invalid-oidc", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "clientId": "valid-client", - }, - } - - err := factory.ValidateProviderConfig(config) - assert.Error(t, err) - assert.Contains(t, err.Error(), "issuer") - }) - - t.Run("missing clientId", func(t *testing.T) { - config := &ProviderConfig{ - Name: "invalid-oidc", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "issuer": "https://valid-issuer.com", - }, - } - - err := factory.ValidateProviderConfig(config) - assert.Error(t, err) - assert.Contains(t, err.Error(), "clientId") - }) -} - -func TestProviderFactory_ConvertToStringSlice(t *testing.T) { - factory := NewProviderFactory() - - t.Run("string slice", func(t *testing.T) { - input := []string{"a", "b", "c"} - result, err := factory.convertToStringSlice(input) - require.NoError(t, err) - assert.Equal(t, []string{"a", "b", "c"}, result) - }) - - t.Run("interface slice", func(t *testing.T) { - input := []interface{}{"a", "b", "c"} - result, err := factory.convertToStringSlice(input) - require.NoError(t, err) - assert.Equal(t, []string{"a", "b", "c"}, result) - }) - - t.Run("invalid type", func(t *testing.T) { - input := "not-a-slice" - result, err := factory.convertToStringSlice(input) - assert.Error(t, err) - assert.Nil(t, result) - }) -} - -func TestProviderFactory_ConfigConversionErrors(t *testing.T) { - factory := NewProviderFactory() - - t.Run("invalid scopes type", func(t *testing.T) { - config := &ProviderConfig{ - Name: "invalid-scopes", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "issuer": "https://test-issuer.com", - "clientId": "test-client", - "scopes": "invalid-not-array", // Should be array - }, - } - - provider, err := factory.CreateProvider(config) - assert.Error(t, err) - assert.Nil(t, provider) - assert.Contains(t, err.Error(), "failed to convert scopes") - }) - - t.Run("invalid claimsMapping type", func(t *testing.T) { - config := &ProviderConfig{ - Name: "invalid-claims", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "issuer": "https://test-issuer.com", - "clientId": "test-client", - "claimsMapping": "invalid-not-map", // Should be map - }, - } - - provider, err := factory.CreateProvider(config) - assert.Error(t, err) - assert.Nil(t, provider) - assert.Contains(t, err.Error(), "failed to convert claimsMapping") - }) - - t.Run("invalid roleMapping type", func(t *testing.T) { - config := &ProviderConfig{ - Name: "invalid-roles", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "issuer": "https://test-issuer.com", - "clientId": "test-client", - "roleMapping": "invalid-not-map", // Should be map - }, - } - - provider, err := factory.CreateProvider(config) - assert.Error(t, err) - assert.Nil(t, provider) - assert.Contains(t, err.Error(), "failed to convert roleMapping") - }) -} - -func TestProviderFactory_ConvertToStringMap(t *testing.T) { - factory := NewProviderFactory() - - t.Run("string map", func(t *testing.T) { - input := map[string]string{"key1": "value1", "key2": "value2"} - result, err := factory.convertToStringMap(input) - require.NoError(t, err) - assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result) - }) - - t.Run("interface map", func(t *testing.T) { - input := map[string]interface{}{"key1": "value1", "key2": "value2"} - result, err := factory.convertToStringMap(input) - require.NoError(t, err) - assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result) - }) - - t.Run("invalid type", func(t *testing.T) { - input := "not-a-map" - result, err := factory.convertToStringMap(input) - assert.Error(t, err) - assert.Nil(t, result) - }) -} - -func TestProviderFactory_GetSupportedProviderTypes(t *testing.T) { - factory := NewProviderFactory() - - supportedTypes := factory.GetSupportedProviderTypes() - assert.Contains(t, supportedTypes, "oidc") - assert.Len(t, supportedTypes, 1) // Currently only OIDC is supported in production -} - -func TestSTSService_LoadProvidersFromConfig(t *testing.T) { - stsConfig := &STSConfig{ - TokenDuration: FlexibleDuration{3600 * time.Second}, - MaxSessionLength: FlexibleDuration{43200 * time.Second}, - Issuer: "test-issuer", - SigningKey: []byte("test-signing-key-32-characters-long"), - Providers: []*ProviderConfig{ - { - Name: "test-provider", - Type: "oidc", - Enabled: true, - Config: map[string]interface{}{ - "issuer": "https://test-issuer.com", - "clientId": "test-client", - }, - }, - }, - } - - stsService := NewSTSService() - err := stsService.Initialize(stsConfig) - require.NoError(t, err) - - // Check that provider was loaded - assert.Len(t, stsService.providers, 1) - assert.Contains(t, stsService.providers, "test-provider") - assert.Equal(t, "test-provider", stsService.providers["test-provider"].Name()) -} - -func TestSTSService_NoProvidersConfig(t *testing.T) { - stsConfig := &STSConfig{ - TokenDuration: FlexibleDuration{3600 * time.Second}, - MaxSessionLength: FlexibleDuration{43200 * time.Second}, - Issuer: "test-issuer", - SigningKey: []byte("test-signing-key-32-characters-long"), - // No providers configured - } - - stsService := NewSTSService() - err := stsService.Initialize(stsConfig) - require.NoError(t, err) - - // Should initialize successfully with no providers - assert.Len(t, stsService.providers, 0) -} diff --git a/weed/iam/sts/security_test.go b/weed/iam/sts/security_test.go deleted file mode 100644 index 2d230d796..000000000 --- a/weed/iam/sts/security_test.go +++ /dev/null @@ -1,193 +0,0 @@ -package sts - -import ( - "context" - "fmt" - "strings" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/seaweedfs/seaweedfs/weed/iam/providers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestSecurityIssuerToProviderMapping tests the security fix that ensures JWT tokens -// with specific issuer claims can only be validated by the provider registered for that issuer -func TestSecurityIssuerToProviderMapping(t *testing.T) { - ctx := context.Background() - - // Create STS service with two mock providers - service := NewSTSService() - config := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{time.Hour * 12}, - Issuer: "test-sts", - SigningKey: []byte("test-signing-key-32-characters-long"), - } - - err := service.Initialize(config) - require.NoError(t, err) - - // Set up mock trust policy validator - mockValidator := &MockTrustPolicyValidator{} - service.SetTrustPolicyValidator(mockValidator) - - // Create two mock providers with different issuers - providerA := &MockIdentityProviderWithIssuer{ - name: "provider-a", - issuer: "https://provider-a.com", - validTokens: map[string]bool{ - "token-for-provider-a": true, - }, - } - - providerB := &MockIdentityProviderWithIssuer{ - name: "provider-b", - issuer: "https://provider-b.com", - validTokens: map[string]bool{ - "token-for-provider-b": true, - }, - } - - // Register both providers - err = service.RegisterProvider(providerA) - require.NoError(t, err) - err = service.RegisterProvider(providerB) - require.NoError(t, err) - - // Create JWT tokens with specific issuer claims - tokenForProviderA := createTestJWT(t, "https://provider-a.com", "user-a") - tokenForProviderB := createTestJWT(t, "https://provider-b.com", "user-b") - - t.Run("jwt_token_with_issuer_a_only_validated_by_provider_a", func(t *testing.T) { - // This should succeed - token has issuer A and provider A is registered - identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderA) - assert.NoError(t, err) - assert.NotNil(t, identity) - assert.Equal(t, "provider-a", provider.Name()) - }) - - t.Run("jwt_token_with_issuer_b_only_validated_by_provider_b", func(t *testing.T) { - // This should succeed - token has issuer B and provider B is registered - identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderB) - assert.NoError(t, err) - assert.NotNil(t, identity) - assert.Equal(t, "provider-b", provider.Name()) - }) - - t.Run("jwt_token_with_unregistered_issuer_fails", func(t *testing.T) { - // Create token with unregistered issuer - tokenWithUnknownIssuer := createTestJWT(t, "https://unknown-issuer.com", "user-x") - - // This should fail - no provider registered for this issuer - identity, provider, err := service.validateWebIdentityToken(ctx, tokenWithUnknownIssuer) - assert.Error(t, err) - assert.Nil(t, identity) - assert.Nil(t, provider) - assert.Contains(t, err.Error(), "no identity provider registered for issuer: https://unknown-issuer.com") - }) - - t.Run("non_jwt_tokens_are_rejected", func(t *testing.T) { - // Non-JWT tokens should be rejected - no fallback mechanism exists for security - identity, provider, err := service.validateWebIdentityToken(ctx, "token-for-provider-a") - assert.Error(t, err) - assert.Nil(t, identity) - assert.Nil(t, provider) - assert.Contains(t, err.Error(), "web identity token must be a valid JWT token") - }) -} - -// createTestJWT creates a test JWT token with the specified issuer and subject -func createTestJWT(t *testing.T, issuer, subject string) string { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "iss": issuer, - "sub": subject, - "aud": "test-client", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - }) - - tokenString, err := token.SignedString([]byte("test-signing-key")) - require.NoError(t, err) - return tokenString -} - -// MockIdentityProviderWithIssuer is a mock provider that supports issuer mapping -type MockIdentityProviderWithIssuer struct { - name string - issuer string - validTokens map[string]bool -} - -func (m *MockIdentityProviderWithIssuer) Name() string { - return m.name -} - -func (m *MockIdentityProviderWithIssuer) GetIssuer() string { - return m.issuer -} - -func (m *MockIdentityProviderWithIssuer) Initialize(config interface{}) error { - return nil -} - -func (m *MockIdentityProviderWithIssuer) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { - // For JWT tokens, parse and validate the token format - if len(token) > 50 && strings.Contains(token, ".") { - // This looks like a JWT - parse it to get the subject - parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{}) - if err != nil { - return nil, fmt.Errorf("invalid JWT token") - } - - claims, ok := parsedToken.Claims.(jwt.MapClaims) - if !ok { - return nil, fmt.Errorf("invalid claims") - } - - issuer, _ := claims["iss"].(string) - subject, _ := claims["sub"].(string) - - // Verify the issuer matches what we expect - if issuer != m.issuer { - return nil, fmt.Errorf("token issuer %s does not match provider issuer %s", issuer, m.issuer) - } - - return &providers.ExternalIdentity{ - UserID: subject, - Email: subject + "@" + m.name + ".com", - Provider: m.name, - }, nil - } - - // For non-JWT tokens, check our simple token list - if m.validTokens[token] { - return &providers.ExternalIdentity{ - UserID: "test-user", - Email: "test@" + m.name + ".com", - Provider: m.name, - }, nil - } - - return nil, fmt.Errorf("invalid token") -} - -func (m *MockIdentityProviderWithIssuer) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { - return &providers.ExternalIdentity{ - UserID: userID, - Email: userID + "@" + m.name + ".com", - Provider: m.name, - }, nil -} - -func (m *MockIdentityProviderWithIssuer) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { - if m.validTokens[token] { - return &providers.TokenClaims{ - Subject: "test-user", - Issuer: m.issuer, - }, nil - } - return nil, fmt.Errorf("invalid token") -} diff --git a/weed/iam/sts/session_policy_test.go b/weed/iam/sts/session_policy_test.go deleted file mode 100644 index 992fde929..000000000 --- a/weed/iam/sts/session_policy_test.go +++ /dev/null @@ -1,168 +0,0 @@ -package sts - -import ( - "context" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// createSessionPolicyTestJWT creates a test JWT token for session policy tests -func createSessionPolicyTestJWT(t *testing.T, issuer, subject string) string { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "iss": issuer, - "sub": subject, - "aud": "test-client", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - }) - - tokenString, err := token.SignedString([]byte("test-signing-key")) - require.NoError(t, err) - return tokenString -} - -// TestAssumeRoleWithWebIdentity_SessionPolicy verifies inline session policies are preserved in tokens. -func TestAssumeRoleWithWebIdentity_SessionPolicy(t *testing.T) { - service := setupTestSTSService(t) - ctx := context.Background() - - sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::example-bucket/*"}]}` - testToken := createSessionPolicyTestJWT(t, "test-issuer", "test-user") - - request := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/TestRole", - WebIdentityToken: testToken, - RoleSessionName: "test-session", - Policy: &sessionPolicy, - } - - response, err := service.AssumeRoleWithWebIdentity(ctx, request) - require.NoError(t, err) - require.NotNil(t, response) - - sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken) - require.NoError(t, err) - - normalized, err := NormalizeSessionPolicy(sessionPolicy) - require.NoError(t, err) - assert.Equal(t, normalized, sessionInfo.SessionPolicy) - - t.Run("should_succeed_without_session_policy", func(t *testing.T) { - request := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/TestRole", - WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"), - RoleSessionName: "test-session", - } - - response, err := service.AssumeRoleWithWebIdentity(ctx, request) - require.NoError(t, err) - require.NotNil(t, response) - - sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken) - require.NoError(t, err) - assert.Empty(t, sessionInfo.SessionPolicy) - }) -} - -// Test edge case scenarios for the Policy field handling -func TestAssumeRoleWithWebIdentity_SessionPolicy_EdgeCases(t *testing.T) { - service := setupTestSTSService(t) - ctx := context.Background() - - t.Run("malformed_json_policy_rejected", func(t *testing.T) { - malformedPolicy := `{"Version": "2012-10-17", "Statement": [` // Incomplete JSON - - request := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/TestRole", - WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"), - RoleSessionName: "test-session", - Policy: &malformedPolicy, - } - - response, err := service.AssumeRoleWithWebIdentity(ctx, request) - assert.Error(t, err) - assert.Nil(t, response) - assert.Contains(t, err.Error(), "invalid session policy JSON") - }) - - t.Run("invalid_policy_document_rejected", func(t *testing.T) { - invalidPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow"}]}` - - request := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/TestRole", - WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"), - RoleSessionName: "test-session", - Policy: &invalidPolicy, - } - - response, err := service.AssumeRoleWithWebIdentity(ctx, request) - assert.Error(t, err) - assert.Nil(t, response) - assert.Contains(t, err.Error(), "invalid session policy document") - }) - - t.Run("whitespace_policy_ignored", func(t *testing.T) { - whitespacePolicy := " \t\n " - - request := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/TestRole", - WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"), - RoleSessionName: "test-session", - Policy: &whitespacePolicy, - } - - response, err := service.AssumeRoleWithWebIdentity(ctx, request) - require.NoError(t, err) - require.NotNil(t, response) - - sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken) - require.NoError(t, err) - assert.Empty(t, sessionInfo.SessionPolicy) - }) -} - -// TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation verifies that the struct field exists and is optional. -func TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation(t *testing.T) { - request := &AssumeRoleWithWebIdentityRequest{} - - assert.IsType(t, (*string)(nil), request.Policy, - "Policy field should be *string type for optional JSON policy") - assert.Nil(t, request.Policy, - "Policy field should default to nil (no session policy)") - - policyValue := `{"Version": "2012-10-17"}` - request.Policy = &policyValue - assert.NotNil(t, request.Policy, "Should be able to assign policy value") - assert.Equal(t, policyValue, *request.Policy, "Policy value should be preserved") -} - -// TestAssumeRoleWithCredentials_SessionPolicy verifies session policy support for credentials-based flow. -func TestAssumeRoleWithCredentials_SessionPolicy(t *testing.T) { - service := setupTestSTSService(t) - ctx := context.Background() - - sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"filer:CreateEntry","Resource":"arn:aws:filer::path/user-docs/*"}]}` - request := &AssumeRoleWithCredentialsRequest{ - RoleArn: "arn:aws:iam::role/TestRole", - Username: "testuser", - Password: "testpass", - RoleSessionName: "test-session", - ProviderName: "test-ldap", - Policy: &sessionPolicy, - } - - response, err := service.AssumeRoleWithCredentials(ctx, request) - require.NoError(t, err) - require.NotNil(t, response) - - sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken) - require.NoError(t, err) - - normalized, err := NormalizeSessionPolicy(sessionPolicy) - require.NoError(t, err) - assert.Equal(t, normalized, sessionInfo.SessionPolicy) -} diff --git a/weed/iam/sts/sts_service.go b/weed/iam/sts/sts_service.go index d02c82ae1..0d0481795 100644 --- a/weed/iam/sts/sts_service.go +++ b/weed/iam/sts/sts_service.go @@ -879,21 +879,6 @@ func (s *STSService) calculateSessionDuration(durationSeconds *int64, tokenExpir return duration } -// extractSessionIdFromToken extracts session ID from JWT session token -func (s *STSService) extractSessionIdFromToken(sessionToken string) string { - // Validate JWT and extract session claims - claims, err := s.tokenGenerator.ValidateJWTWithClaims(sessionToken) - if err != nil { - // For test compatibility, also handle direct session IDs - if len(sessionToken) == 32 { // Typical session ID length - return sessionToken - } - return "" - } - - return claims.SessionId -} - // validateAssumeRoleWithCredentialsRequest validates the credentials request parameters func (s *STSService) validateAssumeRoleWithCredentialsRequest(request *AssumeRoleWithCredentialsRequest) error { if request.RoleArn == "" { diff --git a/weed/iam/sts/sts_service_test.go b/weed/iam/sts/sts_service_test.go deleted file mode 100644 index e16b3209a..000000000 --- a/weed/iam/sts/sts_service_test.go +++ /dev/null @@ -1,778 +0,0 @@ -package sts - -import ( - "context" - "fmt" - "strings" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/seaweedfs/seaweedfs/weed/iam/providers" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// createSTSTestJWT creates a test JWT token for STS service tests -func createSTSTestJWT(t *testing.T, issuer, subject string) string { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "iss": issuer, - "sub": subject, - "aud": "test-client", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - }) - - tokenString, err := token.SignedString([]byte("test-signing-key")) - require.NoError(t, err) - return tokenString -} - -// TestSTSServiceInitialization tests STS service initialization -func TestSTSServiceInitialization(t *testing.T) { - tests := []struct { - name string - config *STSConfig - wantErr bool - }{ - { - name: "valid config", - config: &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{time.Hour * 12}, - Issuer: "seaweedfs-sts", - SigningKey: []byte("test-signing-key"), - }, - wantErr: false, - }, - { - name: "missing signing key", - config: &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - Issuer: "seaweedfs-sts", - }, - wantErr: true, - }, - { - name: "invalid token duration", - config: &STSConfig{ - TokenDuration: FlexibleDuration{-time.Hour}, - Issuer: "seaweedfs-sts", - SigningKey: []byte("test-key"), - }, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - service := NewSTSService() - - err := service.Initialize(tt.config) - - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.True(t, service.IsInitialized()) - - // Verify defaults if applicable - if tt.config.Issuer == "" { - assert.Equal(t, DefaultIssuer, service.Config.Issuer) - } - if tt.config.TokenDuration.Duration == 0 { - assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, service.Config.TokenDuration.Duration) - } - } - }) - } -} - -func TestSTSServiceDefaults(t *testing.T) { - service := NewSTSService() - config := &STSConfig{ - SigningKey: []byte("test-signing-key"), - // Missing duration and issuer - } - - err := service.Initialize(config) - assert.NoError(t, err) - - assert.Equal(t, DefaultIssuer, config.Issuer) - assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, config.TokenDuration.Duration) - assert.Equal(t, time.Duration(DefaultMaxSessionLength)*time.Second, config.MaxSessionLength.Duration) -} - -// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens -func TestAssumeRoleWithWebIdentity(t *testing.T) { - service := setupTestSTSService(t) - - tests := []struct { - name string - roleArn string - webIdentityToken string - sessionName string - durationSeconds *int64 - wantErr bool - expectedSubject string - }{ - { - name: "successful role assumption", - roleArn: "arn:aws:iam::role/TestRole", - webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user-id"), - sessionName: "test-session", - durationSeconds: nil, // Use default - wantErr: false, - expectedSubject: "test-user-id", - }, - { - name: "invalid web identity token", - roleArn: "arn:aws:iam::role/TestRole", - webIdentityToken: "invalid-token", - sessionName: "test-session", - wantErr: true, - }, - { - name: "non-existent role", - roleArn: "arn:aws:iam::role/NonExistentRole", - webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"), - sessionName: "test-session", - wantErr: true, - }, - { - name: "custom session duration", - roleArn: "arn:aws:iam::role/TestRole", - webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"), - sessionName: "test-session", - durationSeconds: int64Ptr(7200), // 2 hours - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - - request := &AssumeRoleWithWebIdentityRequest{ - RoleArn: tt.roleArn, - WebIdentityToken: tt.webIdentityToken, - RoleSessionName: tt.sessionName, - DurationSeconds: tt.durationSeconds, - } - - response, err := service.AssumeRoleWithWebIdentity(ctx, request) - - if tt.wantErr { - assert.Error(t, err) - assert.Nil(t, response) - } else { - assert.NoError(t, err) - assert.NotNil(t, response) - assert.NotNil(t, response.Credentials) - assert.NotNil(t, response.AssumedRoleUser) - - // Verify credentials - creds := response.Credentials - assert.NotEmpty(t, creds.AccessKeyId) - assert.NotEmpty(t, creds.SecretAccessKey) - assert.NotEmpty(t, creds.SessionToken) - assert.True(t, creds.Expiration.After(time.Now())) - - // Verify assumed role user - user := response.AssumedRoleUser - assert.Equal(t, tt.roleArn, user.AssumedRoleId) - assert.Contains(t, user.Arn, tt.sessionName) - - if tt.expectedSubject != "" { - assert.Equal(t, tt.expectedSubject, user.Subject) - } - } - }) - } -} - -// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials -func TestAssumeRoleWithLDAP(t *testing.T) { - service := setupTestSTSService(t) - - tests := []struct { - name string - roleArn string - username string - password string - sessionName string - wantErr bool - }{ - { - name: "successful LDAP role assumption", - roleArn: "arn:aws:iam::role/LDAPRole", - username: "testuser", - password: "testpass", - sessionName: "ldap-session", - wantErr: false, - }, - { - name: "invalid LDAP credentials", - roleArn: "arn:aws:iam::role/LDAPRole", - username: "testuser", - password: "wrongpass", - sessionName: "ldap-session", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - - request := &AssumeRoleWithCredentialsRequest{ - RoleArn: tt.roleArn, - Username: tt.username, - Password: tt.password, - RoleSessionName: tt.sessionName, - ProviderName: "test-ldap", - } - - response, err := service.AssumeRoleWithCredentials(ctx, request) - - if tt.wantErr { - assert.Error(t, err) - assert.Nil(t, response) - } else { - assert.NoError(t, err) - assert.NotNil(t, response) - assert.NotNil(t, response.Credentials) - } - }) - } -} - -// TestSessionTokenValidation tests session token validation -func TestSessionTokenValidation(t *testing.T) { - service := setupTestSTSService(t) - ctx := context.Background() - - // First, create a session - request := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/TestRole", - WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"), - RoleSessionName: "test-session", - } - - response, err := service.AssumeRoleWithWebIdentity(ctx, request) - require.NoError(t, err) - require.NotNil(t, response) - - sessionToken := response.Credentials.SessionToken - - tests := []struct { - name string - token string - wantErr bool - }{ - { - name: "valid session token", - token: sessionToken, - wantErr: false, - }, - { - name: "invalid session token", - token: "invalid-session-token", - wantErr: true, - }, - { - name: "empty session token", - token: "", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - session, err := service.ValidateSessionToken(ctx, tt.token) - - if tt.wantErr { - assert.Error(t, err) - assert.Nil(t, session) - } else { - assert.NoError(t, err) - assert.NotNil(t, session) - assert.Equal(t, "test-session", session.SessionName) - assert.Equal(t, "arn:aws:iam::role/TestRole", session.RoleArn) - } - }) - } -} - -// TestSessionTokenPersistence tests that JWT tokens remain valid throughout their lifetime -// Note: In the stateless JWT design, tokens cannot be revoked and remain valid until expiration -func TestSessionTokenPersistence(t *testing.T) { - service := setupTestSTSService(t) - ctx := context.Background() - - // Create a session first - request := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/TestRole", - WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"), - RoleSessionName: "test-session", - } - - response, err := service.AssumeRoleWithWebIdentity(ctx, request) - require.NoError(t, err) - - sessionToken := response.Credentials.SessionToken - - // Verify token is valid initially - session, err := service.ValidateSessionToken(ctx, sessionToken) - assert.NoError(t, err) - assert.NotNil(t, session) - assert.Equal(t, "test-session", session.SessionName) - - // In a stateless JWT system, tokens remain valid throughout their lifetime - // Multiple validations should all succeed as long as the token hasn't expired - session2, err := service.ValidateSessionToken(ctx, sessionToken) - assert.NoError(t, err, "Token should remain valid in stateless system") - assert.NotNil(t, session2, "Session should be returned from JWT token") - assert.Equal(t, session.SessionId, session2.SessionId, "Session ID should be consistent") -} - -// Helper functions - -func setupTestSTSService(t *testing.T) *STSService { - service := NewSTSService() - - config := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{time.Hour * 12}, - Issuer: "test-sts", - SigningKey: []byte("test-signing-key-32-characters-long"), - } - - err := service.Initialize(config) - require.NoError(t, err) - - // Set up mock trust policy validator (required for STS testing) - mockValidator := &MockTrustPolicyValidator{} - service.SetTrustPolicyValidator(mockValidator) - - // Register test providers - mockOIDCProvider := &MockIdentityProvider{ - name: "test-oidc", - validTokens: map[string]*providers.TokenClaims{ - createSTSTestJWT(t, "test-issuer", "test-user"): { - Subject: "test-user-id", - Issuer: "test-issuer", - Claims: map[string]interface{}{ - "email": "test@example.com", - "name": "Test User", - }, - }, - }, - } - - mockLDAPProvider := &MockIdentityProvider{ - name: "test-ldap", - validCredentials: map[string]string{ - "testuser": "testpass", - }, - } - - service.RegisterProvider(mockOIDCProvider) - service.RegisterProvider(mockLDAPProvider) - - return service -} - -func int64Ptr(v int64) *int64 { - return &v -} - -// Mock identity provider for testing -type MockIdentityProvider struct { - name string - validTokens map[string]*providers.TokenClaims - validCredentials map[string]string -} - -func (m *MockIdentityProvider) Name() string { - return m.name -} - -func (m *MockIdentityProvider) GetIssuer() string { - return "test-issuer" // This matches the issuer in the token claims -} - -func (m *MockIdentityProvider) Initialize(config interface{}) error { - return nil -} - -func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { - // First try to parse as JWT token - if len(token) > 20 && strings.Count(token, ".") >= 2 { - parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{}) - if err == nil { - if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok { - issuer, _ := claims["iss"].(string) - subject, _ := claims["sub"].(string) - - // Verify the issuer matches what we expect - if issuer == "test-issuer" && subject != "" { - return &providers.ExternalIdentity{ - UserID: subject, - Email: subject + "@test-domain.com", - DisplayName: "Test User " + subject, - Provider: m.name, - }, nil - } - } - } - } - - // Handle legacy OIDC tokens (for backwards compatibility) - if claims, exists := m.validTokens[token]; exists { - email, _ := claims.GetClaimString("email") - name, _ := claims.GetClaimString("name") - - return &providers.ExternalIdentity{ - UserID: claims.Subject, - Email: email, - DisplayName: name, - Provider: m.name, - }, nil - } - - // Handle LDAP credentials (username:password format) - if m.validCredentials != nil { - parts := strings.Split(token, ":") - if len(parts) == 2 { - username, password := parts[0], parts[1] - if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password { - return &providers.ExternalIdentity{ - UserID: username, - Email: username + "@" + m.name + ".com", - DisplayName: "Test User " + username, - Provider: m.name, - }, nil - } - } - } - - return nil, fmt.Errorf("unknown test token: %s", token) -} - -func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { - return &providers.ExternalIdentity{ - UserID: userID, - Email: userID + "@" + m.name + ".com", - Provider: m.name, - }, nil -} - -func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { - if claims, exists := m.validTokens[token]; exists { - return claims, nil - } - return nil, fmt.Errorf("invalid token") -} - -// TestSessionDurationCappedByTokenExpiration tests that session duration is capped by the source token's exp claim -func TestSessionDurationCappedByTokenExpiration(t *testing.T) { - service := NewSTSService() - - config := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, // Default: 1 hour - MaxSessionLength: FlexibleDuration{time.Hour * 12}, - Issuer: "test-sts", - SigningKey: []byte("test-signing-key-32-characters-long"), - } - - err := service.Initialize(config) - require.NoError(t, err) - - tests := []struct { - name string - durationSeconds *int64 - tokenExpiration *time.Time - expectedMaxSeconds int64 - description string - }{ - { - name: "no token expiration - use default duration", - durationSeconds: nil, - tokenExpiration: nil, - expectedMaxSeconds: 3600, // 1 hour default - description: "When no token expiration is set, use the configured default duration", - }, - { - name: "token expires before default duration", - durationSeconds: nil, - tokenExpiration: timePtr(time.Now().Add(30 * time.Minute)), - expectedMaxSeconds: 30 * 60, // 30 minutes - description: "When token expires in 30 min, session should be capped at 30 min", - }, - { - name: "token expires after default duration - use default", - durationSeconds: nil, - tokenExpiration: timePtr(time.Now().Add(2 * time.Hour)), - expectedMaxSeconds: 3600, // 1 hour default, since it's less than 2 hour token expiry - description: "When token expires after default duration, use the default duration", - }, - { - name: "requested duration shorter than token expiry", - durationSeconds: int64Ptr(1800), // 30 min requested - tokenExpiration: timePtr(time.Now().Add(time.Hour)), - expectedMaxSeconds: 1800, // 30 minutes as requested - description: "When requested duration is shorter than token expiry, use requested duration", - }, - { - name: "requested duration longer than token expiry - cap at token expiry", - durationSeconds: int64Ptr(3600), // 1 hour requested - tokenExpiration: timePtr(time.Now().Add(15 * time.Minute)), - expectedMaxSeconds: 15 * 60, // Capped at 15 minutes - description: "When requested duration exceeds token expiry, cap at token expiry", - }, - { - name: "GitLab CI short-lived token scenario", - durationSeconds: nil, - tokenExpiration: timePtr(time.Now().Add(5 * time.Minute)), - expectedMaxSeconds: 5 * 60, // 5 minutes - description: "GitLab CI job with 5 minute timeout should result in 5 minute session", - }, - { - name: "already expired token - defense in depth", - durationSeconds: nil, - tokenExpiration: timePtr(time.Now().Add(-5 * time.Minute)), // Expired 5 minutes ago - expectedMaxSeconds: 60, // 1 minute minimum - description: "Already expired token should result in minimal 1 minute session", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - duration := service.calculateSessionDuration(tt.durationSeconds, tt.tokenExpiration) - - // Allow 5 second tolerance for time calculations - maxExpected := time.Duration(tt.expectedMaxSeconds+5) * time.Second - minExpected := time.Duration(tt.expectedMaxSeconds-5) * time.Second - - assert.GreaterOrEqual(t, duration, minExpected, - "%s: duration %v should be >= %v", tt.description, duration, minExpected) - assert.LessOrEqual(t, duration, maxExpected, - "%s: duration %v should be <= %v", tt.description, duration, maxExpected) - }) - } -} - -// TestAssumeRoleWithWebIdentityRespectsTokenExpiration tests end-to-end that session duration is capped -func TestAssumeRoleWithWebIdentityRespectsTokenExpiration(t *testing.T) { - service := NewSTSService() - - config := &STSConfig{ - TokenDuration: FlexibleDuration{time.Hour}, - MaxSessionLength: FlexibleDuration{time.Hour * 12}, - Issuer: "test-sts", - SigningKey: []byte("test-signing-key-32-characters-long"), - } - - err := service.Initialize(config) - require.NoError(t, err) - - // Set up mock trust policy validator - mockValidator := &MockTrustPolicyValidator{} - service.SetTrustPolicyValidator(mockValidator) - - // Create a mock provider that returns tokens with short expiration - shortLivedTokenExpiration := time.Now().Add(10 * time.Minute) - mockProvider := &MockIdentityProviderWithExpiration{ - name: "short-lived-issuer", - tokenExpiration: &shortLivedTokenExpiration, - } - service.RegisterProvider(mockProvider) - - ctx := context.Background() - - // Create a JWT token with short expiration - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "iss": "short-lived-issuer", - "sub": "test-user", - "aud": "test-client", - "exp": shortLivedTokenExpiration.Unix(), - "iat": time.Now().Unix(), - }) - tokenString, err := token.SignedString([]byte("test-signing-key")) - require.NoError(t, err) - - request := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/TestRole", - WebIdentityToken: tokenString, - RoleSessionName: "test-session", - } - - response, err := service.AssumeRoleWithWebIdentity(ctx, request) - require.NoError(t, err) - require.NotNil(t, response) - - // Verify the session expires at or before the token expiration - // Allow 5 second tolerance - assert.True(t, response.Credentials.Expiration.Before(shortLivedTokenExpiration.Add(5*time.Second)), - "Session expiration (%v) should not exceed token expiration (%v)", - response.Credentials.Expiration, shortLivedTokenExpiration) -} - -// MockIdentityProviderWithExpiration is a mock provider that returns tokens with configurable expiration -type MockIdentityProviderWithExpiration struct { - name string - tokenExpiration *time.Time -} - -func (m *MockIdentityProviderWithExpiration) Name() string { - return m.name -} - -func (m *MockIdentityProviderWithExpiration) GetIssuer() string { - return m.name -} - -func (m *MockIdentityProviderWithExpiration) Initialize(config interface{}) error { - return nil -} - -func (m *MockIdentityProviderWithExpiration) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { - // Parse the token to get subject - parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{}) - if err != nil { - return nil, fmt.Errorf("failed to parse token: %w", err) - } - - claims, ok := parsedToken.Claims.(jwt.MapClaims) - if !ok { - return nil, fmt.Errorf("invalid claims") - } - - subject, _ := claims["sub"].(string) - - identity := &providers.ExternalIdentity{ - UserID: subject, - Email: subject + "@example.com", - DisplayName: "Test User", - Provider: m.name, - TokenExpiration: m.tokenExpiration, - } - - return identity, nil -} - -func (m *MockIdentityProviderWithExpiration) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { - return &providers.ExternalIdentity{ - UserID: userID, - Provider: m.name, - }, nil -} - -func (m *MockIdentityProviderWithExpiration) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { - claims := &providers.TokenClaims{ - Subject: "test-user", - Issuer: m.name, - } - if m.tokenExpiration != nil { - claims.ExpiresAt = *m.tokenExpiration - } - return claims, nil -} - -func timePtr(t time.Time) *time.Time { - return &t -} - -// TestAssumeRoleWithWebIdentity_PreservesAttributes tests that attributes from the identity provider -// are correctly propagated to the session token's request context -func TestAssumeRoleWithWebIdentity_PreservesAttributes(t *testing.T) { - service := setupTestSTSService(t) - - // Create a mock provider that returns a user with attributes - mockProvider := &MockIdentityProviderWithAttributes{ - name: "attr-provider", - attributes: map[string]string{ - "preferred_username": "my-user", - "department": "engineering", - "project": "seaweedfs", - }, - } - service.RegisterProvider(mockProvider) - - // Create a valid JWT token for the provider - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "iss": "attr-provider", - "sub": "test-user-id", - "aud": "test-client", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - }) - tokenString, err := token.SignedString([]byte("test-signing-key")) - require.NoError(t, err) - - ctx := context.Background() - request := &AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/TestRole", - WebIdentityToken: tokenString, - RoleSessionName: "test-session", - } - - response, err := service.AssumeRoleWithWebIdentity(ctx, request) - require.NoError(t, err) - require.NotNil(t, response) - - // Validate the session token to check claims - sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken) - require.NoError(t, err) - - // Check that attributes are present in RequestContext - require.NotNil(t, sessionInfo.RequestContext, "RequestContext should not be nil") - assert.Equal(t, "my-user", sessionInfo.RequestContext["preferred_username"]) - assert.Equal(t, "engineering", sessionInfo.RequestContext["department"]) - assert.Equal(t, "seaweedfs", sessionInfo.RequestContext["project"]) - - // Check standard claims are also present - assert.Equal(t, "test-user-id", sessionInfo.RequestContext["sub"]) - assert.Equal(t, "test@example.com", sessionInfo.RequestContext["email"]) - assert.Equal(t, "Test User", sessionInfo.RequestContext["name"]) -} - -// MockIdentityProviderWithAttributes is a mock provider that returns configured attributes -type MockIdentityProviderWithAttributes struct { - name string - attributes map[string]string -} - -func (m *MockIdentityProviderWithAttributes) Name() string { - return m.name -} - -func (m *MockIdentityProviderWithAttributes) GetIssuer() string { - return m.name -} - -func (m *MockIdentityProviderWithAttributes) Initialize(config interface{}) error { - return nil -} - -func (m *MockIdentityProviderWithAttributes) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) { - return &providers.ExternalIdentity{ - UserID: "test-user-id", - Email: "test@example.com", - DisplayName: "Test User", - Provider: m.name, - Attributes: m.attributes, - }, nil -} - -func (m *MockIdentityProviderWithAttributes) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) { - return nil, nil -} - -func (m *MockIdentityProviderWithAttributes) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) { - return &providers.TokenClaims{ - Subject: "test-user-id", - Issuer: m.name, - }, nil -} diff --git a/weed/iam/sts/test_utils.go b/weed/iam/sts/test_utils.go index 61ef72570..61de76bbd 100644 --- a/weed/iam/sts/test_utils.go +++ b/weed/iam/sts/test_utils.go @@ -1,53 +1,4 @@ package sts -import ( - "context" - "fmt" - "strings" - - "github.com/seaweedfs/seaweedfs/weed/iam/providers" -) - // MockTrustPolicyValidator is a simple mock for testing STS functionality type MockTrustPolicyValidator struct{} - -// ValidateTrustPolicyForWebIdentity allows valid JWT test tokens for STS testing -func (m *MockTrustPolicyValidator) ValidateTrustPolicyForWebIdentity(ctx context.Context, roleArn string, webIdentityToken string, durationSeconds *int64) error { - // Reject non-existent roles for testing - if strings.Contains(roleArn, "NonExistentRole") { - return fmt.Errorf("trust policy validation failed: role does not exist") - } - - // For STS unit tests, allow JWT tokens that look valid (contain dots for JWT structure) - // In real implementation, this would validate against actual trust policies - if len(webIdentityToken) > 20 && strings.Count(webIdentityToken, ".") >= 2 { - // This appears to be a JWT token - allow it for testing - return nil - } - - // Legacy support for specific test tokens during migration - if webIdentityToken == "valid_test_token" || webIdentityToken == "valid-oidc-token" { - return nil - } - - // Reject invalid tokens - if webIdentityToken == "invalid_token" || webIdentityToken == "expired_token" || webIdentityToken == "invalid-token" { - return fmt.Errorf("trust policy denies token") - } - - return nil -} - -// ValidateTrustPolicyForCredentials allows valid test identities for STS testing -func (m *MockTrustPolicyValidator) ValidateTrustPolicyForCredentials(ctx context.Context, roleArn string, identity *providers.ExternalIdentity) error { - // Reject non-existent roles for testing - if strings.Contains(roleArn, "NonExistentRole") { - return fmt.Errorf("trust policy validation failed: role does not exist") - } - - // For STS unit tests, allow test identities - if identity != nil && identity.UserID != "" { - return nil - } - return fmt.Errorf("invalid identity for role assumption") -} diff --git a/weed/images/preprocess.go b/weed/images/preprocess.go deleted file mode 100644 index f6f3b554d..000000000 --- a/weed/images/preprocess.go +++ /dev/null @@ -1,29 +0,0 @@ -package images - -import ( - "bytes" - "io" - "path/filepath" - "strings" -) - -/* -* Preprocess image files on client side. -* 1. possibly adjust the orientation -* 2. resize the image to a width or height limit -* 3. remove the exif data -* Call this function on any file uploaded to SeaweedFS -* - */ -func MaybePreprocessImage(filename string, data []byte, width, height int) (resized io.ReadSeeker, w int, h int) { - ext := filepath.Ext(filename) - ext = strings.ToLower(ext) - switch ext { - case ".png", ".gif": - return Resized(ext, bytes.NewReader(data), width, height, "") - case ".jpg", ".jpeg": - data = FixJpgOrientation(data) - return Resized(ext, bytes.NewReader(data), width, height, "") - } - return bytes.NewReader(data), 0, 0 -} diff --git a/weed/kms/config_loader.go b/weed/kms/config_loader.go index 3778c0f59..5f31259c6 100644 --- a/weed/kms/config_loader.go +++ b/weed/kms/config_loader.go @@ -290,15 +290,6 @@ func (loader *ConfigLoader) ValidateConfiguration() error { return nil } -// LoadKMSFromFilerToml is a convenience function to load KMS configuration from filer.toml -func LoadKMSFromFilerToml(v ViperConfig) error { - loader := NewConfigLoader(v) - if err := loader.LoadConfigurations(); err != nil { - return err - } - return loader.ValidateConfiguration() -} - // LoadKMSFromConfig loads KMS configuration directly from parsed JSON data func LoadKMSFromConfig(kmsConfig interface{}) error { kmsMap, ok := kmsConfig.(map[string]interface{}) @@ -415,12 +406,3 @@ func getIntFromConfig(config map[string]interface{}, key string, defaultValue in } return defaultValue } - -func getStringFromConfig(config map[string]interface{}, key string, defaultValue string) string { - if value, exists := config[key]; exists { - if stringValue, ok := value.(string); ok { - return stringValue - } - } - return defaultValue -} diff --git a/weed/mount/filehandle.go b/weed/mount/filehandle.go index 98ca6737f..485a00e41 100644 --- a/weed/mount/filehandle.go +++ b/weed/mount/filehandle.go @@ -147,13 +147,6 @@ func (fh *FileHandle) ReleaseHandle() { } } -func lessThan(a, b *filer_pb.FileChunk) bool { - if a.ModifiedTsNs == b.ModifiedTsNs { - return a.Fid.FileKey < b.Fid.FileKey - } - return a.ModifiedTsNs < b.ModifiedTsNs -} - // getCumulativeOffsets returns cached cumulative offsets for chunks, computing them if necessary func (fh *FileHandle) getCumulativeOffsets(chunks []*filer_pb.FileChunk) []int64 { fh.chunkCacheLock.RLock() diff --git a/weed/mount/page_writer/dirty_pages.go b/weed/mount/page_writer/dirty_pages.go index cec365231..472815dd5 100644 --- a/weed/mount/page_writer/dirty_pages.go +++ b/weed/mount/page_writer/dirty_pages.go @@ -21,9 +21,3 @@ func min(x, y int64) int64 { } return y } -func minInt(x, y int) int { - if x < y { - return x - } - return y -} diff --git a/weed/mount/rdma_client.go b/weed/mount/rdma_client.go index e9ee802ce..6a77f8f52 100644 --- a/weed/mount/rdma_client.go +++ b/weed/mount/rdma_client.go @@ -119,13 +119,6 @@ func (c *RDMAMountClient) lookupVolumeLocationByFileID(ctx context.Context, file return bestAddress, nil } -// lookupVolumeLocation finds the best volume server for a given volume ID (legacy method) -func (c *RDMAMountClient) lookupVolumeLocation(ctx context.Context, volumeID uint32, needleID uint64, cookie uint32) (string, error) { - // Create a file ID for lookup (format: volumeId,needleId,cookie) - fileID := fmt.Sprintf("%d,%x,%d", volumeID, needleID, cookie) - return c.lookupVolumeLocationByFileID(ctx, fileID) -} - // healthCheck verifies that the RDMA sidecar is available and functioning func (c *RDMAMountClient) healthCheck() error { ctx, cancel := context.WithTimeout(context.Background(), c.timeout) diff --git a/weed/mq/broker/broker_errors.go b/weed/mq/broker/broker_errors.go index b3d4cc42c..529a03e44 100644 --- a/weed/mq/broker/broker_errors.go +++ b/weed/mq/broker/broker_errors.go @@ -117,11 +117,6 @@ func GetBrokerErrorInfo(code int32) BrokerErrorInfo { } } -// GetKafkaErrorCode returns the corresponding Kafka protocol error code for a broker error -func GetKafkaErrorCode(brokerErrorCode int32) int16 { - return GetBrokerErrorInfo(brokerErrorCode).KafkaCode -} - // CreateBrokerError creates a structured broker error with both error code and message func CreateBrokerError(code int32, message string) (int32, string) { info := GetBrokerErrorInfo(code) diff --git a/weed/mq/broker/broker_offset_integration_test.go b/weed/mq/broker/broker_offset_integration_test.go deleted file mode 100644 index 49df58a64..000000000 --- a/weed/mq/broker/broker_offset_integration_test.go +++ /dev/null @@ -1,351 +0,0 @@ -package broker - -import ( - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/mq/topic" - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" -) - -func createTestTopic() topic.Topic { - return topic.Topic{ - Namespace: "test", - Name: "offset-test", - } -} - -func createTestPartition() topic.Partition { - return topic.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: time.Now().UnixNano(), - } -} - -func TestBrokerOffsetManager_AssignOffset(t *testing.T) { - storage := NewInMemoryOffsetStorageForTesting() - manager := NewBrokerOffsetManagerWithStorage(storage) - testTopic := createTestTopic() - testPartition := createTestPartition() - - // Test sequential offset assignment - for i := int64(0); i < 10; i++ { - assignedOffset, err := manager.AssignOffset(testTopic, testPartition) - if err != nil { - t.Fatalf("Failed to assign offset %d: %v", i, err) - } - - if assignedOffset != i { - t.Errorf("Expected offset %d, got %d", i, assignedOffset) - } - } -} - -func TestBrokerOffsetManager_AssignBatchOffsets(t *testing.T) { - storage := NewInMemoryOffsetStorageForTesting() - manager := NewBrokerOffsetManagerWithStorage(storage) - testTopic := createTestTopic() - testPartition := createTestPartition() - - // Assign batch of offsets - baseOffset, lastOffset, err := manager.AssignBatchOffsets(testTopic, testPartition, 5) - if err != nil { - t.Fatalf("Failed to assign batch offsets: %v", err) - } - - if baseOffset != 0 { - t.Errorf("Expected base offset 0, got %d", baseOffset) - } - - if lastOffset != 4 { - t.Errorf("Expected last offset 4, got %d", lastOffset) - } - - // Assign another batch - baseOffset2, lastOffset2, err := manager.AssignBatchOffsets(testTopic, testPartition, 3) - if err != nil { - t.Fatalf("Failed to assign second batch offsets: %v", err) - } - - if baseOffset2 != 5 { - t.Errorf("Expected base offset 5, got %d", baseOffset2) - } - - if lastOffset2 != 7 { - t.Errorf("Expected last offset 7, got %d", lastOffset2) - } -} - -func TestBrokerOffsetManager_GetHighWaterMark(t *testing.T) { - storage := NewInMemoryOffsetStorageForTesting() - manager := NewBrokerOffsetManagerWithStorage(storage) - testTopic := createTestTopic() - testPartition := createTestPartition() - - // Initially should be 0 - hwm, err := manager.GetHighWaterMark(testTopic, testPartition) - if err != nil { - t.Fatalf("Failed to get initial high water mark: %v", err) - } - - if hwm != 0 { - t.Errorf("Expected initial high water mark 0, got %d", hwm) - } - - // Assign some offsets - manager.AssignBatchOffsets(testTopic, testPartition, 10) - - // High water mark should be updated - hwm, err = manager.GetHighWaterMark(testTopic, testPartition) - if err != nil { - t.Fatalf("Failed to get high water mark after assignment: %v", err) - } - - if hwm != 10 { - t.Errorf("Expected high water mark 10, got %d", hwm) - } -} - -func TestBrokerOffsetManager_CreateSubscription(t *testing.T) { - storage := NewInMemoryOffsetStorageForTesting() - manager := NewBrokerOffsetManagerWithStorage(storage) - testTopic := createTestTopic() - testPartition := createTestPartition() - - // Assign some offsets first - manager.AssignBatchOffsets(testTopic, testPartition, 5) - - // Create subscription - sub, err := manager.CreateSubscription( - "test-sub", - testTopic, - testPartition, - schema_pb.OffsetType_RESET_TO_EARLIEST, - 0, - ) - - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - if sub.ID != "test-sub" { - t.Errorf("Expected subscription ID 'test-sub', got %s", sub.ID) - } - - if sub.StartOffset != 0 { - t.Errorf("Expected start offset 0, got %d", sub.StartOffset) - } -} - -func TestBrokerOffsetManager_GetPartitionOffsetInfo(t *testing.T) { - storage := NewInMemoryOffsetStorageForTesting() - manager := NewBrokerOffsetManagerWithStorage(storage) - testTopic := createTestTopic() - testPartition := createTestPartition() - - // Test empty partition - info, err := manager.GetPartitionOffsetInfo(testTopic, testPartition) - if err != nil { - t.Fatalf("Failed to get partition offset info: %v", err) - } - - if info.EarliestOffset != 0 { - t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset) - } - - if info.LatestOffset != -1 { - t.Errorf("Expected latest offset -1 for empty partition, got %d", info.LatestOffset) - } - - // Assign offsets and test again - manager.AssignBatchOffsets(testTopic, testPartition, 5) - - info, err = manager.GetPartitionOffsetInfo(testTopic, testPartition) - if err != nil { - t.Fatalf("Failed to get partition offset info after assignment: %v", err) - } - - if info.LatestOffset != 4 { - t.Errorf("Expected latest offset 4, got %d", info.LatestOffset) - } - - if info.HighWaterMark != 5 { - t.Errorf("Expected high water mark 5, got %d", info.HighWaterMark) - } -} - -func TestBrokerOffsetManager_MultiplePartitions(t *testing.T) { - storage := NewInMemoryOffsetStorageForTesting() - manager := NewBrokerOffsetManagerWithStorage(storage) - testTopic := createTestTopic() - - // Create different partitions - partition1 := topic.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: time.Now().UnixNano(), - } - - partition2 := topic.Partition{ - RingSize: 1024, - RangeStart: 32, - RangeStop: 63, - UnixTimeNs: time.Now().UnixNano(), - } - - // Assign offsets to different partitions - assignedOffset1, err := manager.AssignOffset(testTopic, partition1) - if err != nil { - t.Fatalf("Failed to assign offset to partition1: %v", err) - } - - assignedOffset2, err := manager.AssignOffset(testTopic, partition2) - if err != nil { - t.Fatalf("Failed to assign offset to partition2: %v", err) - } - - // Both should start at 0 - if assignedOffset1 != 0 { - t.Errorf("Expected offset 0 for partition1, got %d", assignedOffset1) - } - - if assignedOffset2 != 0 { - t.Errorf("Expected offset 0 for partition2, got %d", assignedOffset2) - } - - // Assign more offsets to partition1 - assignedOffset1_2, err := manager.AssignOffset(testTopic, partition1) - if err != nil { - t.Fatalf("Failed to assign second offset to partition1: %v", err) - } - - if assignedOffset1_2 != 1 { - t.Errorf("Expected offset 1 for partition1, got %d", assignedOffset1_2) - } - - // Partition2 should still be at 0 for next assignment - assignedOffset2_2, err := manager.AssignOffset(testTopic, partition2) - if err != nil { - t.Fatalf("Failed to assign second offset to partition2: %v", err) - } - - if assignedOffset2_2 != 1 { - t.Errorf("Expected offset 1 for partition2, got %d", assignedOffset2_2) - } -} - -func TestOffsetAwarePublisher(t *testing.T) { - storage := NewInMemoryOffsetStorageForTesting() - manager := NewBrokerOffsetManagerWithStorage(storage) - testTopic := createTestTopic() - testPartition := createTestPartition() - - // Create a mock local partition (simplified for testing) - localPartition := &topic.LocalPartition{} - - // Create offset assignment function - assignOffsetFn := func() (int64, error) { - return manager.AssignOffset(testTopic, testPartition) - } - - // Create offset-aware publisher - publisher := topic.NewOffsetAwarePublisher(localPartition, assignOffsetFn) - - if publisher.GetPartition() != localPartition { - t.Error("Publisher should return the correct partition") - } - - // Test would require more setup to actually publish messages - // This tests the basic structure -} - -func TestBrokerOffsetManager_GetOffsetMetrics(t *testing.T) { - storage := NewInMemoryOffsetStorageForTesting() - manager := NewBrokerOffsetManagerWithStorage(storage) - testTopic := createTestTopic() - testPartition := createTestPartition() - - // Initial metrics - metrics := manager.GetOffsetMetrics() - if metrics.TotalOffsets != 0 { - t.Errorf("Expected 0 total offsets initially, got %d", metrics.TotalOffsets) - } - - // Assign some offsets - manager.AssignBatchOffsets(testTopic, testPartition, 5) - - // Create subscription - manager.CreateSubscription("test-sub", testTopic, testPartition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0) - - // Check updated metrics - metrics = manager.GetOffsetMetrics() - if metrics.PartitionCount != 1 { - t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount) - } -} - -func TestBrokerOffsetManager_AssignOffsetsWithResult(t *testing.T) { - storage := NewInMemoryOffsetStorageForTesting() - manager := NewBrokerOffsetManagerWithStorage(storage) - testTopic := createTestTopic() - testPartition := createTestPartition() - - // Assign offsets with result - result := manager.AssignOffsetsWithResult(testTopic, testPartition, 3) - - if result.Error != nil { - t.Fatalf("Expected no error, got: %v", result.Error) - } - - if result.BaseOffset != 0 { - t.Errorf("Expected base offset 0, got %d", result.BaseOffset) - } - - if result.LastOffset != 2 { - t.Errorf("Expected last offset 2, got %d", result.LastOffset) - } - - if result.Count != 3 { - t.Errorf("Expected count 3, got %d", result.Count) - } - - if result.Topic != testTopic { - t.Error("Topic mismatch in result") - } - - if result.Partition != testPartition { - t.Error("Partition mismatch in result") - } - - if result.Timestamp <= 0 { - t.Error("Timestamp should be set") - } -} - -func TestBrokerOffsetManager_Shutdown(t *testing.T) { - storage := NewInMemoryOffsetStorageForTesting() - manager := NewBrokerOffsetManagerWithStorage(storage) - testTopic := createTestTopic() - testPartition := createTestPartition() - - // Assign some offsets and create subscriptions - manager.AssignBatchOffsets(testTopic, testPartition, 5) - manager.CreateSubscription("test-sub", testTopic, testPartition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0) - - // Shutdown should not panic - manager.Shutdown() - - // After shutdown, operations should still work (using new managers) - offset, err := manager.AssignOffset(testTopic, testPartition) - if err != nil { - t.Fatalf("Operations should still work after shutdown: %v", err) - } - - // Should start from 0 again (new manager) - if offset != 0 { - t.Errorf("Expected offset 0 after shutdown, got %d", offset) - } -} diff --git a/weed/mq/broker/broker_server.go b/weed/mq/broker/broker_server.go index 5116ff5a5..a3e658c35 100644 --- a/weed/mq/broker/broker_server.go +++ b/weed/mq/broker/broker_server.go @@ -203,14 +203,6 @@ func (b *MessageQueueBroker) GetDataCenter() string { } -func (b *MessageQueueBroker) withMasterClient(streamingMode bool, master pb.ServerAddress, fn func(client master_pb.SeaweedClient) error) error { - - return pb.WithMasterClient(streamingMode, master, b.grpcDialOption, false, func(client master_pb.SeaweedClient) error { - return fn(client) - }) - -} - func (b *MessageQueueBroker) withBrokerClient(streamingMode bool, server pb.ServerAddress, fn func(client mq_pb.SeaweedMessagingClient) error) error { return pb.WithBrokerGrpcClient(streamingMode, server.String(), b.grpcDialOption, func(client mq_pb.SeaweedMessagingClient) error { diff --git a/weed/mq/kafka/consumer_offset/filer_storage.go b/weed/mq/kafka/consumer_offset/filer_storage.go index 9d92ad730..967982683 100644 --- a/weed/mq/kafka/consumer_offset/filer_storage.go +++ b/weed/mq/kafka/consumer_offset/filer_storage.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "strings" "time" "github.com/seaweedfs/seaweedfs/weed/filer_client" @@ -192,10 +191,6 @@ func (f *FilerStorage) getOffsetPath(group, topic string, partition int32) strin return fmt.Sprintf("%s/offset", f.getPartitionPath(group, topic, partition)) } -func (f *FilerStorage) getMetadataPath(group, topic string, partition int32) string { - return fmt.Sprintf("%s/metadata", f.getPartitionPath(group, topic, partition)) -} - func (f *FilerStorage) writeFile(path string, data []byte) error { fullPath := util.FullPath(path) dir, name := fullPath.DirAndName() @@ -311,16 +306,3 @@ func (f *FilerStorage) deleteDirectory(path string) error { return err }) } - -// normalizePath removes leading/trailing slashes and collapses multiple slashes -func normalizePath(path string) string { - path = strings.Trim(path, "/") - parts := strings.Split(path, "/") - normalized := []string{} - for _, part := range parts { - if part != "" { - normalized = append(normalized, part) - } - } - return "/" + strings.Join(normalized, "/") -} diff --git a/weed/mq/kafka/consumer_offset/filer_storage_test.go b/weed/mq/kafka/consumer_offset/filer_storage_test.go deleted file mode 100644 index 67a0e7e09..000000000 --- a/weed/mq/kafka/consumer_offset/filer_storage_test.go +++ /dev/null @@ -1,65 +0,0 @@ -package consumer_offset - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -// Note: These tests require a running filer instance -// They are marked as integration tests and should be run with: -// go test -tags=integration - -func TestFilerStorageCommitAndFetch(t *testing.T) { - t.Skip("Requires running filer - integration test") - - // This will be implemented once we have test infrastructure - // Test will: - // 1. Create filer storage - // 2. Commit offset - // 3. Fetch offset - // 4. Verify values match -} - -func TestFilerStoragePersistence(t *testing.T) { - t.Skip("Requires running filer - integration test") - - // Test will: - // 1. Commit offset with first storage instance - // 2. Close first instance - // 3. Create new storage instance - // 4. Fetch offset and verify it persisted -} - -func TestFilerStorageMultipleGroups(t *testing.T) { - t.Skip("Requires running filer - integration test") - - // Test will: - // 1. Commit offsets for multiple groups - // 2. Fetch all offsets per group - // 3. Verify isolation between groups -} - -func TestFilerStoragePath(t *testing.T) { - // Test path generation (doesn't require filer) - storage := &FilerStorage{} - - group := "test-group" - topic := "test-topic" - partition := int32(5) - - groupPath := storage.getGroupPath(group) - assert.Equal(t, ConsumerOffsetsBasePath+"/test-group", groupPath) - - topicPath := storage.getTopicPath(group, topic) - assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic", topicPath) - - partitionPath := storage.getPartitionPath(group, topic, partition) - assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5", partitionPath) - - offsetPath := storage.getOffsetPath(group, topic, partition) - assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5/offset", offsetPath) - - metadataPath := storage.getMetadataPath(group, topic, partition) - assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5/metadata", metadataPath) -} diff --git a/weed/mq/kafka/integration/seaweedmq_handler_topics.go b/weed/mq/kafka/integration/seaweedmq_handler_topics.go index b635b40af..b2071fd00 100644 --- a/weed/mq/kafka/integration/seaweedmq_handler_topics.go +++ b/weed/mq/kafka/integration/seaweedmq_handler_topics.go @@ -278,38 +278,3 @@ func (h *SeaweedMQHandler) checkTopicInFiler(topicName string) bool { return exists } - -// listTopicsFromFiler lists all topics from the filer -func (h *SeaweedMQHandler) listTopicsFromFiler() []string { - if h.filerClientAccessor == nil { - return []string{} - } - - var topics []string - - h.filerClientAccessor.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { - request := &filer_pb.ListEntriesRequest{ - Directory: "/topics/kafka", - } - - stream, err := client.ListEntries(context.Background(), request) - if err != nil { - return nil // Don't propagate error, just return empty list - } - - for { - resp, err := stream.Recv() - if err != nil { - break // End of stream or error - } - - if resp.Entry != nil && resp.Entry.IsDirectory { - topics = append(topics, resp.Entry.Name) - } else if resp.Entry != nil { - } - } - return nil - }) - - return topics -} diff --git a/weed/mq/kafka/partition_mapping.go b/weed/mq/kafka/partition_mapping.go deleted file mode 100644 index a956c3cde..000000000 --- a/weed/mq/kafka/partition_mapping.go +++ /dev/null @@ -1,53 +0,0 @@ -package kafka - -import ( - "github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer" - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" -) - -// Convenience functions for partition mapping used by production code -// The full PartitionMapper implementation is in partition_mapping_test.go for testing - -// MapKafkaPartitionToSMQRange maps a Kafka partition to SeaweedMQ ring range -func MapKafkaPartitionToSMQRange(kafkaPartition int32) (rangeStart, rangeStop int32) { - // Use a range size that divides evenly into MaxPartitionCount (2520) - // Range size 35 gives us exactly 72 Kafka partitions: 2520 / 35 = 72 - rangeSize := int32(35) - rangeStart = kafkaPartition * rangeSize - rangeStop = rangeStart + rangeSize - 1 - return rangeStart, rangeStop -} - -// CreateSMQPartition creates a SeaweedMQ partition from a Kafka partition -func CreateSMQPartition(kafkaPartition int32, unixTimeNs int64) *schema_pb.Partition { - rangeStart, rangeStop := MapKafkaPartitionToSMQRange(kafkaPartition) - - return &schema_pb.Partition{ - RingSize: pub_balancer.MaxPartitionCount, - RangeStart: rangeStart, - RangeStop: rangeStop, - UnixTimeNs: unixTimeNs, - } -} - -// ExtractKafkaPartitionFromSMQRange extracts the Kafka partition from SeaweedMQ range -func ExtractKafkaPartitionFromSMQRange(rangeStart int32) int32 { - rangeSize := int32(35) - return rangeStart / rangeSize -} - -// ValidateKafkaPartition validates that a Kafka partition is within supported range -func ValidateKafkaPartition(kafkaPartition int32) bool { - maxPartitions := int32(pub_balancer.MaxPartitionCount) / 35 // 72 partitions - return kafkaPartition >= 0 && kafkaPartition < maxPartitions -} - -// GetRangeSize returns the range size used for partition mapping -func GetRangeSize() int32 { - return 35 -} - -// GetMaxKafkaPartitions returns the maximum number of Kafka partitions supported -func GetMaxKafkaPartitions() int32 { - return int32(pub_balancer.MaxPartitionCount) / 35 // 72 partitions -} diff --git a/weed/mq/kafka/partition_mapping_test.go b/weed/mq/kafka/partition_mapping_test.go deleted file mode 100644 index 6f41a68d4..000000000 --- a/weed/mq/kafka/partition_mapping_test.go +++ /dev/null @@ -1,294 +0,0 @@ -package kafka - -import ( - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer" - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" -) - -// PartitionMapper provides consistent Kafka partition to SeaweedMQ ring mapping -// NOTE: This is test-only code and not used in the actual Kafka Gateway implementation -type PartitionMapper struct{} - -// NewPartitionMapper creates a new partition mapper -func NewPartitionMapper() *PartitionMapper { - return &PartitionMapper{} -} - -// GetRangeSize returns the consistent range size for Kafka partition mapping -// This ensures all components use the same calculation -func (pm *PartitionMapper) GetRangeSize() int32 { - // Use a range size that divides evenly into MaxPartitionCount (2520) - // Range size 35 gives us exactly 72 Kafka partitions: 2520 / 35 = 72 - // This provides a good balance between partition granularity and ring utilization - return 35 -} - -// GetMaxKafkaPartitions returns the maximum number of Kafka partitions supported -func (pm *PartitionMapper) GetMaxKafkaPartitions() int32 { - // With range size 35, we can support: 2520 / 35 = 72 Kafka partitions - return int32(pub_balancer.MaxPartitionCount) / pm.GetRangeSize() -} - -// MapKafkaPartitionToSMQRange maps a Kafka partition to SeaweedMQ ring range -func (pm *PartitionMapper) MapKafkaPartitionToSMQRange(kafkaPartition int32) (rangeStart, rangeStop int32) { - rangeSize := pm.GetRangeSize() - rangeStart = kafkaPartition * rangeSize - rangeStop = rangeStart + rangeSize - 1 - return rangeStart, rangeStop -} - -// CreateSMQPartition creates a SeaweedMQ partition from a Kafka partition -func (pm *PartitionMapper) CreateSMQPartition(kafkaPartition int32, unixTimeNs int64) *schema_pb.Partition { - rangeStart, rangeStop := pm.MapKafkaPartitionToSMQRange(kafkaPartition) - - return &schema_pb.Partition{ - RingSize: pub_balancer.MaxPartitionCount, - RangeStart: rangeStart, - RangeStop: rangeStop, - UnixTimeNs: unixTimeNs, - } -} - -// ExtractKafkaPartitionFromSMQRange extracts the Kafka partition from SeaweedMQ range -func (pm *PartitionMapper) ExtractKafkaPartitionFromSMQRange(rangeStart int32) int32 { - rangeSize := pm.GetRangeSize() - return rangeStart / rangeSize -} - -// ValidateKafkaPartition validates that a Kafka partition is within supported range -func (pm *PartitionMapper) ValidateKafkaPartition(kafkaPartition int32) bool { - return kafkaPartition >= 0 && kafkaPartition < pm.GetMaxKafkaPartitions() -} - -// GetPartitionMappingInfo returns debug information about the partition mapping -func (pm *PartitionMapper) GetPartitionMappingInfo() map[string]interface{} { - return map[string]interface{}{ - "ring_size": pub_balancer.MaxPartitionCount, - "range_size": pm.GetRangeSize(), - "max_kafka_partitions": pm.GetMaxKafkaPartitions(), - "ring_utilization": float64(pm.GetMaxKafkaPartitions()*pm.GetRangeSize()) / float64(pub_balancer.MaxPartitionCount), - } -} - -// Global instance for consistent usage across the test codebase -var DefaultPartitionMapper = NewPartitionMapper() - -func TestPartitionMapper_GetRangeSize(t *testing.T) { - mapper := NewPartitionMapper() - rangeSize := mapper.GetRangeSize() - - if rangeSize != 35 { - t.Errorf("Expected range size 35, got %d", rangeSize) - } - - // Verify that the range size divides evenly into available partitions - maxPartitions := mapper.GetMaxKafkaPartitions() - totalUsed := maxPartitions * rangeSize - - if totalUsed > int32(pub_balancer.MaxPartitionCount) { - t.Errorf("Total used slots (%d) exceeds MaxPartitionCount (%d)", totalUsed, pub_balancer.MaxPartitionCount) - } - - t.Logf("Range size: %d, Max Kafka partitions: %d, Ring utilization: %.2f%%", - rangeSize, maxPartitions, float64(totalUsed)/float64(pub_balancer.MaxPartitionCount)*100) -} - -func TestPartitionMapper_MapKafkaPartitionToSMQRange(t *testing.T) { - mapper := NewPartitionMapper() - - tests := []struct { - kafkaPartition int32 - expectedStart int32 - expectedStop int32 - }{ - {0, 0, 34}, - {1, 35, 69}, - {2, 70, 104}, - {10, 350, 384}, - } - - for _, tt := range tests { - t.Run("", func(t *testing.T) { - start, stop := mapper.MapKafkaPartitionToSMQRange(tt.kafkaPartition) - - if start != tt.expectedStart { - t.Errorf("Kafka partition %d: expected start %d, got %d", tt.kafkaPartition, tt.expectedStart, start) - } - - if stop != tt.expectedStop { - t.Errorf("Kafka partition %d: expected stop %d, got %d", tt.kafkaPartition, tt.expectedStop, stop) - } - - // Verify range size is consistent - rangeSize := stop - start + 1 - if rangeSize != mapper.GetRangeSize() { - t.Errorf("Inconsistent range size: expected %d, got %d", mapper.GetRangeSize(), rangeSize) - } - }) - } -} - -func TestPartitionMapper_ExtractKafkaPartitionFromSMQRange(t *testing.T) { - mapper := NewPartitionMapper() - - tests := []struct { - rangeStart int32 - expectedKafka int32 - }{ - {0, 0}, - {35, 1}, - {70, 2}, - {350, 10}, - } - - for _, tt := range tests { - t.Run("", func(t *testing.T) { - kafkaPartition := mapper.ExtractKafkaPartitionFromSMQRange(tt.rangeStart) - - if kafkaPartition != tt.expectedKafka { - t.Errorf("Range start %d: expected Kafka partition %d, got %d", - tt.rangeStart, tt.expectedKafka, kafkaPartition) - } - }) - } -} - -func TestPartitionMapper_RoundTrip(t *testing.T) { - mapper := NewPartitionMapper() - - // Test round-trip conversion for all valid Kafka partitions - maxPartitions := mapper.GetMaxKafkaPartitions() - - for kafkaPartition := int32(0); kafkaPartition < maxPartitions; kafkaPartition++ { - // Kafka -> SMQ -> Kafka - rangeStart, rangeStop := mapper.MapKafkaPartitionToSMQRange(kafkaPartition) - extractedKafka := mapper.ExtractKafkaPartitionFromSMQRange(rangeStart) - - if extractedKafka != kafkaPartition { - t.Errorf("Round-trip failed for partition %d: got %d", kafkaPartition, extractedKafka) - } - - // Verify no overlap with next partition - if kafkaPartition < maxPartitions-1 { - nextStart, _ := mapper.MapKafkaPartitionToSMQRange(kafkaPartition + 1) - if rangeStop >= nextStart { - t.Errorf("Partition %d range [%d,%d] overlaps with partition %d start %d", - kafkaPartition, rangeStart, rangeStop, kafkaPartition+1, nextStart) - } - } - } -} - -func TestPartitionMapper_CreateSMQPartition(t *testing.T) { - mapper := NewPartitionMapper() - - kafkaPartition := int32(5) - unixTimeNs := time.Now().UnixNano() - - partition := mapper.CreateSMQPartition(kafkaPartition, unixTimeNs) - - if partition.RingSize != pub_balancer.MaxPartitionCount { - t.Errorf("Expected ring size %d, got %d", pub_balancer.MaxPartitionCount, partition.RingSize) - } - - expectedStart, expectedStop := mapper.MapKafkaPartitionToSMQRange(kafkaPartition) - if partition.RangeStart != expectedStart { - t.Errorf("Expected range start %d, got %d", expectedStart, partition.RangeStart) - } - - if partition.RangeStop != expectedStop { - t.Errorf("Expected range stop %d, got %d", expectedStop, partition.RangeStop) - } - - if partition.UnixTimeNs != unixTimeNs { - t.Errorf("Expected timestamp %d, got %d", unixTimeNs, partition.UnixTimeNs) - } -} - -func TestPartitionMapper_ValidateKafkaPartition(t *testing.T) { - mapper := NewPartitionMapper() - - tests := []struct { - partition int32 - valid bool - }{ - {-1, false}, - {0, true}, - {1, true}, - {mapper.GetMaxKafkaPartitions() - 1, true}, - {mapper.GetMaxKafkaPartitions(), false}, - {1000, false}, - } - - for _, tt := range tests { - t.Run("", func(t *testing.T) { - valid := mapper.ValidateKafkaPartition(tt.partition) - if valid != tt.valid { - t.Errorf("Partition %d: expected valid=%v, got %v", tt.partition, tt.valid, valid) - } - }) - } -} - -func TestPartitionMapper_ConsistencyWithGlobalFunctions(t *testing.T) { - mapper := NewPartitionMapper() - - kafkaPartition := int32(7) - unixTimeNs := time.Now().UnixNano() - - // Test that global functions produce same results as mapper methods - start1, stop1 := mapper.MapKafkaPartitionToSMQRange(kafkaPartition) - start2, stop2 := MapKafkaPartitionToSMQRange(kafkaPartition) - - if start1 != start2 || stop1 != stop2 { - t.Errorf("Global function inconsistent: mapper=(%d,%d), global=(%d,%d)", - start1, stop1, start2, stop2) - } - - partition1 := mapper.CreateSMQPartition(kafkaPartition, unixTimeNs) - partition2 := CreateSMQPartition(kafkaPartition, unixTimeNs) - - if partition1.RangeStart != partition2.RangeStart || partition1.RangeStop != partition2.RangeStop { - t.Errorf("Global CreateSMQPartition inconsistent") - } - - extracted1 := mapper.ExtractKafkaPartitionFromSMQRange(start1) - extracted2 := ExtractKafkaPartitionFromSMQRange(start1) - - if extracted1 != extracted2 { - t.Errorf("Global ExtractKafkaPartitionFromSMQRange inconsistent: %d vs %d", extracted1, extracted2) - } -} - -func TestPartitionMapper_GetPartitionMappingInfo(t *testing.T) { - mapper := NewPartitionMapper() - - info := mapper.GetPartitionMappingInfo() - - // Verify all expected keys are present - expectedKeys := []string{"ring_size", "range_size", "max_kafka_partitions", "ring_utilization"} - for _, key := range expectedKeys { - if _, exists := info[key]; !exists { - t.Errorf("Missing key in mapping info: %s", key) - } - } - - // Verify values are reasonable - if info["ring_size"].(int) != pub_balancer.MaxPartitionCount { - t.Errorf("Incorrect ring_size in info") - } - - if info["range_size"].(int32) != mapper.GetRangeSize() { - t.Errorf("Incorrect range_size in info") - } - - utilization := info["ring_utilization"].(float64) - if utilization <= 0 || utilization > 1 { - t.Errorf("Invalid ring utilization: %f", utilization) - } - - t.Logf("Partition mapping info: %+v", info) -} diff --git a/weed/mq/offset/benchmark_test.go b/weed/mq/offset/benchmark_test.go index 0fdacf127..8c78cc0f8 100644 --- a/weed/mq/offset/benchmark_test.go +++ b/weed/mq/offset/benchmark_test.go @@ -2,11 +2,9 @@ package offset import ( "fmt" - "os" "testing" "time" - _ "github.com/mattn/go-sqlite3" "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" ) @@ -62,151 +60,6 @@ func BenchmarkBatchOffsetAssignment(b *testing.B) { } } -// BenchmarkSQLOffsetStorage benchmarks SQL storage operations -func BenchmarkSQLOffsetStorage(b *testing.B) { - // Create temporary database - tmpFile, err := os.CreateTemp("", "benchmark_*.db") - if err != nil { - b.Fatalf("Failed to create temp database: %v", err) - } - tmpFile.Close() - defer os.Remove(tmpFile.Name()) - - db, err := CreateDatabase(tmpFile.Name()) - if err != nil { - b.Fatalf("Failed to create database: %v", err) - } - defer db.Close() - - storage, err := NewSQLOffsetStorage(db) - if err != nil { - b.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - partition := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: time.Now().UnixNano(), - } - - partitionKey := partitionKey(partition) - - b.Run("SaveCheckpoint", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - storage.SaveCheckpoint("test-namespace", "test-topic", partition, int64(i)) - } - }) - - b.Run("LoadCheckpoint", func(b *testing.B) { - storage.SaveCheckpoint("test-namespace", "test-topic", partition, 1000) - b.ResetTimer() - for i := 0; i < b.N; i++ { - storage.LoadCheckpoint("test-namespace", "test-topic", partition) - } - }) - - b.Run("SaveOffsetMapping", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), 100) - } - }) - - // Pre-populate for read benchmarks - for i := 0; i < 1000; i++ { - storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), 100) - } - - b.Run("GetHighestOffset", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - storage.GetHighestOffset("test-namespace", "test-topic", partition) - } - }) - - b.Run("LoadOffsetMappings", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - storage.LoadOffsetMappings(partitionKey) - } - }) - - b.Run("GetOffsetMappingsByRange", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - start := int64(i % 900) - end := start + 100 - storage.GetOffsetMappingsByRange(partitionKey, start, end) - } - }) - - b.Run("GetPartitionStats", func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - storage.GetPartitionStats(partitionKey) - } - }) -} - -// BenchmarkInMemoryVsSQL compares in-memory and SQL storage performance -func BenchmarkInMemoryVsSQL(b *testing.B) { - partition := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: time.Now().UnixNano(), - } - - // In-memory storage benchmark - b.Run("InMemory", func(b *testing.B) { - storage := NewInMemoryOffsetStorage() - manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage) - if err != nil { - b.Fatalf("Failed to create partition manager: %v", err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - manager.AssignOffset() - } - }) - - // SQL storage benchmark - b.Run("SQL", func(b *testing.B) { - tmpFile, err := os.CreateTemp("", "benchmark_sql_*.db") - if err != nil { - b.Fatalf("Failed to create temp database: %v", err) - } - tmpFile.Close() - defer os.Remove(tmpFile.Name()) - - db, err := CreateDatabase(tmpFile.Name()) - if err != nil { - b.Fatalf("Failed to create database: %v", err) - } - defer db.Close() - - storage, err := NewSQLOffsetStorage(db) - if err != nil { - b.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage) - if err != nil { - b.Fatalf("Failed to create partition manager: %v", err) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - manager.AssignOffset() - } - }) -} - // BenchmarkOffsetSubscription benchmarks subscription operations func BenchmarkOffsetSubscription(b *testing.B) { storage := NewInMemoryOffsetStorage() diff --git a/weed/mq/offset/end_to_end_test.go b/weed/mq/offset/end_to_end_test.go deleted file mode 100644 index f2b57b843..000000000 --- a/weed/mq/offset/end_to_end_test.go +++ /dev/null @@ -1,473 +0,0 @@ -package offset - -import ( - "fmt" - "os" - "testing" - "time" - - _ "github.com/mattn/go-sqlite3" - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" -) - -// TestEndToEndOffsetFlow tests the complete offset management flow -func TestEndToEndOffsetFlow(t *testing.T) { - // Create temporary database - tmpFile, err := os.CreateTemp("", "e2e_offset_test_*.db") - if err != nil { - t.Fatalf("Failed to create temp database: %v", err) - } - tmpFile.Close() - defer os.Remove(tmpFile.Name()) - - // Create database with migrations - db, err := CreateDatabase(tmpFile.Name()) - if err != nil { - t.Fatalf("Failed to create database: %v", err) - } - defer db.Close() - - // Create SQL storage - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - // Create SMQ offset integration - integration := NewSMQOffsetIntegration(storage) - - // Test partition - partition := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: time.Now().UnixNano(), - } - - t.Run("PublishAndAssignOffsets", func(t *testing.T) { - // Simulate publishing messages with offset assignment - records := []PublishRecordRequest{ - {Key: []byte("user1"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("user2"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("user3"), Value: &schema_pb.RecordValue{}}, - } - - response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records) - if err != nil { - t.Fatalf("Failed to publish record batch: %v", err) - } - - if response.BaseOffset != 0 { - t.Errorf("Expected base offset 0, got %d", response.BaseOffset) - } - - if response.LastOffset != 2 { - t.Errorf("Expected last offset 2, got %d", response.LastOffset) - } - - // Verify high water mark - hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get high water mark: %v", err) - } - - if hwm != 3 { - t.Errorf("Expected high water mark 3, got %d", hwm) - } - }) - - t.Run("CreateAndUseSubscription", func(t *testing.T) { - // Create subscription from earliest - sub, err := integration.CreateSubscription( - "e2e-test-sub", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_RESET_TO_EARLIEST, - 0, - ) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Subscribe to records - responses, err := integration.SubscribeRecords(sub, 2) - if err != nil { - t.Fatalf("Failed to subscribe to records: %v", err) - } - - if len(responses) != 2 { - t.Errorf("Expected 2 responses, got %d", len(responses)) - } - - // Check subscription advancement - if sub.CurrentOffset != 2 { - t.Errorf("Expected current offset 2, got %d", sub.CurrentOffset) - } - - // Get subscription lag - lag, err := sub.GetLag() - if err != nil { - t.Fatalf("Failed to get lag: %v", err) - } - - if lag != 1 { // 3 (hwm) - 2 (current) = 1 - t.Errorf("Expected lag 1, got %d", lag) - } - }) - - t.Run("OffsetSeekingAndRanges", func(t *testing.T) { - // Create subscription at specific offset - sub, err := integration.CreateSubscription( - "seek-test-sub", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_EXACT_OFFSET, - 1, - ) - if err != nil { - t.Fatalf("Failed to create subscription at offset 1: %v", err) - } - - // Verify starting position - if sub.CurrentOffset != 1 { - t.Errorf("Expected current offset 1, got %d", sub.CurrentOffset) - } - - // Get offset range - offsetRange, err := sub.GetOffsetRange(2) - if err != nil { - t.Fatalf("Failed to get offset range: %v", err) - } - - if offsetRange.StartOffset != 1 { - t.Errorf("Expected start offset 1, got %d", offsetRange.StartOffset) - } - - if offsetRange.Count != 2 { - t.Errorf("Expected count 2, got %d", offsetRange.Count) - } - - // Seek to different offset - err = sub.SeekToOffset(0) - if err != nil { - t.Fatalf("Failed to seek to offset 0: %v", err) - } - - if sub.CurrentOffset != 0 { - t.Errorf("Expected current offset 0 after seek, got %d", sub.CurrentOffset) - } - }) - - t.Run("PartitionInformationAndMetrics", func(t *testing.T) { - // Get partition offset info - info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get partition offset info: %v", err) - } - - if info.EarliestOffset != 0 { - t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset) - } - - if info.LatestOffset != 2 { - t.Errorf("Expected latest offset 2, got %d", info.LatestOffset) - } - - if info.HighWaterMark != 3 { - t.Errorf("Expected high water mark 3, got %d", info.HighWaterMark) - } - - if info.ActiveSubscriptions != 2 { // Two subscriptions created above - t.Errorf("Expected 2 active subscriptions, got %d", info.ActiveSubscriptions) - } - - // Get offset metrics - metrics := integration.GetOffsetMetrics() - if metrics.PartitionCount != 1 { - t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount) - } - - if metrics.ActiveSubscriptions != 2 { - t.Errorf("Expected 2 active subscriptions in metrics, got %d", metrics.ActiveSubscriptions) - } - }) -} - -// TestOffsetPersistenceAcrossRestarts tests that offsets persist across system restarts -func TestOffsetPersistenceAcrossRestarts(t *testing.T) { - // Create temporary database - tmpFile, err := os.CreateTemp("", "persistence_test_*.db") - if err != nil { - t.Fatalf("Failed to create temp database: %v", err) - } - tmpFile.Close() - defer os.Remove(tmpFile.Name()) - - partition := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: time.Now().UnixNano(), - } - - var lastOffset int64 - - // First session: Create database and assign offsets - { - db, err := CreateDatabase(tmpFile.Name()) - if err != nil { - t.Fatalf("Failed to create database: %v", err) - } - - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - - integration := NewSMQOffsetIntegration(storage) - - // Publish some records - records := []PublishRecordRequest{ - {Key: []byte("msg1"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("msg2"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("msg3"), Value: &schema_pb.RecordValue{}}, - } - - response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records) - if err != nil { - t.Fatalf("Failed to publish records: %v", err) - } - - lastOffset = response.LastOffset - - // Close connections - Close integration first to trigger final checkpoint - integration.Close() - storage.Close() - db.Close() - } - - // Second session: Reopen database and verify persistence - { - db, err := CreateDatabase(tmpFile.Name()) - if err != nil { - t.Fatalf("Failed to reopen database: %v", err) - } - defer db.Close() - - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - integration := NewSMQOffsetIntegration(storage) - - // Verify high water mark persisted - hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get high water mark after restart: %v", err) - } - - if hwm != lastOffset+1 { - t.Errorf("Expected high water mark %d after restart, got %d", lastOffset+1, hwm) - } - - // Assign new offsets and verify continuity - newResponse, err := integration.PublishRecord("test-namespace", "test-topic", partition, []byte("msg4"), &schema_pb.RecordValue{}) - if err != nil { - t.Fatalf("Failed to publish new record after restart: %v", err) - } - - expectedNextOffset := lastOffset + 1 - if newResponse.BaseOffset != expectedNextOffset { - t.Errorf("Expected next offset %d after restart, got %d", expectedNextOffset, newResponse.BaseOffset) - } - } -} - -// TestConcurrentOffsetOperations tests concurrent offset operations -func TestConcurrentOffsetOperations(t *testing.T) { - // Create temporary database - tmpFile, err := os.CreateTemp("", "concurrent_test_*.db") - if err != nil { - t.Fatalf("Failed to create temp database: %v", err) - } - tmpFile.Close() - defer os.Remove(tmpFile.Name()) - - db, err := CreateDatabase(tmpFile.Name()) - if err != nil { - t.Fatalf("Failed to create database: %v", err) - } - defer db.Close() - - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - integration := NewSMQOffsetIntegration(storage) - - partition := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: time.Now().UnixNano(), - } - - // Concurrent publishers - const numPublishers = 5 - const recordsPerPublisher = 10 - - done := make(chan bool, numPublishers) - - for i := 0; i < numPublishers; i++ { - go func(publisherID int) { - defer func() { done <- true }() - - for j := 0; j < recordsPerPublisher; j++ { - key := fmt.Sprintf("publisher-%d-msg-%d", publisherID, j) - _, err := integration.PublishRecord("test-namespace", "test-topic", partition, []byte(key), &schema_pb.RecordValue{}) - if err != nil { - t.Errorf("Publisher %d failed to publish message %d: %v", publisherID, j, err) - return - } - } - }(i) - } - - // Wait for all publishers to complete - for i := 0; i < numPublishers; i++ { - <-done - } - - // Verify total records - hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get high water mark: %v", err) - } - - expectedTotal := int64(numPublishers * recordsPerPublisher) - if hwm != expectedTotal { - t.Errorf("Expected high water mark %d, got %d", expectedTotal, hwm) - } - - // Verify no duplicate offsets - info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get partition info: %v", err) - } - - if info.RecordCount != expectedTotal { - t.Errorf("Expected record count %d, got %d", expectedTotal, info.RecordCount) - } -} - -// TestOffsetValidationAndErrorHandling tests error conditions and validation -func TestOffsetValidationAndErrorHandling(t *testing.T) { - // Create temporary database - tmpFile, err := os.CreateTemp("", "validation_test_*.db") - if err != nil { - t.Fatalf("Failed to create temp database: %v", err) - } - tmpFile.Close() - defer os.Remove(tmpFile.Name()) - - db, err := CreateDatabase(tmpFile.Name()) - if err != nil { - t.Fatalf("Failed to create database: %v", err) - } - defer db.Close() - - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - integration := NewSMQOffsetIntegration(storage) - - partition := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: time.Now().UnixNano(), - } - - t.Run("InvalidOffsetSubscription", func(t *testing.T) { - // Try to create subscription with invalid offset - _, err := integration.CreateSubscription( - "invalid-sub", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_EXACT_OFFSET, - 100, // Beyond any existing data - ) - if err == nil { - t.Error("Expected error for subscription beyond high water mark") - } - }) - - t.Run("NegativeOffsetValidation", func(t *testing.T) { - // Try to create subscription with negative offset - _, err := integration.CreateSubscription( - "negative-sub", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_EXACT_OFFSET, - -1, - ) - if err == nil { - t.Error("Expected error for negative offset") - } - }) - - t.Run("DuplicateSubscriptionID", func(t *testing.T) { - // Create first subscription - _, err := integration.CreateSubscription( - "duplicate-id", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_RESET_TO_EARLIEST, - 0, - ) - if err != nil { - t.Fatalf("Failed to create first subscription: %v", err) - } - - // Try to create duplicate - _, err = integration.CreateSubscription( - "duplicate-id", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_RESET_TO_EARLIEST, - 0, - ) - if err == nil { - t.Error("Expected error for duplicate subscription ID") - } - }) - - t.Run("OffsetRangeValidation", func(t *testing.T) { - // Add some data first - integration.PublishRecord("test-namespace", "test-topic", partition, []byte("test"), &schema_pb.RecordValue{}) - - // Test invalid range validation - err := integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 5, 10) // Beyond high water mark - if err == nil { - t.Error("Expected error for range beyond high water mark") - } - - err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 10, 5) // End before start - if err == nil { - t.Error("Expected error for end offset before start offset") - } - - err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, -1, 5) // Negative start - if err == nil { - t.Error("Expected error for negative start offset") - } - }) -} diff --git a/weed/mq/offset/filer_storage.go b/weed/mq/offset/filer_storage.go index 6f1a71e39..10b54bae3 100644 --- a/weed/mq/offset/filer_storage.go +++ b/weed/mq/offset/filer_storage.go @@ -93,9 +93,3 @@ func (f *FilerOffsetStorage) getPartitionDir(namespace, topicName string, partit return fmt.Sprintf("%s/%s/%s/%s/%s", filer.TopicsDir, namespace, topicName, version, partitionRange) } - -// getPartitionKey generates a unique key for a partition -func (f *FilerOffsetStorage) getPartitionKey(partition *schema_pb.Partition) string { - return fmt.Sprintf("ring:%d:range:%d-%d:time:%d", - partition.RingSize, partition.RangeStart, partition.RangeStop, partition.UnixTimeNs) -} diff --git a/weed/mq/offset/integration_test.go b/weed/mq/offset/integration_test.go deleted file mode 100644 index 35299be65..000000000 --- a/weed/mq/offset/integration_test.go +++ /dev/null @@ -1,544 +0,0 @@ -package offset - -import ( - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" -) - -func TestSMQOffsetIntegration_PublishRecord(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Publish a single record - response, err := integration.PublishRecord( - "test-namespace", "test-topic", - partition, - []byte("test-key"), - &schema_pb.RecordValue{}, - ) - - if err != nil { - t.Fatalf("Failed to publish record: %v", err) - } - - if response.Error != "" { - t.Errorf("Expected no error, got: %s", response.Error) - } - - if response.BaseOffset != 0 { - t.Errorf("Expected base offset 0, got %d", response.BaseOffset) - } - - if response.LastOffset != 0 { - t.Errorf("Expected last offset 0, got %d", response.LastOffset) - } -} - -func TestSMQOffsetIntegration_PublishRecordBatch(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Create batch of records - records := []PublishRecordRequest{ - {Key: []byte("key1"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key2"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key3"), Value: &schema_pb.RecordValue{}}, - } - - // Publish batch - response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records) - if err != nil { - t.Fatalf("Failed to publish record batch: %v", err) - } - - if response.Error != "" { - t.Errorf("Expected no error, got: %s", response.Error) - } - - if response.BaseOffset != 0 { - t.Errorf("Expected base offset 0, got %d", response.BaseOffset) - } - - if response.LastOffset != 2 { - t.Errorf("Expected last offset 2, got %d", response.LastOffset) - } - - // Verify high water mark - hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get high water mark: %v", err) - } - - if hwm != 3 { - t.Errorf("Expected high water mark 3, got %d", hwm) - } -} - -func TestSMQOffsetIntegration_EmptyBatch(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Publish empty batch - response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, []PublishRecordRequest{}) - if err != nil { - t.Fatalf("Failed to publish empty batch: %v", err) - } - - if response.Error == "" { - t.Error("Expected error for empty batch") - } -} - -func TestSMQOffsetIntegration_CreateSubscription(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Publish some records first - records := []PublishRecordRequest{ - {Key: []byte("key1"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key2"), Value: &schema_pb.RecordValue{}}, - } - integration.PublishRecordBatch("test-namespace", "test-topic", partition, records) - - // Create subscription - sub, err := integration.CreateSubscription( - "test-sub", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_RESET_TO_EARLIEST, - 0, - ) - - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - if sub.ID != "test-sub" { - t.Errorf("Expected subscription ID 'test-sub', got %s", sub.ID) - } - - if sub.StartOffset != 0 { - t.Errorf("Expected start offset 0, got %d", sub.StartOffset) - } -} - -func TestSMQOffsetIntegration_SubscribeRecords(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Publish some records - records := []PublishRecordRequest{ - {Key: []byte("key1"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key2"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key3"), Value: &schema_pb.RecordValue{}}, - } - integration.PublishRecordBatch("test-namespace", "test-topic", partition, records) - - // Create subscription - sub, err := integration.CreateSubscription( - "test-sub", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_RESET_TO_EARLIEST, - 0, - ) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Subscribe to records - responses, err := integration.SubscribeRecords(sub, 2) - if err != nil { - t.Fatalf("Failed to subscribe to records: %v", err) - } - - if len(responses) != 2 { - t.Errorf("Expected 2 responses, got %d", len(responses)) - } - - // Check offset progression - if responses[0].Offset != 0 { - t.Errorf("Expected first record offset 0, got %d", responses[0].Offset) - } - - if responses[1].Offset != 1 { - t.Errorf("Expected second record offset 1, got %d", responses[1].Offset) - } - - // Check subscription advancement - if sub.CurrentOffset != 2 { - t.Errorf("Expected subscription current offset 2, got %d", sub.CurrentOffset) - } -} - -func TestSMQOffsetIntegration_SubscribeEmptyPartition(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Create subscription on empty partition - sub, err := integration.CreateSubscription( - "empty-sub", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_RESET_TO_EARLIEST, - 0, - ) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Subscribe to records (should return empty) - responses, err := integration.SubscribeRecords(sub, 10) - if err != nil { - t.Fatalf("Failed to subscribe to empty partition: %v", err) - } - - if len(responses) != 0 { - t.Errorf("Expected 0 responses from empty partition, got %d", len(responses)) - } -} - -func TestSMQOffsetIntegration_SeekSubscription(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Publish records - records := []PublishRecordRequest{ - {Key: []byte("key1"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key2"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key3"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key4"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key5"), Value: &schema_pb.RecordValue{}}, - } - integration.PublishRecordBatch("test-namespace", "test-topic", partition, records) - - // Create subscription - sub, err := integration.CreateSubscription( - "seek-sub", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_RESET_TO_EARLIEST, - 0, - ) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Seek to offset 3 - err = integration.SeekSubscription("seek-sub", 3) - if err != nil { - t.Fatalf("Failed to seek subscription: %v", err) - } - - if sub.CurrentOffset != 3 { - t.Errorf("Expected current offset 3 after seek, got %d", sub.CurrentOffset) - } - - // Subscribe from new position - responses, err := integration.SubscribeRecords(sub, 2) - if err != nil { - t.Fatalf("Failed to subscribe after seek: %v", err) - } - - if len(responses) != 2 { - t.Errorf("Expected 2 responses after seek, got %d", len(responses)) - } - - if responses[0].Offset != 3 { - t.Errorf("Expected first record offset 3 after seek, got %d", responses[0].Offset) - } -} - -func TestSMQOffsetIntegration_GetSubscriptionLag(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Publish records - records := []PublishRecordRequest{ - {Key: []byte("key1"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key2"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key3"), Value: &schema_pb.RecordValue{}}, - } - integration.PublishRecordBatch("test-namespace", "test-topic", partition, records) - - // Create subscription at offset 1 - sub, err := integration.CreateSubscription( - "lag-sub", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_EXACT_OFFSET, - 1, - ) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Get lag - lag, err := integration.GetSubscriptionLag("lag-sub") - if err != nil { - t.Fatalf("Failed to get subscription lag: %v", err) - } - - expectedLag := int64(3 - 1) // hwm - current - if lag != expectedLag { - t.Errorf("Expected lag %d, got %d", expectedLag, lag) - } - - // Advance subscription and check lag again - integration.SubscribeRecords(sub, 1) - - lag, err = integration.GetSubscriptionLag("lag-sub") - if err != nil { - t.Fatalf("Failed to get lag after advance: %v", err) - } - - expectedLag = int64(3 - 2) // hwm - current - if lag != expectedLag { - t.Errorf("Expected lag %d after advance, got %d", expectedLag, lag) - } -} - -func TestSMQOffsetIntegration_CloseSubscription(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Create subscription - _, err := integration.CreateSubscription( - "close-sub", - "test-namespace", "test-topic", - partition, - schema_pb.OffsetType_RESET_TO_EARLIEST, - 0, - ) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Close subscription - err = integration.CloseSubscription("close-sub") - if err != nil { - t.Fatalf("Failed to close subscription: %v", err) - } - - // Try to get lag (should fail) - _, err = integration.GetSubscriptionLag("close-sub") - if err == nil { - t.Error("Expected error when getting lag for closed subscription") - } -} - -func TestSMQOffsetIntegration_ValidateOffsetRange(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Publish some records - records := []PublishRecordRequest{ - {Key: []byte("key1"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key2"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key3"), Value: &schema_pb.RecordValue{}}, - } - integration.PublishRecordBatch("test-namespace", "test-topic", partition, records) - - // Test valid range - err := integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 0, 2) - if err != nil { - t.Errorf("Valid range should not return error: %v", err) - } - - // Test invalid range (beyond hwm) - err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 0, 5) - if err == nil { - t.Error("Expected error for range beyond high water mark") - } -} - -func TestSMQOffsetIntegration_GetAvailableOffsetRange(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Test empty partition - offsetRange, err := integration.GetAvailableOffsetRange("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get available range for empty partition: %v", err) - } - - if offsetRange.Count != 0 { - t.Errorf("Expected empty range for empty partition, got count %d", offsetRange.Count) - } - - // Publish records - records := []PublishRecordRequest{ - {Key: []byte("key1"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key2"), Value: &schema_pb.RecordValue{}}, - } - integration.PublishRecordBatch("test-namespace", "test-topic", partition, records) - - // Test with data - offsetRange, err = integration.GetAvailableOffsetRange("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get available range: %v", err) - } - - if offsetRange.StartOffset != 0 { - t.Errorf("Expected start offset 0, got %d", offsetRange.StartOffset) - } - - if offsetRange.EndOffset != 1 { - t.Errorf("Expected end offset 1, got %d", offsetRange.EndOffset) - } - - if offsetRange.Count != 2 { - t.Errorf("Expected count 2, got %d", offsetRange.Count) - } -} - -func TestSMQOffsetIntegration_GetOffsetMetrics(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Initial metrics - metrics := integration.GetOffsetMetrics() - if metrics.TotalOffsets != 0 { - t.Errorf("Expected 0 total offsets initially, got %d", metrics.TotalOffsets) - } - - if metrics.ActiveSubscriptions != 0 { - t.Errorf("Expected 0 active subscriptions initially, got %d", metrics.ActiveSubscriptions) - } - - // Publish records - records := []PublishRecordRequest{ - {Key: []byte("key1"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key2"), Value: &schema_pb.RecordValue{}}, - } - integration.PublishRecordBatch("test-namespace", "test-topic", partition, records) - - // Create subscriptions - integration.CreateSubscription("sub1", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0) - integration.CreateSubscription("sub2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0) - - // Check updated metrics - metrics = integration.GetOffsetMetrics() - if metrics.TotalOffsets != 2 { - t.Errorf("Expected 2 total offsets, got %d", metrics.TotalOffsets) - } - - if metrics.ActiveSubscriptions != 2 { - t.Errorf("Expected 2 active subscriptions, got %d", metrics.ActiveSubscriptions) - } - - if metrics.PartitionCount != 1 { - t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount) - } -} - -func TestSMQOffsetIntegration_GetOffsetInfo(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Test non-existent offset - info, err := integration.GetOffsetInfo("test-namespace", "test-topic", partition, 0) - if err != nil { - t.Fatalf("Failed to get offset info: %v", err) - } - - if info.Exists { - t.Error("Offset should not exist in empty partition") - } - - // Publish record - integration.PublishRecord("test-namespace", "test-topic", partition, []byte("key1"), &schema_pb.RecordValue{}) - - // Test existing offset - info, err = integration.GetOffsetInfo("test-namespace", "test-topic", partition, 0) - if err != nil { - t.Fatalf("Failed to get offset info for existing offset: %v", err) - } - - if !info.Exists { - t.Error("Offset should exist after publishing") - } - - if info.Offset != 0 { - t.Errorf("Expected offset 0, got %d", info.Offset) - } -} - -func TestSMQOffsetIntegration_GetPartitionOffsetInfo(t *testing.T) { - storage := NewInMemoryOffsetStorage() - integration := NewSMQOffsetIntegration(storage) - partition := createTestPartition() - - // Test empty partition - info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get partition offset info: %v", err) - } - - if info.EarliestOffset != 0 { - t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset) - } - - if info.LatestOffset != -1 { - t.Errorf("Expected latest offset -1 for empty partition, got %d", info.LatestOffset) - } - - if info.HighWaterMark != 0 { - t.Errorf("Expected high water mark 0, got %d", info.HighWaterMark) - } - - if info.RecordCount != 0 { - t.Errorf("Expected record count 0, got %d", info.RecordCount) - } - - // Publish records - records := []PublishRecordRequest{ - {Key: []byte("key1"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key2"), Value: &schema_pb.RecordValue{}}, - {Key: []byte("key3"), Value: &schema_pb.RecordValue{}}, - } - integration.PublishRecordBatch("test-namespace", "test-topic", partition, records) - - // Create subscription - integration.CreateSubscription("test-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0) - - // Test with data - info, err = integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get partition offset info with data: %v", err) - } - - if info.EarliestOffset != 0 { - t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset) - } - - if info.LatestOffset != 2 { - t.Errorf("Expected latest offset 2, got %d", info.LatestOffset) - } - - if info.HighWaterMark != 3 { - t.Errorf("Expected high water mark 3, got %d", info.HighWaterMark) - } - - if info.RecordCount != 3 { - t.Errorf("Expected record count 3, got %d", info.RecordCount) - } - - if info.ActiveSubscriptions != 1 { - t.Errorf("Expected 1 active subscription, got %d", info.ActiveSubscriptions) - } -} diff --git a/weed/mq/offset/manager.go b/weed/mq/offset/manager.go index 53388d82f..b78307f3a 100644 --- a/weed/mq/offset/manager.go +++ b/weed/mq/offset/manager.go @@ -338,13 +338,6 @@ type OffsetAssigner struct { registry *PartitionOffsetRegistry } -// NewOffsetAssigner creates a new offset assigner -func NewOffsetAssigner(storage OffsetStorage) *OffsetAssigner { - return &OffsetAssigner{ - registry: NewPartitionOffsetRegistry(storage), - } -} - // AssignSingleOffset assigns a single offset with timestamp func (a *OffsetAssigner) AssignSingleOffset(namespace, topicName string, partition *schema_pb.Partition) *AssignmentResult { offset, err := a.registry.AssignOffset(namespace, topicName, partition) diff --git a/weed/mq/offset/manager_test.go b/weed/mq/offset/manager_test.go deleted file mode 100644 index 0db301e84..000000000 --- a/weed/mq/offset/manager_test.go +++ /dev/null @@ -1,388 +0,0 @@ -package offset - -import ( - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" -) - -func createTestPartition() *schema_pb.Partition { - return &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: time.Now().UnixNano(), - } -} - -func TestPartitionOffsetManager_BasicAssignment(t *testing.T) { - storage := NewInMemoryOffsetStorage() - partition := createTestPartition() - - manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage) - if err != nil { - t.Fatalf("Failed to create offset manager: %v", err) - } - - // Test sequential offset assignment - for i := int64(0); i < 10; i++ { - offset := manager.AssignOffset() - if offset != i { - t.Errorf("Expected offset %d, got %d", i, offset) - } - } - - // Test high water mark - hwm := manager.GetHighWaterMark() - if hwm != 10 { - t.Errorf("Expected high water mark 10, got %d", hwm) - } -} - -func TestPartitionOffsetManager_BatchAssignment(t *testing.T) { - storage := NewInMemoryOffsetStorage() - partition := createTestPartition() - - manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage) - if err != nil { - t.Fatalf("Failed to create offset manager: %v", err) - } - - // Assign batch of 5 offsets - baseOffset, lastOffset := manager.AssignOffsets(5) - if baseOffset != 0 { - t.Errorf("Expected base offset 0, got %d", baseOffset) - } - if lastOffset != 4 { - t.Errorf("Expected last offset 4, got %d", lastOffset) - } - - // Assign another batch - baseOffset, lastOffset = manager.AssignOffsets(3) - if baseOffset != 5 { - t.Errorf("Expected base offset 5, got %d", baseOffset) - } - if lastOffset != 7 { - t.Errorf("Expected last offset 7, got %d", lastOffset) - } - - // Check high water mark - hwm := manager.GetHighWaterMark() - if hwm != 8 { - t.Errorf("Expected high water mark 8, got %d", hwm) - } -} - -func TestPartitionOffsetManager_Recovery(t *testing.T) { - storage := NewInMemoryOffsetStorage() - partition := createTestPartition() - - // Create manager and assign some offsets - manager1, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage) - if err != nil { - t.Fatalf("Failed to create offset manager: %v", err) - } - - // Assign offsets and simulate records - for i := 0; i < 150; i++ { // More than checkpoint interval - offset := manager1.AssignOffset() - storage.AddRecord("test-namespace", "test-topic", partition, offset) - } - - // Wait for checkpoint to complete - time.Sleep(100 * time.Millisecond) - - // Create new manager (simulates restart) - manager2, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage) - if err != nil { - t.Fatalf("Failed to create offset manager after recovery: %v", err) - } - - // Next offset should continue from checkpoint + 1 - // With checkpoint interval 100, checkpoint happens at offset 100 - // So recovery should start from 101, but we assigned 150 offsets (0-149) - // The checkpoint should be at 100, so next offset should be 101 - // But since we have records up to 149, it should recover from storage scan - nextOffset := manager2.AssignOffset() - if nextOffset != 150 { - t.Errorf("Expected next offset 150 after recovery, got %d", nextOffset) - } -} - -func TestPartitionOffsetManager_RecoveryFromStorage(t *testing.T) { - storage := NewInMemoryOffsetStorage() - partition := createTestPartition() - - // Simulate existing records in storage without checkpoint - for i := int64(0); i < 50; i++ { - storage.AddRecord("test-namespace", "test-topic", partition, i) - } - - // Create manager - should recover from storage scan - manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage) - if err != nil { - t.Fatalf("Failed to create offset manager: %v", err) - } - - // Next offset should be 50 - nextOffset := manager.AssignOffset() - if nextOffset != 50 { - t.Errorf("Expected next offset 50 after storage recovery, got %d", nextOffset) - } -} - -func TestPartitionOffsetRegistry_MultiplePartitions(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - - // Create different partitions - partition1 := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: time.Now().UnixNano(), - } - - partition2 := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 32, - RangeStop: 63, - UnixTimeNs: time.Now().UnixNano(), - } - - // Assign offsets to different partitions - offset1, err := registry.AssignOffset("test-namespace", "test-topic", partition1) - if err != nil { - t.Fatalf("Failed to assign offset to partition1: %v", err) - } - if offset1 != 0 { - t.Errorf("Expected offset 0 for partition1, got %d", offset1) - } - - offset2, err := registry.AssignOffset("test-namespace", "test-topic", partition2) - if err != nil { - t.Fatalf("Failed to assign offset to partition2: %v", err) - } - if offset2 != 0 { - t.Errorf("Expected offset 0 for partition2, got %d", offset2) - } - - // Assign more offsets to partition1 - offset1_2, err := registry.AssignOffset("test-namespace", "test-topic", partition1) - if err != nil { - t.Fatalf("Failed to assign second offset to partition1: %v", err) - } - if offset1_2 != 1 { - t.Errorf("Expected offset 1 for partition1, got %d", offset1_2) - } - - // Partition2 should still be at 0 for next assignment - offset2_2, err := registry.AssignOffset("test-namespace", "test-topic", partition2) - if err != nil { - t.Fatalf("Failed to assign second offset to partition2: %v", err) - } - if offset2_2 != 1 { - t.Errorf("Expected offset 1 for partition2, got %d", offset2_2) - } -} - -func TestPartitionOffsetRegistry_BatchAssignment(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - partition := createTestPartition() - - // Assign batch of offsets - baseOffset, lastOffset, err := registry.AssignOffsets("test-namespace", "test-topic", partition, 10) - if err != nil { - t.Fatalf("Failed to assign batch offsets: %v", err) - } - - if baseOffset != 0 { - t.Errorf("Expected base offset 0, got %d", baseOffset) - } - if lastOffset != 9 { - t.Errorf("Expected last offset 9, got %d", lastOffset) - } - - // Get high water mark - hwm, err := registry.GetHighWaterMark("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get high water mark: %v", err) - } - if hwm != 10 { - t.Errorf("Expected high water mark 10, got %d", hwm) - } -} - -func TestOffsetAssigner_SingleAssignment(t *testing.T) { - storage := NewInMemoryOffsetStorage() - assigner := NewOffsetAssigner(storage) - partition := createTestPartition() - - // Assign single offset - result := assigner.AssignSingleOffset("test-namespace", "test-topic", partition) - if result.Error != nil { - t.Fatalf("Failed to assign single offset: %v", result.Error) - } - - if result.Assignment == nil { - t.Fatal("Assignment result is nil") - } - - if result.Assignment.Offset != 0 { - t.Errorf("Expected offset 0, got %d", result.Assignment.Offset) - } - - if result.Assignment.Partition != partition { - t.Error("Partition mismatch in assignment") - } - - if result.Assignment.Timestamp <= 0 { - t.Error("Timestamp should be set") - } -} - -func TestOffsetAssigner_BatchAssignment(t *testing.T) { - storage := NewInMemoryOffsetStorage() - assigner := NewOffsetAssigner(storage) - partition := createTestPartition() - - // Assign batch of offsets - result := assigner.AssignBatchOffsets("test-namespace", "test-topic", partition, 5) - if result.Error != nil { - t.Fatalf("Failed to assign batch offsets: %v", result.Error) - } - - if result.Batch == nil { - t.Fatal("Batch result is nil") - } - - if result.Batch.BaseOffset != 0 { - t.Errorf("Expected base offset 0, got %d", result.Batch.BaseOffset) - } - - if result.Batch.LastOffset != 4 { - t.Errorf("Expected last offset 4, got %d", result.Batch.LastOffset) - } - - if result.Batch.Count != 5 { - t.Errorf("Expected count 5, got %d", result.Batch.Count) - } - - if result.Batch.Timestamp <= 0 { - t.Error("Timestamp should be set") - } -} - -func TestOffsetAssigner_HighWaterMark(t *testing.T) { - storage := NewInMemoryOffsetStorage() - assigner := NewOffsetAssigner(storage) - partition := createTestPartition() - - // Initially should be 0 - hwm, err := assigner.GetHighWaterMark("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get initial high water mark: %v", err) - } - if hwm != 0 { - t.Errorf("Expected initial high water mark 0, got %d", hwm) - } - - // Assign some offsets - assigner.AssignBatchOffsets("test-namespace", "test-topic", partition, 10) - - // High water mark should be updated - hwm, err = assigner.GetHighWaterMark("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get high water mark after assignment: %v", err) - } - if hwm != 10 { - t.Errorf("Expected high water mark 10, got %d", hwm) - } -} - -func TestPartitionKey(t *testing.T) { - partition1 := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: 1234567890, - } - - partition2 := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: 1234567890, - } - - partition3 := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 32, - RangeStop: 63, - UnixTimeNs: 1234567890, - } - - key1 := partitionKey(partition1) - key2 := partitionKey(partition2) - key3 := partitionKey(partition3) - - // Same partitions should have same key - if key1 != key2 { - t.Errorf("Same partitions should have same key: %s vs %s", key1, key2) - } - - // Different partitions should have different keys - if key1 == key3 { - t.Errorf("Different partitions should have different keys: %s vs %s", key1, key3) - } -} - -func TestConcurrentOffsetAssignment(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - partition := createTestPartition() - - const numGoroutines = 10 - const offsetsPerGoroutine = 100 - - results := make(chan int64, numGoroutines*offsetsPerGoroutine) - - // Start concurrent offset assignments - for i := 0; i < numGoroutines; i++ { - go func() { - for j := 0; j < offsetsPerGoroutine; j++ { - offset, err := registry.AssignOffset("test-namespace", "test-topic", partition) - if err != nil { - t.Errorf("Failed to assign offset: %v", err) - return - } - results <- offset - } - }() - } - - // Collect all results - offsets := make(map[int64]bool) - for i := 0; i < numGoroutines*offsetsPerGoroutine; i++ { - offset := <-results - if offsets[offset] { - t.Errorf("Duplicate offset assigned: %d", offset) - } - offsets[offset] = true - } - - // Verify we got all expected offsets - expectedCount := numGoroutines * offsetsPerGoroutine - if len(offsets) != expectedCount { - t.Errorf("Expected %d unique offsets, got %d", expectedCount, len(offsets)) - } - - // Verify offsets are in expected range - for offset := range offsets { - if offset < 0 || offset >= int64(expectedCount) { - t.Errorf("Offset %d is out of expected range [0, %d)", offset, expectedCount) - } - } -} diff --git a/weed/mq/offset/migration.go b/weed/mq/offset/migration.go deleted file mode 100644 index 4e0a6ab12..000000000 --- a/weed/mq/offset/migration.go +++ /dev/null @@ -1,302 +0,0 @@ -package offset - -import ( - "database/sql" - "fmt" - "time" -) - -// MigrationVersion represents a database migration version -type MigrationVersion struct { - Version int - Description string - SQL string -} - -// GetMigrations returns all available migrations for offset storage -func GetMigrations() []MigrationVersion { - return []MigrationVersion{ - { - Version: 1, - Description: "Create initial offset storage tables", - SQL: ` - -- Partition offset checkpoints table - -- TODO: Add _index as computed column when supported by database - CREATE TABLE IF NOT EXISTS partition_offset_checkpoints ( - partition_key TEXT PRIMARY KEY, - ring_size INTEGER NOT NULL, - range_start INTEGER NOT NULL, - range_stop INTEGER NOT NULL, - unix_time_ns INTEGER NOT NULL, - checkpoint_offset INTEGER NOT NULL, - updated_at INTEGER NOT NULL - ); - - -- Offset mappings table for detailed tracking - -- TODO: Add _index as computed column when supported by database - CREATE TABLE IF NOT EXISTS offset_mappings ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - partition_key TEXT NOT NULL, - kafka_offset INTEGER NOT NULL, - smq_timestamp INTEGER NOT NULL, - message_size INTEGER NOT NULL, - created_at INTEGER NOT NULL, - UNIQUE(partition_key, kafka_offset) - ); - - -- Schema migrations tracking table - CREATE TABLE IF NOT EXISTS schema_migrations ( - version INTEGER PRIMARY KEY, - description TEXT NOT NULL, - applied_at INTEGER NOT NULL - ); - `, - }, - { - Version: 2, - Description: "Add indexes for performance optimization", - SQL: ` - -- Indexes for performance - CREATE INDEX IF NOT EXISTS idx_partition_offset_checkpoints_partition - ON partition_offset_checkpoints(partition_key); - - CREATE INDEX IF NOT EXISTS idx_offset_mappings_partition_offset - ON offset_mappings(partition_key, kafka_offset); - - CREATE INDEX IF NOT EXISTS idx_offset_mappings_timestamp - ON offset_mappings(partition_key, smq_timestamp); - - CREATE INDEX IF NOT EXISTS idx_offset_mappings_created_at - ON offset_mappings(created_at); - `, - }, - { - Version: 3, - Description: "Add partition metadata table for enhanced tracking", - SQL: ` - -- Partition metadata table - CREATE TABLE IF NOT EXISTS partition_metadata ( - partition_key TEXT PRIMARY KEY, - ring_size INTEGER NOT NULL, - range_start INTEGER NOT NULL, - range_stop INTEGER NOT NULL, - unix_time_ns INTEGER NOT NULL, - created_at INTEGER NOT NULL, - last_activity_at INTEGER NOT NULL, - record_count INTEGER DEFAULT 0, - total_size INTEGER DEFAULT 0 - ); - - -- Index for partition metadata - CREATE INDEX IF NOT EXISTS idx_partition_metadata_activity - ON partition_metadata(last_activity_at); - `, - }, - } -} - -// MigrationManager handles database schema migrations -type MigrationManager struct { - db *sql.DB -} - -// NewMigrationManager creates a new migration manager -func NewMigrationManager(db *sql.DB) *MigrationManager { - return &MigrationManager{db: db} -} - -// GetCurrentVersion returns the current schema version -func (m *MigrationManager) GetCurrentVersion() (int, error) { - // First, ensure the migrations table exists - _, err := m.db.Exec(` - CREATE TABLE IF NOT EXISTS schema_migrations ( - version INTEGER PRIMARY KEY, - description TEXT NOT NULL, - applied_at INTEGER NOT NULL - ) - `) - if err != nil { - return 0, fmt.Errorf("failed to create migrations table: %w", err) - } - - var version sql.NullInt64 - err = m.db.QueryRow("SELECT MAX(version) FROM schema_migrations").Scan(&version) - if err != nil { - return 0, fmt.Errorf("failed to get current version: %w", err) - } - - if !version.Valid { - return 0, nil // No migrations applied yet - } - - return int(version.Int64), nil -} - -// ApplyMigrations applies all pending migrations -func (m *MigrationManager) ApplyMigrations() error { - currentVersion, err := m.GetCurrentVersion() - if err != nil { - return fmt.Errorf("failed to get current version: %w", err) - } - - migrations := GetMigrations() - - for _, migration := range migrations { - if migration.Version <= currentVersion { - continue // Already applied - } - - fmt.Printf("Applying migration %d: %s\n", migration.Version, migration.Description) - - // Begin transaction - tx, err := m.db.Begin() - if err != nil { - return fmt.Errorf("failed to begin transaction for migration %d: %w", migration.Version, err) - } - - // Execute migration SQL - _, err = tx.Exec(migration.SQL) - if err != nil { - tx.Rollback() - return fmt.Errorf("failed to execute migration %d: %w", migration.Version, err) - } - - // Record migration as applied - _, err = tx.Exec( - "INSERT INTO schema_migrations (version, description, applied_at) VALUES (?, ?, ?)", - migration.Version, - migration.Description, - getCurrentTimestamp(), - ) - if err != nil { - tx.Rollback() - return fmt.Errorf("failed to record migration %d: %w", migration.Version, err) - } - - // Commit transaction - err = tx.Commit() - if err != nil { - return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err) - } - - fmt.Printf("Successfully applied migration %d\n", migration.Version) - } - - return nil -} - -// RollbackMigration rolls back a specific migration (if supported) -func (m *MigrationManager) RollbackMigration(version int) error { - // TODO: Implement rollback functionality - // ASSUMPTION: For now, rollbacks are not supported as they require careful planning - return fmt.Errorf("migration rollbacks not implemented - manual intervention required") -} - -// GetAppliedMigrations returns a list of all applied migrations -func (m *MigrationManager) GetAppliedMigrations() ([]AppliedMigration, error) { - rows, err := m.db.Query(` - SELECT version, description, applied_at - FROM schema_migrations - ORDER BY version - `) - if err != nil { - return nil, fmt.Errorf("failed to query applied migrations: %w", err) - } - defer rows.Close() - - var migrations []AppliedMigration - for rows.Next() { - var migration AppliedMigration - err := rows.Scan(&migration.Version, &migration.Description, &migration.AppliedAt) - if err != nil { - return nil, fmt.Errorf("failed to scan migration: %w", err) - } - migrations = append(migrations, migration) - } - - return migrations, nil -} - -// ValidateSchema validates that the database schema is up to date -func (m *MigrationManager) ValidateSchema() error { - currentVersion, err := m.GetCurrentVersion() - if err != nil { - return fmt.Errorf("failed to get current version: %w", err) - } - - migrations := GetMigrations() - if len(migrations) == 0 { - return nil - } - - latestVersion := migrations[len(migrations)-1].Version - if currentVersion < latestVersion { - return fmt.Errorf("schema is outdated: current version %d, latest version %d", currentVersion, latestVersion) - } - - return nil -} - -// AppliedMigration represents a migration that has been applied -type AppliedMigration struct { - Version int - Description string - AppliedAt int64 -} - -// getCurrentTimestamp returns the current timestamp in nanoseconds -func getCurrentTimestamp() int64 { - return time.Now().UnixNano() -} - -// CreateDatabase creates and initializes a new offset storage database -func CreateDatabase(dbPath string) (*sql.DB, error) { - // TODO: Support different database types (PostgreSQL, MySQL, etc.) - // ASSUMPTION: Using SQLite for now, can be extended for other databases - - db, err := sql.Open("sqlite3", dbPath) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) - } - - // Configure SQLite for better performance - pragmas := []string{ - "PRAGMA journal_mode=WAL", // Write-Ahead Logging for better concurrency - "PRAGMA synchronous=NORMAL", // Balance between safety and performance - "PRAGMA cache_size=10000", // Increase cache size - "PRAGMA foreign_keys=ON", // Enable foreign key constraints - "PRAGMA temp_store=MEMORY", // Store temporary tables in memory - } - - for _, pragma := range pragmas { - _, err := db.Exec(pragma) - if err != nil { - db.Close() - return nil, fmt.Errorf("failed to set pragma %s: %w", pragma, err) - } - } - - // Apply migrations - migrationManager := NewMigrationManager(db) - err = migrationManager.ApplyMigrations() - if err != nil { - db.Close() - return nil, fmt.Errorf("failed to apply migrations: %w", err) - } - - return db, nil -} - -// BackupDatabase creates a backup of the offset storage database -func BackupDatabase(sourceDB *sql.DB, backupPath string) error { - // TODO: Implement database backup functionality - // ASSUMPTION: This would use database-specific backup mechanisms - return fmt.Errorf("database backup not implemented yet") -} - -// RestoreDatabase restores a database from a backup -func RestoreDatabase(backupPath, targetPath string) error { - // TODO: Implement database restore functionality - // ASSUMPTION: This would use database-specific restore mechanisms - return fmt.Errorf("database restore not implemented yet") -} diff --git a/weed/mq/offset/sql_storage.go b/weed/mq/offset/sql_storage.go deleted file mode 100644 index c3107e5a4..000000000 --- a/weed/mq/offset/sql_storage.go +++ /dev/null @@ -1,394 +0,0 @@ -package offset - -import ( - "database/sql" - "fmt" - "time" - - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" -) - -// OffsetEntry represents a mapping between Kafka offset and SMQ timestamp -type OffsetEntry struct { - KafkaOffset int64 - SMQTimestamp int64 - MessageSize int32 -} - -// SQLOffsetStorage implements OffsetStorage using SQL database with _index column -type SQLOffsetStorage struct { - db *sql.DB -} - -// NewSQLOffsetStorage creates a new SQL-based offset storage -func NewSQLOffsetStorage(db *sql.DB) (*SQLOffsetStorage, error) { - storage := &SQLOffsetStorage{db: db} - - // Initialize database schema - if err := storage.initializeSchema(); err != nil { - return nil, fmt.Errorf("failed to initialize schema: %w", err) - } - - return storage, nil -} - -// initializeSchema creates the necessary tables for offset storage -func (s *SQLOffsetStorage) initializeSchema() error { - // TODO: Create offset storage tables with _index as hidden column - // ASSUMPTION: Using SQLite-compatible syntax, may need adaptation for other databases - - queries := []string{ - // Partition offset checkpoints table - // TODO: Add _index as computed column when supported by database - // ASSUMPTION: Using regular columns for now, _index concept preserved for future enhancement - `CREATE TABLE IF NOT EXISTS partition_offset_checkpoints ( - partition_key TEXT PRIMARY KEY, - ring_size INTEGER NOT NULL, - range_start INTEGER NOT NULL, - range_stop INTEGER NOT NULL, - unix_time_ns INTEGER NOT NULL, - checkpoint_offset INTEGER NOT NULL, - updated_at INTEGER NOT NULL - )`, - - // Offset mappings table for detailed tracking - // TODO: Add _index as computed column when supported by database - `CREATE TABLE IF NOT EXISTS offset_mappings ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - partition_key TEXT NOT NULL, - kafka_offset INTEGER NOT NULL, - smq_timestamp INTEGER NOT NULL, - message_size INTEGER NOT NULL, - created_at INTEGER NOT NULL, - UNIQUE(partition_key, kafka_offset) - )`, - - // Indexes for performance - `CREATE INDEX IF NOT EXISTS idx_partition_offset_checkpoints_partition - ON partition_offset_checkpoints(partition_key)`, - - `CREATE INDEX IF NOT EXISTS idx_offset_mappings_partition_offset - ON offset_mappings(partition_key, kafka_offset)`, - - `CREATE INDEX IF NOT EXISTS idx_offset_mappings_timestamp - ON offset_mappings(partition_key, smq_timestamp)`, - } - - for _, query := range queries { - if _, err := s.db.Exec(query); err != nil { - return fmt.Errorf("failed to execute schema query: %w", err) - } - } - - return nil -} - -// SaveCheckpoint saves the checkpoint for a partition -func (s *SQLOffsetStorage) SaveCheckpoint(namespace, topicName string, partition *schema_pb.Partition, offset int64) error { - // Use TopicPartitionKey to ensure each topic has isolated checkpoint storage - partitionKey := TopicPartitionKey(namespace, topicName, partition) - now := time.Now().UnixNano() - - // TODO: Use UPSERT for better performance - // ASSUMPTION: SQLite REPLACE syntax, may need adaptation for other databases - query := ` - REPLACE INTO partition_offset_checkpoints - (partition_key, ring_size, range_start, range_stop, unix_time_ns, checkpoint_offset, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?) - ` - - _, err := s.db.Exec(query, - partitionKey, - partition.RingSize, - partition.RangeStart, - partition.RangeStop, - partition.UnixTimeNs, - offset, - now, - ) - - if err != nil { - return fmt.Errorf("failed to save checkpoint: %w", err) - } - - return nil -} - -// LoadCheckpoint loads the checkpoint for a partition -func (s *SQLOffsetStorage) LoadCheckpoint(namespace, topicName string, partition *schema_pb.Partition) (int64, error) { - // Use TopicPartitionKey to match SaveCheckpoint - partitionKey := TopicPartitionKey(namespace, topicName, partition) - - query := ` - SELECT checkpoint_offset - FROM partition_offset_checkpoints - WHERE partition_key = ? - ` - - var checkpointOffset int64 - err := s.db.QueryRow(query, partitionKey).Scan(&checkpointOffset) - - if err == sql.ErrNoRows { - return -1, fmt.Errorf("no checkpoint found") - } - - if err != nil { - return -1, fmt.Errorf("failed to load checkpoint: %w", err) - } - - return checkpointOffset, nil -} - -// GetHighestOffset finds the highest offset in storage for a partition -func (s *SQLOffsetStorage) GetHighestOffset(namespace, topicName string, partition *schema_pb.Partition) (int64, error) { - // Use TopicPartitionKey to match SaveCheckpoint - partitionKey := TopicPartitionKey(namespace, topicName, partition) - - // TODO: Use _index column for efficient querying - // ASSUMPTION: kafka_offset represents the sequential offset we're tracking - query := ` - SELECT MAX(kafka_offset) - FROM offset_mappings - WHERE partition_key = ? - ` - - var highestOffset sql.NullInt64 - err := s.db.QueryRow(query, partitionKey).Scan(&highestOffset) - - if err != nil { - return -1, fmt.Errorf("failed to get highest offset: %w", err) - } - - if !highestOffset.Valid { - return -1, fmt.Errorf("no records found") - } - - return highestOffset.Int64, nil -} - -// SaveOffsetMapping stores an offset mapping (extends OffsetStorage interface) -func (s *SQLOffsetStorage) SaveOffsetMapping(partitionKey string, kafkaOffset, smqTimestamp int64, size int32) error { - now := time.Now().UnixNano() - - // TODO: Handle duplicate key conflicts gracefully - // ASSUMPTION: Using INSERT OR REPLACE for conflict resolution - query := ` - INSERT OR REPLACE INTO offset_mappings - (partition_key, kafka_offset, smq_timestamp, message_size, created_at) - VALUES (?, ?, ?, ?, ?) - ` - - _, err := s.db.Exec(query, partitionKey, kafkaOffset, smqTimestamp, size, now) - if err != nil { - return fmt.Errorf("failed to save offset mapping: %w", err) - } - - return nil -} - -// LoadOffsetMappings retrieves all offset mappings for a partition -func (s *SQLOffsetStorage) LoadOffsetMappings(partitionKey string) ([]OffsetEntry, error) { - // TODO: Add pagination for large result sets - // ASSUMPTION: Loading all mappings for now, should be paginated in production - query := ` - SELECT kafka_offset, smq_timestamp, message_size - FROM offset_mappings - WHERE partition_key = ? - ORDER BY kafka_offset ASC - ` - - rows, err := s.db.Query(query, partitionKey) - if err != nil { - return nil, fmt.Errorf("failed to query offset mappings: %w", err) - } - defer rows.Close() - - var entries []OffsetEntry - for rows.Next() { - var entry OffsetEntry - err := rows.Scan(&entry.KafkaOffset, &entry.SMQTimestamp, &entry.MessageSize) - if err != nil { - return nil, fmt.Errorf("failed to scan offset entry: %w", err) - } - entries = append(entries, entry) - } - - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating offset mappings: %w", err) - } - - return entries, nil -} - -// GetOffsetMappingsByRange retrieves offset mappings within a specific range -func (s *SQLOffsetStorage) GetOffsetMappingsByRange(partitionKey string, startOffset, endOffset int64) ([]OffsetEntry, error) { - // TODO: Use _index column for efficient range queries - query := ` - SELECT kafka_offset, smq_timestamp, message_size - FROM offset_mappings - WHERE partition_key = ? AND kafka_offset >= ? AND kafka_offset <= ? - ORDER BY kafka_offset ASC - ` - - rows, err := s.db.Query(query, partitionKey, startOffset, endOffset) - if err != nil { - return nil, fmt.Errorf("failed to query offset range: %w", err) - } - defer rows.Close() - - var entries []OffsetEntry - for rows.Next() { - var entry OffsetEntry - err := rows.Scan(&entry.KafkaOffset, &entry.SMQTimestamp, &entry.MessageSize) - if err != nil { - return nil, fmt.Errorf("failed to scan offset entry: %w", err) - } - entries = append(entries, entry) - } - - return entries, nil -} - -// GetPartitionStats returns statistics about a partition's offset usage -func (s *SQLOffsetStorage) GetPartitionStats(partitionKey string) (*PartitionStats, error) { - query := ` - SELECT - COUNT(*) as record_count, - MIN(kafka_offset) as earliest_offset, - MAX(kafka_offset) as latest_offset, - SUM(message_size) as total_size, - MIN(created_at) as first_record_time, - MAX(created_at) as last_record_time - FROM offset_mappings - WHERE partition_key = ? - ` - - var stats PartitionStats - var earliestOffset, latestOffset sql.NullInt64 - var totalSize sql.NullInt64 - var firstRecordTime, lastRecordTime sql.NullInt64 - - err := s.db.QueryRow(query, partitionKey).Scan( - &stats.RecordCount, - &earliestOffset, - &latestOffset, - &totalSize, - &firstRecordTime, - &lastRecordTime, - ) - - if err != nil { - return nil, fmt.Errorf("failed to get partition stats: %w", err) - } - - stats.PartitionKey = partitionKey - - if earliestOffset.Valid { - stats.EarliestOffset = earliestOffset.Int64 - } else { - stats.EarliestOffset = -1 - } - - if latestOffset.Valid { - stats.LatestOffset = latestOffset.Int64 - stats.HighWaterMark = latestOffset.Int64 + 1 - } else { - stats.LatestOffset = -1 - stats.HighWaterMark = 0 - } - - if firstRecordTime.Valid { - stats.FirstRecordTime = firstRecordTime.Int64 - } - - if lastRecordTime.Valid { - stats.LastRecordTime = lastRecordTime.Int64 - } - - if totalSize.Valid { - stats.TotalSize = totalSize.Int64 - } - - return &stats, nil -} - -// CleanupOldMappings removes offset mappings older than the specified time -func (s *SQLOffsetStorage) CleanupOldMappings(olderThanNs int64) error { - // TODO: Add configurable cleanup policies - // ASSUMPTION: Simple time-based cleanup, could be enhanced with retention policies - query := ` - DELETE FROM offset_mappings - WHERE created_at < ? - ` - - result, err := s.db.Exec(query, olderThanNs) - if err != nil { - return fmt.Errorf("failed to cleanup old mappings: %w", err) - } - - rowsAffected, _ := result.RowsAffected() - if rowsAffected > 0 { - // Log cleanup activity - fmt.Printf("Cleaned up %d old offset mappings\n", rowsAffected) - } - - return nil -} - -// Close closes the database connection -func (s *SQLOffsetStorage) Close() error { - if s.db != nil { - return s.db.Close() - } - return nil -} - -// PartitionStats provides statistics about a partition's offset usage -type PartitionStats struct { - PartitionKey string - RecordCount int64 - EarliestOffset int64 - LatestOffset int64 - HighWaterMark int64 - TotalSize int64 - FirstRecordTime int64 - LastRecordTime int64 -} - -// GetAllPartitions returns a list of all partitions with offset data -func (s *SQLOffsetStorage) GetAllPartitions() ([]string, error) { - query := ` - SELECT DISTINCT partition_key - FROM offset_mappings - ORDER BY partition_key - ` - - rows, err := s.db.Query(query) - if err != nil { - return nil, fmt.Errorf("failed to get all partitions: %w", err) - } - defer rows.Close() - - var partitions []string - for rows.Next() { - var partitionKey string - if err := rows.Scan(&partitionKey); err != nil { - return nil, fmt.Errorf("failed to scan partition key: %w", err) - } - partitions = append(partitions, partitionKey) - } - - return partitions, nil -} - -// Vacuum performs database maintenance operations -func (s *SQLOffsetStorage) Vacuum() error { - // TODO: Add database-specific optimization commands - // ASSUMPTION: SQLite VACUUM command, may need adaptation for other databases - _, err := s.db.Exec("VACUUM") - if err != nil { - return fmt.Errorf("failed to vacuum database: %w", err) - } - - return nil -} diff --git a/weed/mq/offset/sql_storage_test.go b/weed/mq/offset/sql_storage_test.go deleted file mode 100644 index 661f317de..000000000 --- a/weed/mq/offset/sql_storage_test.go +++ /dev/null @@ -1,516 +0,0 @@ -package offset - -import ( - "database/sql" - "os" - "testing" - "time" - - _ "github.com/mattn/go-sqlite3" // SQLite driver - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" -) - -func createTestDB(t *testing.T) *sql.DB { - // Create temporary database file - tmpFile, err := os.CreateTemp("", "offset_test_*.db") - if err != nil { - t.Fatalf("Failed to create temp database file: %v", err) - } - tmpFile.Close() - - // Clean up the file when test completes - t.Cleanup(func() { - os.Remove(tmpFile.Name()) - }) - - db, err := sql.Open("sqlite3", tmpFile.Name()) - if err != nil { - t.Fatalf("Failed to open database: %v", err) - } - - t.Cleanup(func() { - db.Close() - }) - - return db -} - -func createTestPartitionForSQL() *schema_pb.Partition { - return &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 0, - RangeStop: 31, - UnixTimeNs: time.Now().UnixNano(), - } -} - -func TestSQLOffsetStorage_InitializeSchema(t *testing.T) { - db := createTestDB(t) - - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - // Verify tables were created - tables := []string{ - "partition_offset_checkpoints", - "offset_mappings", - } - - for _, table := range tables { - var count int - err := db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&count) - if err != nil { - t.Fatalf("Failed to check table %s: %v", table, err) - } - - if count != 1 { - t.Errorf("Table %s was not created", table) - } - } -} - -func TestSQLOffsetStorage_SaveLoadCheckpoint(t *testing.T) { - db := createTestDB(t) - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - partition := createTestPartitionForSQL() - - // Test saving checkpoint - err = storage.SaveCheckpoint("test-namespace", "test-topic", partition, 100) - if err != nil { - t.Fatalf("Failed to save checkpoint: %v", err) - } - - // Test loading checkpoint - checkpoint, err := storage.LoadCheckpoint("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to load checkpoint: %v", err) - } - - if checkpoint != 100 { - t.Errorf("Expected checkpoint 100, got %d", checkpoint) - } - - // Test updating checkpoint - err = storage.SaveCheckpoint("test-namespace", "test-topic", partition, 200) - if err != nil { - t.Fatalf("Failed to update checkpoint: %v", err) - } - - checkpoint, err = storage.LoadCheckpoint("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to load updated checkpoint: %v", err) - } - - if checkpoint != 200 { - t.Errorf("Expected updated checkpoint 200, got %d", checkpoint) - } -} - -func TestSQLOffsetStorage_LoadCheckpointNotFound(t *testing.T) { - db := createTestDB(t) - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - partition := createTestPartitionForSQL() - - // Test loading non-existent checkpoint - _, err = storage.LoadCheckpoint("test-namespace", "test-topic", partition) - if err == nil { - t.Error("Expected error for non-existent checkpoint") - } -} - -func TestSQLOffsetStorage_SaveLoadOffsetMappings(t *testing.T) { - db := createTestDB(t) - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - partition := createTestPartitionForSQL() - partitionKey := partitionKey(partition) - - // Save multiple offset mappings - mappings := []struct { - offset int64 - timestamp int64 - size int32 - }{ - {0, 1000, 100}, - {1, 2000, 150}, - {2, 3000, 200}, - } - - for _, mapping := range mappings { - err := storage.SaveOffsetMapping(partitionKey, mapping.offset, mapping.timestamp, mapping.size) - if err != nil { - t.Fatalf("Failed to save offset mapping: %v", err) - } - } - - // Load offset mappings - entries, err := storage.LoadOffsetMappings(partitionKey) - if err != nil { - t.Fatalf("Failed to load offset mappings: %v", err) - } - - if len(entries) != len(mappings) { - t.Errorf("Expected %d entries, got %d", len(mappings), len(entries)) - } - - // Verify entries are sorted by offset - for i, entry := range entries { - expected := mappings[i] - if entry.KafkaOffset != expected.offset { - t.Errorf("Entry %d: expected offset %d, got %d", i, expected.offset, entry.KafkaOffset) - } - if entry.SMQTimestamp != expected.timestamp { - t.Errorf("Entry %d: expected timestamp %d, got %d", i, expected.timestamp, entry.SMQTimestamp) - } - if entry.MessageSize != expected.size { - t.Errorf("Entry %d: expected size %d, got %d", i, expected.size, entry.MessageSize) - } - } -} - -func TestSQLOffsetStorage_GetHighestOffset(t *testing.T) { - db := createTestDB(t) - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - partition := createTestPartitionForSQL() - partitionKey := TopicPartitionKey("test-namespace", "test-topic", partition) - - // Test empty partition - _, err = storage.GetHighestOffset("test-namespace", "test-topic", partition) - if err == nil { - t.Error("Expected error for empty partition") - } - - // Add some offset mappings - offsets := []int64{5, 1, 3, 2, 4} - for _, offset := range offsets { - err := storage.SaveOffsetMapping(partitionKey, offset, offset*1000, 100) - if err != nil { - t.Fatalf("Failed to save offset mapping: %v", err) - } - } - - // Get highest offset - highest, err := storage.GetHighestOffset("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get highest offset: %v", err) - } - - if highest != 5 { - t.Errorf("Expected highest offset 5, got %d", highest) - } -} - -func TestSQLOffsetStorage_GetOffsetMappingsByRange(t *testing.T) { - db := createTestDB(t) - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - partition := createTestPartitionForSQL() - partitionKey := partitionKey(partition) - - // Add offset mappings - for i := int64(0); i < 10; i++ { - err := storage.SaveOffsetMapping(partitionKey, i, i*1000, 100) - if err != nil { - t.Fatalf("Failed to save offset mapping: %v", err) - } - } - - // Get range of offsets - entries, err := storage.GetOffsetMappingsByRange(partitionKey, 3, 7) - if err != nil { - t.Fatalf("Failed to get offset range: %v", err) - } - - expectedCount := 5 // offsets 3, 4, 5, 6, 7 - if len(entries) != expectedCount { - t.Errorf("Expected %d entries, got %d", expectedCount, len(entries)) - } - - // Verify range - for i, entry := range entries { - expectedOffset := int64(3 + i) - if entry.KafkaOffset != expectedOffset { - t.Errorf("Entry %d: expected offset %d, got %d", i, expectedOffset, entry.KafkaOffset) - } - } -} - -func TestSQLOffsetStorage_GetPartitionStats(t *testing.T) { - db := createTestDB(t) - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - partition := createTestPartitionForSQL() - partitionKey := partitionKey(partition) - - // Test empty partition stats - stats, err := storage.GetPartitionStats(partitionKey) - if err != nil { - t.Fatalf("Failed to get empty partition stats: %v", err) - } - - if stats.RecordCount != 0 { - t.Errorf("Expected record count 0, got %d", stats.RecordCount) - } - - if stats.EarliestOffset != -1 { - t.Errorf("Expected earliest offset -1, got %d", stats.EarliestOffset) - } - - // Add some data - sizes := []int32{100, 150, 200} - for i, size := range sizes { - err := storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), size) - if err != nil { - t.Fatalf("Failed to save offset mapping: %v", err) - } - } - - // Get stats with data - stats, err = storage.GetPartitionStats(partitionKey) - if err != nil { - t.Fatalf("Failed to get partition stats: %v", err) - } - - if stats.RecordCount != 3 { - t.Errorf("Expected record count 3, got %d", stats.RecordCount) - } - - if stats.EarliestOffset != 0 { - t.Errorf("Expected earliest offset 0, got %d", stats.EarliestOffset) - } - - if stats.LatestOffset != 2 { - t.Errorf("Expected latest offset 2, got %d", stats.LatestOffset) - } - - if stats.HighWaterMark != 3 { - t.Errorf("Expected high water mark 3, got %d", stats.HighWaterMark) - } - - expectedTotalSize := int64(100 + 150 + 200) - if stats.TotalSize != expectedTotalSize { - t.Errorf("Expected total size %d, got %d", expectedTotalSize, stats.TotalSize) - } -} - -func TestSQLOffsetStorage_GetAllPartitions(t *testing.T) { - db := createTestDB(t) - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - // Test empty database - partitions, err := storage.GetAllPartitions() - if err != nil { - t.Fatalf("Failed to get all partitions: %v", err) - } - - if len(partitions) != 0 { - t.Errorf("Expected 0 partitions, got %d", len(partitions)) - } - - // Add data for multiple partitions - partition1 := createTestPartitionForSQL() - partition2 := &schema_pb.Partition{ - RingSize: 1024, - RangeStart: 32, - RangeStop: 63, - UnixTimeNs: time.Now().UnixNano(), - } - - partitionKey1 := partitionKey(partition1) - partitionKey2 := partitionKey(partition2) - - storage.SaveOffsetMapping(partitionKey1, 0, 1000, 100) - storage.SaveOffsetMapping(partitionKey2, 0, 2000, 150) - - // Get all partitions - partitions, err = storage.GetAllPartitions() - if err != nil { - t.Fatalf("Failed to get all partitions: %v", err) - } - - if len(partitions) != 2 { - t.Errorf("Expected 2 partitions, got %d", len(partitions)) - } - - // Verify partition keys are present - partitionMap := make(map[string]bool) - for _, p := range partitions { - partitionMap[p] = true - } - - if !partitionMap[partitionKey1] { - t.Errorf("Partition key %s not found", partitionKey1) - } - - if !partitionMap[partitionKey2] { - t.Errorf("Partition key %s not found", partitionKey2) - } -} - -func TestSQLOffsetStorage_CleanupOldMappings(t *testing.T) { - db := createTestDB(t) - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - partition := createTestPartitionForSQL() - partitionKey := partitionKey(partition) - - // Add mappings with different timestamps - now := time.Now().UnixNano() - - // Add old mapping by directly inserting with old timestamp - oldTime := now - (24 * time.Hour).Nanoseconds() // 24 hours ago - _, err = db.Exec(` - INSERT INTO offset_mappings - (partition_key, kafka_offset, smq_timestamp, message_size, created_at) - VALUES (?, ?, ?, ?, ?) - `, partitionKey, 0, oldTime, 100, oldTime) - if err != nil { - t.Fatalf("Failed to insert old mapping: %v", err) - } - - // Add recent mapping - storage.SaveOffsetMapping(partitionKey, 1, now, 150) - - // Verify both mappings exist - entries, err := storage.LoadOffsetMappings(partitionKey) - if err != nil { - t.Fatalf("Failed to load mappings: %v", err) - } - - if len(entries) != 2 { - t.Errorf("Expected 2 mappings before cleanup, got %d", len(entries)) - } - - // Cleanup old mappings (older than 12 hours) - cutoffTime := now - (12 * time.Hour).Nanoseconds() - err = storage.CleanupOldMappings(cutoffTime) - if err != nil { - t.Fatalf("Failed to cleanup old mappings: %v", err) - } - - // Verify only recent mapping remains - entries, err = storage.LoadOffsetMappings(partitionKey) - if err != nil { - t.Fatalf("Failed to load mappings after cleanup: %v", err) - } - - if len(entries) != 1 { - t.Errorf("Expected 1 mapping after cleanup, got %d", len(entries)) - } - - if entries[0].KafkaOffset != 1 { - t.Errorf("Expected remaining mapping offset 1, got %d", entries[0].KafkaOffset) - } -} - -func TestSQLOffsetStorage_Vacuum(t *testing.T) { - db := createTestDB(t) - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - // Vacuum should not fail on empty database - err = storage.Vacuum() - if err != nil { - t.Fatalf("Failed to vacuum database: %v", err) - } - - // Add some data and vacuum again - partition := createTestPartitionForSQL() - partitionKey := partitionKey(partition) - storage.SaveOffsetMapping(partitionKey, 0, 1000, 100) - - err = storage.Vacuum() - if err != nil { - t.Fatalf("Failed to vacuum database with data: %v", err) - } -} - -func TestSQLOffsetStorage_ConcurrentAccess(t *testing.T) { - db := createTestDB(t) - storage, err := NewSQLOffsetStorage(db) - if err != nil { - t.Fatalf("Failed to create SQL storage: %v", err) - } - defer storage.Close() - - partition := createTestPartitionForSQL() - partitionKey := partitionKey(partition) - - // Test concurrent writes - const numGoroutines = 10 - const offsetsPerGoroutine = 10 - - done := make(chan bool, numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func(goroutineID int) { - defer func() { done <- true }() - - for j := 0; j < offsetsPerGoroutine; j++ { - offset := int64(goroutineID*offsetsPerGoroutine + j) - err := storage.SaveOffsetMapping(partitionKey, offset, offset*1000, 100) - if err != nil { - t.Errorf("Failed to save offset mapping %d: %v", offset, err) - return - } - } - }(i) - } - - // Wait for all goroutines to complete - for i := 0; i < numGoroutines; i++ { - <-done - } - - // Verify all mappings were saved - entries, err := storage.LoadOffsetMappings(partitionKey) - if err != nil { - t.Fatalf("Failed to load mappings: %v", err) - } - - expectedCount := numGoroutines * offsetsPerGoroutine - if len(entries) != expectedCount { - t.Errorf("Expected %d mappings, got %d", expectedCount, len(entries)) - } -} diff --git a/weed/mq/offset/subscriber_test.go b/weed/mq/offset/subscriber_test.go deleted file mode 100644 index 1ab97dadc..000000000 --- a/weed/mq/offset/subscriber_test.go +++ /dev/null @@ -1,457 +0,0 @@ -package offset - -import ( - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" -) - -func TestOffsetSubscriber_CreateSubscription(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - subscriber := NewOffsetSubscriber(registry) - partition := createTestPartition() - - // Assign some offsets first - registry.AssignOffsets("test-namespace", "test-topic", partition, 10) - - // Test EXACT_OFFSET subscription - sub, err := subscriber.CreateSubscription("test-sub-1", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5) - if err != nil { - t.Fatalf("Failed to create EXACT_OFFSET subscription: %v", err) - } - - if sub.StartOffset != 5 { - t.Errorf("Expected start offset 5, got %d", sub.StartOffset) - } - if sub.CurrentOffset != 5 { - t.Errorf("Expected current offset 5, got %d", sub.CurrentOffset) - } - - // Test RESET_TO_LATEST subscription - sub2, err := subscriber.CreateSubscription("test-sub-2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0) - if err != nil { - t.Fatalf("Failed to create RESET_TO_LATEST subscription: %v", err) - } - - if sub2.StartOffset != 10 { // Should be at high water mark - t.Errorf("Expected start offset 10, got %d", sub2.StartOffset) - } -} - -func TestOffsetSubscriber_InvalidSubscription(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - subscriber := NewOffsetSubscriber(registry) - partition := createTestPartition() - - // Assign some offsets - registry.AssignOffsets("test-namespace", "test-topic", partition, 5) - - // Test invalid offset (beyond high water mark) - _, err := subscriber.CreateSubscription("invalid-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 10) - if err == nil { - t.Error("Expected error for offset beyond high water mark") - } - - // Test negative offset - _, err = subscriber.CreateSubscription("invalid-sub-2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, -1) - if err == nil { - t.Error("Expected error for negative offset") - } -} - -func TestOffsetSubscriber_DuplicateSubscription(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - subscriber := NewOffsetSubscriber(registry) - partition := createTestPartition() - - // Create first subscription - _, err := subscriber.CreateSubscription("duplicate-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0) - if err != nil { - t.Fatalf("Failed to create first subscription: %v", err) - } - - // Try to create duplicate - _, err = subscriber.CreateSubscription("duplicate-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0) - if err == nil { - t.Error("Expected error for duplicate subscription ID") - } -} - -func TestOffsetSubscription_SeekToOffset(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - subscriber := NewOffsetSubscriber(registry) - partition := createTestPartition() - - // Assign offsets - registry.AssignOffsets("test-namespace", "test-topic", partition, 20) - - // Create subscription - sub, err := subscriber.CreateSubscription("seek-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Test valid seek - err = sub.SeekToOffset(10) - if err != nil { - t.Fatalf("Failed to seek to offset 10: %v", err) - } - - if sub.CurrentOffset != 10 { - t.Errorf("Expected current offset 10, got %d", sub.CurrentOffset) - } - - // Test invalid seek (beyond high water mark) - err = sub.SeekToOffset(25) - if err == nil { - t.Error("Expected error for seek beyond high water mark") - } - - // Test negative seek - err = sub.SeekToOffset(-1) - if err == nil { - t.Error("Expected error for negative seek offset") - } -} - -func TestOffsetSubscription_AdvanceOffset(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - subscriber := NewOffsetSubscriber(registry) - partition := createTestPartition() - - // Create subscription - sub, err := subscriber.CreateSubscription("advance-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Test single advance - initialOffset := sub.GetNextOffset() - sub.AdvanceOffset() - - if sub.GetNextOffset() != initialOffset+1 { - t.Errorf("Expected offset %d, got %d", initialOffset+1, sub.GetNextOffset()) - } - - // Test batch advance - sub.AdvanceOffsetBy(5) - - if sub.GetNextOffset() != initialOffset+6 { - t.Errorf("Expected offset %d, got %d", initialOffset+6, sub.GetNextOffset()) - } -} - -func TestOffsetSubscription_GetLag(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - subscriber := NewOffsetSubscriber(registry) - partition := createTestPartition() - - // Assign offsets - registry.AssignOffsets("test-namespace", "test-topic", partition, 15) - - // Create subscription at offset 5 - sub, err := subscriber.CreateSubscription("lag-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Check initial lag - lag, err := sub.GetLag() - if err != nil { - t.Fatalf("Failed to get lag: %v", err) - } - - expectedLag := int64(15 - 5) // hwm - current - if lag != expectedLag { - t.Errorf("Expected lag %d, got %d", expectedLag, lag) - } - - // Advance and check lag again - sub.AdvanceOffsetBy(3) - - lag, err = sub.GetLag() - if err != nil { - t.Fatalf("Failed to get lag after advance: %v", err) - } - - expectedLag = int64(15 - 8) // hwm - current - if lag != expectedLag { - t.Errorf("Expected lag %d after advance, got %d", expectedLag, lag) - } -} - -func TestOffsetSubscription_IsAtEnd(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - subscriber := NewOffsetSubscriber(registry) - partition := createTestPartition() - - // Assign offsets - registry.AssignOffsets("test-namespace", "test-topic", partition, 10) - - // Create subscription at end - sub, err := subscriber.CreateSubscription("end-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Should be at end - atEnd, err := sub.IsAtEnd() - if err != nil { - t.Fatalf("Failed to check if at end: %v", err) - } - - if !atEnd { - t.Error("Expected subscription to be at end") - } - - // Seek to middle and check again - sub.SeekToOffset(5) - - atEnd, err = sub.IsAtEnd() - if err != nil { - t.Fatalf("Failed to check if at end after seek: %v", err) - } - - if atEnd { - t.Error("Expected subscription not to be at end after seek") - } -} - -func TestOffsetSubscription_GetOffsetRange(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - subscriber := NewOffsetSubscriber(registry) - partition := createTestPartition() - - // Assign offsets - registry.AssignOffsets("test-namespace", "test-topic", partition, 20) - - // Create subscription - sub, err := subscriber.CreateSubscription("range-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Test normal range - offsetRange, err := sub.GetOffsetRange(10) - if err != nil { - t.Fatalf("Failed to get offset range: %v", err) - } - - if offsetRange.StartOffset != 5 { - t.Errorf("Expected start offset 5, got %d", offsetRange.StartOffset) - } - if offsetRange.EndOffset != 14 { - t.Errorf("Expected end offset 14, got %d", offsetRange.EndOffset) - } - if offsetRange.Count != 10 { - t.Errorf("Expected count 10, got %d", offsetRange.Count) - } - - // Test range that exceeds high water mark - sub.SeekToOffset(15) - offsetRange, err = sub.GetOffsetRange(10) - if err != nil { - t.Fatalf("Failed to get offset range near end: %v", err) - } - - if offsetRange.StartOffset != 15 { - t.Errorf("Expected start offset 15, got %d", offsetRange.StartOffset) - } - if offsetRange.EndOffset != 19 { // Should be capped at hwm-1 - t.Errorf("Expected end offset 19, got %d", offsetRange.EndOffset) - } - if offsetRange.Count != 5 { - t.Errorf("Expected count 5, got %d", offsetRange.Count) - } -} - -func TestOffsetSubscription_EmptyRange(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - subscriber := NewOffsetSubscriber(registry) - partition := createTestPartition() - - // Assign offsets - registry.AssignOffsets("test-namespace", "test-topic", partition, 10) - - // Create subscription at end - sub, err := subscriber.CreateSubscription("empty-range-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Request range when at end - offsetRange, err := sub.GetOffsetRange(5) - if err != nil { - t.Fatalf("Failed to get offset range at end: %v", err) - } - - if offsetRange.Count != 0 { - t.Errorf("Expected empty range (count 0), got count %d", offsetRange.Count) - } - - if offsetRange.StartOffset != 10 { - t.Errorf("Expected start offset 10, got %d", offsetRange.StartOffset) - } - - if offsetRange.EndOffset != 9 { // Empty range: end < start - t.Errorf("Expected end offset 9 (empty range), got %d", offsetRange.EndOffset) - } -} - -func TestOffsetSeeker_ValidateOffsetRange(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - seeker := NewOffsetSeeker(registry) - partition := createTestPartition() - - // Assign offsets - registry.AssignOffsets("test-namespace", "test-topic", partition, 15) - - // Test valid range - err := seeker.ValidateOffsetRange("test-namespace", "test-topic", partition, 5, 10) - if err != nil { - t.Errorf("Valid range should not return error: %v", err) - } - - // Test invalid ranges - testCases := []struct { - name string - startOffset int64 - endOffset int64 - expectError bool - }{ - {"negative start", -1, 5, true}, - {"end before start", 10, 5, true}, - {"start beyond hwm", 20, 25, true}, - {"valid range", 0, 14, false}, - {"single offset", 5, 5, false}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := seeker.ValidateOffsetRange("test-namespace", "test-topic", partition, tc.startOffset, tc.endOffset) - if tc.expectError && err == nil { - t.Error("Expected error but got none") - } - if !tc.expectError && err != nil { - t.Errorf("Expected no error but got: %v", err) - } - }) - } -} - -func TestOffsetSeeker_GetAvailableOffsetRange(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - seeker := NewOffsetSeeker(registry) - partition := createTestPartition() - - // Test empty partition - offsetRange, err := seeker.GetAvailableOffsetRange("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get available range for empty partition: %v", err) - } - - if offsetRange.Count != 0 { - t.Errorf("Expected empty range for empty partition, got count %d", offsetRange.Count) - } - - // Assign offsets and test again - registry.AssignOffsets("test-namespace", "test-topic", partition, 25) - - offsetRange, err = seeker.GetAvailableOffsetRange("test-namespace", "test-topic", partition) - if err != nil { - t.Fatalf("Failed to get available range: %v", err) - } - - if offsetRange.StartOffset != 0 { - t.Errorf("Expected start offset 0, got %d", offsetRange.StartOffset) - } - if offsetRange.EndOffset != 24 { - t.Errorf("Expected end offset 24, got %d", offsetRange.EndOffset) - } - if offsetRange.Count != 25 { - t.Errorf("Expected count 25, got %d", offsetRange.Count) - } -} - -func TestOffsetSubscriber_CloseSubscription(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - subscriber := NewOffsetSubscriber(registry) - partition := createTestPartition() - - // Create subscription - sub, err := subscriber.CreateSubscription("close-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - // Verify subscription exists - _, err = subscriber.GetSubscription("close-test") - if err != nil { - t.Fatalf("Subscription should exist: %v", err) - } - - // Close subscription - err = subscriber.CloseSubscription("close-test") - if err != nil { - t.Fatalf("Failed to close subscription: %v", err) - } - - // Verify subscription is gone - _, err = subscriber.GetSubscription("close-test") - if err == nil { - t.Error("Subscription should not exist after close") - } - - // Verify subscription is marked inactive - if sub.IsActive { - t.Error("Subscription should be marked inactive after close") - } -} - -func TestOffsetSubscription_InactiveOperations(t *testing.T) { - storage := NewInMemoryOffsetStorage() - registry := NewPartitionOffsetRegistry(storage) - subscriber := NewOffsetSubscriber(registry) - partition := createTestPartition() - - // Create and close subscription - sub, err := subscriber.CreateSubscription("inactive-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0) - if err != nil { - t.Fatalf("Failed to create subscription: %v", err) - } - - subscriber.CloseSubscription("inactive-test") - - // Test operations on inactive subscription - err = sub.SeekToOffset(5) - if err == nil { - t.Error("Expected error for seek on inactive subscription") - } - - _, err = sub.GetLag() - if err == nil { - t.Error("Expected error for GetLag on inactive subscription") - } - - _, err = sub.IsAtEnd() - if err == nil { - t.Error("Expected error for IsAtEnd on inactive subscription") - } - - _, err = sub.GetOffsetRange(10) - if err == nil { - t.Error("Expected error for GetOffsetRange on inactive subscription") - } -} diff --git a/weed/mq/pub_balancer/repair.go b/weed/mq/pub_balancer/repair.go index 9af81d27f..549843978 100644 --- a/weed/mq/pub_balancer/repair.go +++ b/weed/mq/pub_balancer/repair.go @@ -1,13 +1,6 @@ package pub_balancer -import ( - "math/rand/v2" - "sort" - - cmap "github.com/orcaman/concurrent-map/v2" - "github.com/seaweedfs/seaweedfs/weed/mq/topic" - "modernc.org/mathutil" -) +import () func (balancer *PubBalancer) RepairTopics() []BalanceAction { action := BalanceTopicPartitionOnBrokers(balancer.Brokers) @@ -17,107 +10,3 @@ func (balancer *PubBalancer) RepairTopics() []BalanceAction { type TopicPartitionInfo struct { Broker string } - -// RepairMissingTopicPartitions check the stats of all brokers, -// and repair the missing topic partitions on the brokers. -func RepairMissingTopicPartitions(brokers cmap.ConcurrentMap[string, *BrokerStats]) (actions []BalanceAction) { - - // find all topic partitions - topicToTopicPartitions := make(map[topic.Topic]map[topic.Partition]*TopicPartitionInfo) - for brokerStatsItem := range brokers.IterBuffered() { - broker, brokerStats := brokerStatsItem.Key, brokerStatsItem.Val - for topicPartitionStatsItem := range brokerStats.TopicPartitionStats.IterBuffered() { - topicPartitionStat := topicPartitionStatsItem.Val - topicPartitionToInfo, found := topicToTopicPartitions[topicPartitionStat.Topic] - if !found { - topicPartitionToInfo = make(map[topic.Partition]*TopicPartitionInfo) - topicToTopicPartitions[topicPartitionStat.Topic] = topicPartitionToInfo - } - tpi, found := topicPartitionToInfo[topicPartitionStat.Partition] - if !found { - tpi = &TopicPartitionInfo{} - topicPartitionToInfo[topicPartitionStat.Partition] = tpi - } - tpi.Broker = broker - } - } - - // collect all brokers as candidates - candidates := make([]string, 0, brokers.Count()) - for brokerStatsItem := range brokers.IterBuffered() { - candidates = append(candidates, brokerStatsItem.Key) - } - - // find the missing topic partitions - for t, topicPartitionToInfo := range topicToTopicPartitions { - missingPartitions := EachTopicRepairMissingTopicPartitions(t, topicPartitionToInfo) - for _, partition := range missingPartitions { - actions = append(actions, BalanceActionCreate{ - TopicPartition: topic.TopicPartition{ - Topic: t, - Partition: partition, - }, - TargetBroker: candidates[rand.IntN(len(candidates))], - }) - } - } - - return actions -} - -func EachTopicRepairMissingTopicPartitions(t topic.Topic, info map[topic.Partition]*TopicPartitionInfo) (missingPartitions []topic.Partition) { - - // find the missing topic partitions - var partitions []topic.Partition - for partition := range info { - partitions = append(partitions, partition) - } - return findMissingPartitions(partitions, MaxPartitionCount) -} - -// findMissingPartitions find the missing partitions -func findMissingPartitions(partitions []topic.Partition, ringSize int32) (missingPartitions []topic.Partition) { - // sort the partitions by range start - sort.Slice(partitions, func(i, j int) bool { - return partitions[i].RangeStart < partitions[j].RangeStart - }) - - // calculate the average partition size - var covered int32 - for _, partition := range partitions { - covered += partition.RangeStop - partition.RangeStart - } - averagePartitionSize := covered / int32(len(partitions)) - - // find the missing partitions - var coveredWatermark int32 - i := 0 - for i < len(partitions) { - partition := partitions[i] - if partition.RangeStart > coveredWatermark { - upperBound := mathutil.MinInt32(coveredWatermark+averagePartitionSize, partition.RangeStart) - missingPartitions = append(missingPartitions, topic.Partition{ - RangeStart: coveredWatermark, - RangeStop: upperBound, - RingSize: ringSize, - }) - coveredWatermark = upperBound - if coveredWatermark == partition.RangeStop { - i++ - } - } else { - coveredWatermark = partition.RangeStop - i++ - } - } - for coveredWatermark < ringSize { - upperBound := mathutil.MinInt32(coveredWatermark+averagePartitionSize, ringSize) - missingPartitions = append(missingPartitions, topic.Partition{ - RangeStart: coveredWatermark, - RangeStop: upperBound, - RingSize: ringSize, - }) - coveredWatermark = upperBound - } - return missingPartitions -} diff --git a/weed/mq/pub_balancer/repair_test.go b/weed/mq/pub_balancer/repair_test.go deleted file mode 100644 index 4ccf59e13..000000000 --- a/weed/mq/pub_balancer/repair_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package pub_balancer - -import ( - "reflect" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/mq/topic" -) - -func Test_findMissingPartitions(t *testing.T) { - type args struct { - partitions []topic.Partition - } - tests := []struct { - name string - args args - wantMissingPartitions []topic.Partition - }{ - { - name: "one partition", - args: args{ - partitions: []topic.Partition{ - {RingSize: 1024, RangeStart: 0, RangeStop: 1024}, - }, - }, - wantMissingPartitions: nil, - }, - { - name: "two partitions", - args: args{ - partitions: []topic.Partition{ - {RingSize: 1024, RangeStart: 0, RangeStop: 512}, - {RingSize: 1024, RangeStart: 512, RangeStop: 1024}, - }, - }, - wantMissingPartitions: nil, - }, - { - name: "four partitions, missing last two", - args: args{ - partitions: []topic.Partition{ - {RingSize: 1024, RangeStart: 0, RangeStop: 256}, - {RingSize: 1024, RangeStart: 256, RangeStop: 512}, - }, - }, - wantMissingPartitions: []topic.Partition{ - {RingSize: 1024, RangeStart: 512, RangeStop: 768}, - {RingSize: 1024, RangeStart: 768, RangeStop: 1024}, - }, - }, - { - name: "four partitions, missing first two", - args: args{ - partitions: []topic.Partition{ - {RingSize: 1024, RangeStart: 512, RangeStop: 768}, - {RingSize: 1024, RangeStart: 768, RangeStop: 1024}, - }, - }, - wantMissingPartitions: []topic.Partition{ - {RingSize: 1024, RangeStart: 0, RangeStop: 256}, - {RingSize: 1024, RangeStart: 256, RangeStop: 512}, - }, - }, - { - name: "four partitions, missing middle two", - args: args{ - partitions: []topic.Partition{ - {RingSize: 1024, RangeStart: 0, RangeStop: 256}, - {RingSize: 1024, RangeStart: 768, RangeStop: 1024}, - }, - }, - wantMissingPartitions: []topic.Partition{ - {RingSize: 1024, RangeStart: 256, RangeStop: 512}, - {RingSize: 1024, RangeStart: 512, RangeStop: 768}, - }, - }, - { - name: "four partitions, missing three", - args: args{ - partitions: []topic.Partition{ - {RingSize: 1024, RangeStart: 512, RangeStop: 768}, - }, - }, - wantMissingPartitions: []topic.Partition{ - {RingSize: 1024, RangeStart: 0, RangeStop: 256}, - {RingSize: 1024, RangeStart: 256, RangeStop: 512}, - {RingSize: 1024, RangeStart: 768, RangeStop: 1024}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if gotMissingPartitions := findMissingPartitions(tt.args.partitions, 1024); !reflect.DeepEqual(gotMissingPartitions, tt.wantMissingPartitions) { - t.Errorf("findMissingPartitions() = %v, want %v", gotMissingPartitions, tt.wantMissingPartitions) - } - }) - } -} diff --git a/weed/mq/segment/message_serde.go b/weed/mq/segment/message_serde.go deleted file mode 100644 index 66a76c57d..000000000 --- a/weed/mq/segment/message_serde.go +++ /dev/null @@ -1,109 +0,0 @@ -package segment - -import ( - flatbuffers "github.com/google/flatbuffers/go" - "github.com/seaweedfs/seaweedfs/weed/pb/message_fbs" -) - -type MessageBatchBuilder struct { - b *flatbuffers.Builder - producerId int32 - producerEpoch int32 - segmentId int32 - flags int32 - messageOffsets []flatbuffers.UOffsetT - segmentSeqBase int64 - segmentSeqLast int64 - tsMsBase int64 - tsMsLast int64 -} - -func NewMessageBatchBuilder(b *flatbuffers.Builder, - producerId int32, - producerEpoch int32, - segmentId int32, - flags int32) *MessageBatchBuilder { - - b.Reset() - - return &MessageBatchBuilder{ - b: b, - producerId: producerId, - producerEpoch: producerEpoch, - segmentId: segmentId, - flags: flags, - } -} - -func (builder *MessageBatchBuilder) AddMessage(segmentSeq int64, tsMs int64, properties map[string][]byte, key []byte, value []byte) { - if builder.segmentSeqBase == 0 { - builder.segmentSeqBase = segmentSeq - } - builder.segmentSeqLast = segmentSeq - if builder.tsMsBase == 0 { - builder.tsMsBase = tsMs - } - builder.tsMsLast = tsMs - - var names, values, pairs []flatbuffers.UOffsetT - for k, v := range properties { - names = append(names, builder.b.CreateString(k)) - values = append(values, builder.b.CreateByteVector(v)) - } - for i, _ := range names { - message_fbs.NameValueStart(builder.b) - message_fbs.NameValueAddName(builder.b, names[i]) - message_fbs.NameValueAddValue(builder.b, values[i]) - pair := message_fbs.NameValueEnd(builder.b) - pairs = append(pairs, pair) - } - - message_fbs.MessageStartPropertiesVector(builder.b, len(properties)) - for i := len(pairs) - 1; i >= 0; i-- { - builder.b.PrependUOffsetT(pairs[i]) - } - propOffset := builder.b.EndVector(len(properties)) - - keyOffset := builder.b.CreateByteVector(key) - valueOffset := builder.b.CreateByteVector(value) - - message_fbs.MessageStart(builder.b) - message_fbs.MessageAddSeqDelta(builder.b, int32(segmentSeq-builder.segmentSeqBase)) - message_fbs.MessageAddTsMsDelta(builder.b, int32(tsMs-builder.tsMsBase)) - - message_fbs.MessageAddProperties(builder.b, propOffset) - message_fbs.MessageAddKey(builder.b, keyOffset) - message_fbs.MessageAddData(builder.b, valueOffset) - messageOffset := message_fbs.MessageEnd(builder.b) - - builder.messageOffsets = append(builder.messageOffsets, messageOffset) - -} - -func (builder *MessageBatchBuilder) BuildMessageBatch() { - message_fbs.MessageBatchStartMessagesVector(builder.b, len(builder.messageOffsets)) - for i := len(builder.messageOffsets) - 1; i >= 0; i-- { - builder.b.PrependUOffsetT(builder.messageOffsets[i]) - } - messagesOffset := builder.b.EndVector(len(builder.messageOffsets)) - - message_fbs.MessageBatchStart(builder.b) - message_fbs.MessageBatchAddProducerId(builder.b, builder.producerId) - message_fbs.MessageBatchAddProducerEpoch(builder.b, builder.producerEpoch) - message_fbs.MessageBatchAddSegmentId(builder.b, builder.segmentId) - message_fbs.MessageBatchAddFlags(builder.b, builder.flags) - message_fbs.MessageBatchAddSegmentSeqBase(builder.b, builder.segmentSeqBase) - message_fbs.MessageBatchAddSegmentSeqMaxDelta(builder.b, int32(builder.segmentSeqLast-builder.segmentSeqBase)) - message_fbs.MessageBatchAddTsMsBase(builder.b, builder.tsMsBase) - message_fbs.MessageBatchAddTsMsMaxDelta(builder.b, int32(builder.tsMsLast-builder.tsMsBase)) - - message_fbs.MessageBatchAddMessages(builder.b, messagesOffset) - - messageBatch := message_fbs.MessageBatchEnd(builder.b) - - builder.b.Finish(messageBatch) -} - -func (builder *MessageBatchBuilder) GetBytes() []byte { - return builder.b.FinishedBytes() -} diff --git a/weed/mq/segment/message_serde_test.go b/weed/mq/segment/message_serde_test.go deleted file mode 100644 index 52c9d8e55..000000000 --- a/weed/mq/segment/message_serde_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package segment - -import ( - "testing" - - flatbuffers "github.com/google/flatbuffers/go" - "github.com/seaweedfs/seaweedfs/weed/pb/message_fbs" - "github.com/stretchr/testify/assert" -) - -func TestMessageSerde(t *testing.T) { - b := flatbuffers.NewBuilder(1024) - - prop := make(map[string][]byte) - prop["n1"] = []byte("v1") - prop["n2"] = []byte("v2") - - bb := NewMessageBatchBuilder(b, 1, 2, 3, 4) - - bb.AddMessage(5, 6, prop, []byte("the primary key"), []byte("body is here")) - bb.AddMessage(5, 7, prop, []byte("the primary 2"), []byte("body is 2")) - - bb.BuildMessageBatch() - - buf := bb.GetBytes() - - println("serialized size", len(buf)) - - mb := message_fbs.GetRootAsMessageBatch(buf, 0) - - assert.Equal(t, int32(1), mb.ProducerId()) - assert.Equal(t, int32(2), mb.ProducerEpoch()) - assert.Equal(t, int32(3), mb.SegmentId()) - assert.Equal(t, int32(4), mb.Flags()) - assert.Equal(t, int64(5), mb.SegmentSeqBase()) - assert.Equal(t, int32(0), mb.SegmentSeqMaxDelta()) - assert.Equal(t, int64(6), mb.TsMsBase()) - assert.Equal(t, int32(1), mb.TsMsMaxDelta()) - - assert.Equal(t, 2, mb.MessagesLength()) - - m := &message_fbs.Message{} - mb.Messages(m, 0) - - /* - // the vector seems not consistent - nv := &message_fbs.NameValue{} - m.Properties(nv, 0) - assert.Equal(t, "n1", string(nv.Name())) - assert.Equal(t, "v1", string(nv.Value())) - m.Properties(nv, 1) - assert.Equal(t, "n2", string(nv.Name())) - assert.Equal(t, "v2", string(nv.Value())) - */ - assert.Equal(t, []byte("the primary key"), m.Key()) - assert.Equal(t, []byte("body is here"), m.Data()) - - assert.Equal(t, int32(0), m.SeqDelta()) - assert.Equal(t, int32(0), m.TsMsDelta()) - -} diff --git a/weed/mq/sub_coordinator/inflight_message_tracker.go b/weed/mq/sub_coordinator/inflight_message_tracker.go index 8ecbb2ccd..c78e7883e 100644 --- a/weed/mq/sub_coordinator/inflight_message_tracker.go +++ b/weed/mq/sub_coordinator/inflight_message_tracker.go @@ -28,28 +28,6 @@ func (imt *InflightMessageTracker) EnflightMessage(key []byte, tsNs int64) { imt.timestamps.EnflightTimestamp(tsNs) } -// IsMessageAcknowledged returns true if the message has been acknowledged. -// If the message is older than the oldest inflight messages, returns false. -// returns false if the message is inflight. -// Otherwise, returns false if the message is old and can be ignored. -func (imt *InflightMessageTracker) IsMessageAcknowledged(key []byte, tsNs int64) bool { - imt.mu.Lock() - defer imt.mu.Unlock() - - if tsNs <= imt.timestamps.OldestAckedTimestamp() { - return true - } - if tsNs > imt.timestamps.Latest() { - return false - } - - if _, found := imt.messages[string(key)]; found { - return false - } - - return true -} - // AcknowledgeMessage acknowledges the message with the key and timestamp. func (imt *InflightMessageTracker) AcknowledgeMessage(key []byte, tsNs int64) bool { // fmt.Printf("AcknowledgeMessage(%s,%d)\n", string(key), tsNs) @@ -164,8 +142,3 @@ func (rb *RingBuffer) AckTimestamp(timestamp int64) { func (rb *RingBuffer) OldestAckedTimestamp() int64 { return rb.maxAllAckedTs } - -// Latest returns the most recently known timestamp in the ring buffer. -func (rb *RingBuffer) Latest() int64 { - return rb.maxTimestamp -} diff --git a/weed/mq/sub_coordinator/inflight_message_tracker_test.go b/weed/mq/sub_coordinator/inflight_message_tracker_test.go deleted file mode 100644 index a5c63d561..000000000 --- a/weed/mq/sub_coordinator/inflight_message_tracker_test.go +++ /dev/null @@ -1,134 +0,0 @@ -package sub_coordinator - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestRingBuffer(t *testing.T) { - // Initialize a RingBuffer with capacity 5 - rb := NewRingBuffer(5) - - // Add timestamps to the buffer - timestamps := []int64{100, 200, 300, 400, 500} - for _, ts := range timestamps { - rb.EnflightTimestamp(ts) - } - - // Test Add method and buffer size - expectedSize := 5 - if rb.size != expectedSize { - t.Errorf("Expected buffer size %d, got %d", expectedSize, rb.size) - } - - assert.Equal(t, int64(0), rb.OldestAckedTimestamp()) - assert.Equal(t, int64(500), rb.Latest()) - - rb.AckTimestamp(200) - assert.Equal(t, int64(0), rb.OldestAckedTimestamp()) - rb.AckTimestamp(100) - assert.Equal(t, int64(200), rb.OldestAckedTimestamp()) - - rb.EnflightTimestamp(int64(600)) - rb.EnflightTimestamp(int64(700)) - - rb.AckTimestamp(500) - assert.Equal(t, int64(200), rb.OldestAckedTimestamp()) - rb.AckTimestamp(400) - assert.Equal(t, int64(200), rb.OldestAckedTimestamp()) - rb.AckTimestamp(300) - assert.Equal(t, int64(500), rb.OldestAckedTimestamp()) - - assert.Equal(t, int64(700), rb.Latest()) -} - -func TestInflightMessageTracker(t *testing.T) { - // Initialize an InflightMessageTracker with capacity 5 - tracker := NewInflightMessageTracker(5) - - // Add inflight messages - key := []byte("1") - timestamp := int64(1) - tracker.EnflightMessage(key, timestamp) - - // Test IsMessageAcknowledged method - isOld := tracker.IsMessageAcknowledged(key, timestamp-10) - if !isOld { - t.Error("Expected message to be old") - } - - // Test AcknowledgeMessage method - acked := tracker.AcknowledgeMessage(key, timestamp) - if !acked { - t.Error("Expected message to be acked") - } - if _, exists := tracker.messages[string(key)]; exists { - t.Error("Expected message to be deleted after ack") - } - if tracker.timestamps.size != 0 { - t.Error("Expected buffer size to be 0 after ack") - } - assert.Equal(t, timestamp, tracker.GetOldestAckedTimestamp()) -} - -func TestInflightMessageTracker2(t *testing.T) { - // Initialize an InflightMessageTracker with initial capacity 1 - tracker := NewInflightMessageTracker(1) - - tracker.EnflightMessage([]byte("1"), int64(1)) - tracker.EnflightMessage([]byte("2"), int64(2)) - tracker.EnflightMessage([]byte("3"), int64(3)) - tracker.EnflightMessage([]byte("4"), int64(4)) - tracker.EnflightMessage([]byte("5"), int64(5)) - assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1))) - assert.Equal(t, int64(1), tracker.GetOldestAckedTimestamp()) - - // Test IsMessageAcknowledged method - isAcked := tracker.IsMessageAcknowledged([]byte("2"), int64(2)) - if isAcked { - t.Error("Expected message to be not acked") - } - - // Test AcknowledgeMessage method - assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2))) - assert.Equal(t, int64(2), tracker.GetOldestAckedTimestamp()) - -} - -func TestInflightMessageTracker3(t *testing.T) { - // Initialize an InflightMessageTracker with initial capacity 1 - tracker := NewInflightMessageTracker(1) - - tracker.EnflightMessage([]byte("1"), int64(1)) - tracker.EnflightMessage([]byte("2"), int64(2)) - tracker.EnflightMessage([]byte("3"), int64(3)) - assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1))) - tracker.EnflightMessage([]byte("4"), int64(4)) - tracker.EnflightMessage([]byte("5"), int64(5)) - assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2))) - assert.True(t, tracker.AcknowledgeMessage([]byte("3"), int64(3))) - tracker.EnflightMessage([]byte("6"), int64(6)) - tracker.EnflightMessage([]byte("7"), int64(7)) - assert.True(t, tracker.AcknowledgeMessage([]byte("4"), int64(4))) - assert.True(t, tracker.AcknowledgeMessage([]byte("5"), int64(5))) - assert.True(t, tracker.AcknowledgeMessage([]byte("6"), int64(6))) - assert.Equal(t, int64(6), tracker.GetOldestAckedTimestamp()) - assert.True(t, tracker.AcknowledgeMessage([]byte("7"), int64(7))) - assert.Equal(t, int64(7), tracker.GetOldestAckedTimestamp()) - -} - -func TestInflightMessageTracker4(t *testing.T) { - // Initialize an InflightMessageTracker with initial capacity 1 - tracker := NewInflightMessageTracker(1) - - tracker.EnflightMessage([]byte("1"), int64(1)) - tracker.EnflightMessage([]byte("2"), int64(2)) - assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1))) - assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2))) - tracker.EnflightMessage([]byte("3"), int64(3)) - assert.True(t, tracker.AcknowledgeMessage([]byte("3"), int64(3))) - assert.Equal(t, int64(3), tracker.GetOldestAckedTimestamp()) - -} diff --git a/weed/mq/sub_coordinator/partition_consumer_mapping.go b/weed/mq/sub_coordinator/partition_consumer_mapping.go index e4d00a0dd..ec3a6582f 100644 --- a/weed/mq/sub_coordinator/partition_consumer_mapping.go +++ b/weed/mq/sub_coordinator/partition_consumer_mapping.go @@ -1,130 +1,6 @@ package sub_coordinator -import ( - "fmt" - "time" - - "github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer" -) - type PartitionConsumerMapping struct { currentMapping *PartitionSlotToConsumerInstanceList prevMappings []*PartitionSlotToConsumerInstanceList } - -// Balance goal: -// 1. max processing power utilization -// 2. allow one consumer instance to be down unexpectedly -// without affecting the processing power utilization - -func (pcm *PartitionConsumerMapping) BalanceToConsumerInstances(partitionSlotToBrokerList *pub_balancer.PartitionSlotToBrokerList, consumerInstances []*ConsumerGroupInstance) { - if len(partitionSlotToBrokerList.PartitionSlots) == 0 || len(consumerInstances) == 0 { - return - } - newMapping := NewPartitionSlotToConsumerInstanceList(partitionSlotToBrokerList.RingSize, time.Now()) - var prevMapping *PartitionSlotToConsumerInstanceList - if len(pcm.prevMappings) > 0 { - prevMapping = pcm.prevMappings[len(pcm.prevMappings)-1] - } else { - prevMapping = nil - } - newMapping.PartitionSlots = doBalanceSticky(partitionSlotToBrokerList.PartitionSlots, consumerInstances, prevMapping) - if pcm.currentMapping != nil { - pcm.prevMappings = append(pcm.prevMappings, pcm.currentMapping) - if len(pcm.prevMappings) > 10 { - pcm.prevMappings = pcm.prevMappings[1:] - } - } - pcm.currentMapping = newMapping -} - -func doBalanceSticky(partitions []*pub_balancer.PartitionSlotToBroker, consumerInstances []*ConsumerGroupInstance, prevMapping *PartitionSlotToConsumerInstanceList) (partitionSlots []*PartitionSlotToConsumerInstance) { - // collect previous consumer instance ids - prevConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{}) - if prevMapping != nil { - for _, prevPartitionSlot := range prevMapping.PartitionSlots { - if prevPartitionSlot.AssignedInstanceId != "" { - prevConsumerInstanceIds[prevPartitionSlot.AssignedInstanceId] = struct{}{} - } - } - } - // collect current consumer instance ids - currConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{}) - for _, consumerInstance := range consumerInstances { - currConsumerInstanceIds[consumerInstance.InstanceId] = struct{}{} - } - - // check deleted consumer instances - deletedConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{}) - for consumerInstanceId := range prevConsumerInstanceIds { - if _, ok := currConsumerInstanceIds[consumerInstanceId]; !ok { - deletedConsumerInstanceIds[consumerInstanceId] = struct{}{} - } - } - - // convert partition slots from list to a map - prevPartitionSlotMap := make(map[string]*PartitionSlotToConsumerInstance) - if prevMapping != nil { - for _, partitionSlot := range prevMapping.PartitionSlots { - key := fmt.Sprintf("%d-%d", partitionSlot.RangeStart, partitionSlot.RangeStop) - prevPartitionSlotMap[key] = partitionSlot - } - } - - // make a copy of old mapping, skipping the deleted consumer instances - newPartitionSlots := make([]*PartitionSlotToConsumerInstance, 0, len(partitions)) - for _, partition := range partitions { - newPartitionSlots = append(newPartitionSlots, &PartitionSlotToConsumerInstance{ - RangeStart: partition.RangeStart, - RangeStop: partition.RangeStop, - UnixTimeNs: partition.UnixTimeNs, - Broker: partition.AssignedBroker, - FollowerBroker: partition.FollowerBroker, - }) - } - for _, newPartitionSlot := range newPartitionSlots { - key := fmt.Sprintf("%d-%d", newPartitionSlot.RangeStart, newPartitionSlot.RangeStop) - if prevPartitionSlot, ok := prevPartitionSlotMap[key]; ok { - if _, ok := deletedConsumerInstanceIds[prevPartitionSlot.AssignedInstanceId]; !ok { - newPartitionSlot.AssignedInstanceId = prevPartitionSlot.AssignedInstanceId - } - } - } - - // for all consumer instances, count the average number of partitions - // that are assigned to them - consumerInstancePartitionCount := make(map[ConsumerGroupInstanceId]int) - for _, newPartitionSlot := range newPartitionSlots { - if newPartitionSlot.AssignedInstanceId != "" { - consumerInstancePartitionCount[newPartitionSlot.AssignedInstanceId]++ - } - } - // average number of partitions that are assigned to each consumer instance - averageConsumerInstanceLoad := float32(len(partitions)) / float32(len(consumerInstances)) - - // assign unassigned partition slots to consumer instances that is underloaded - consumerInstanceIdsIndex := 0 - for _, newPartitionSlot := range newPartitionSlots { - if newPartitionSlot.AssignedInstanceId == "" { - for avoidDeadLoop := len(consumerInstances); avoidDeadLoop > 0; avoidDeadLoop-- { - consumerInstance := consumerInstances[consumerInstanceIdsIndex] - if float32(consumerInstancePartitionCount[consumerInstance.InstanceId]) < averageConsumerInstanceLoad { - newPartitionSlot.AssignedInstanceId = consumerInstance.InstanceId - consumerInstancePartitionCount[consumerInstance.InstanceId]++ - consumerInstanceIdsIndex++ - if consumerInstanceIdsIndex >= len(consumerInstances) { - consumerInstanceIdsIndex = 0 - } - break - } else { - consumerInstanceIdsIndex++ - if consumerInstanceIdsIndex >= len(consumerInstances) { - consumerInstanceIdsIndex = 0 - } - } - } - } - } - - return newPartitionSlots -} diff --git a/weed/mq/sub_coordinator/partition_consumer_mapping_test.go b/weed/mq/sub_coordinator/partition_consumer_mapping_test.go deleted file mode 100644 index ccc4e8601..000000000 --- a/weed/mq/sub_coordinator/partition_consumer_mapping_test.go +++ /dev/null @@ -1,385 +0,0 @@ -package sub_coordinator - -import ( - "reflect" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer" -) - -func Test_doBalanceSticky(t *testing.T) { - type args struct { - partitions []*pub_balancer.PartitionSlotToBroker - consumerInstanceIds []*ConsumerGroupInstance - prevMapping *PartitionSlotToConsumerInstanceList - } - tests := []struct { - name string - args args - wantPartitionSlots []*PartitionSlotToConsumerInstance - }{ - { - name: "1 consumer instance, 1 partition", - args: args{ - partitions: []*pub_balancer.PartitionSlotToBroker{ - { - RangeStart: 0, - RangeStop: 100, - }, - }, - consumerInstanceIds: []*ConsumerGroupInstance{ - { - InstanceId: "consumer-instance-1", - MaxPartitionCount: 1, - }, - }, - prevMapping: nil, - }, - wantPartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-1", - }, - }, - }, - { - name: "2 consumer instances, 1 partition", - args: args{ - partitions: []*pub_balancer.PartitionSlotToBroker{ - { - RangeStart: 0, - RangeStop: 100, - }, - }, - consumerInstanceIds: []*ConsumerGroupInstance{ - { - InstanceId: "consumer-instance-1", - MaxPartitionCount: 1, - }, - { - InstanceId: "consumer-instance-2", - MaxPartitionCount: 1, - }, - }, - prevMapping: nil, - }, - wantPartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-1", - }, - }, - }, - { - name: "1 consumer instance, 2 partitions", - args: args{ - partitions: []*pub_balancer.PartitionSlotToBroker{ - { - RangeStart: 0, - RangeStop: 50, - }, - { - RangeStart: 50, - RangeStop: 100, - }, - }, - consumerInstanceIds: []*ConsumerGroupInstance{ - { - InstanceId: "consumer-instance-1", - MaxPartitionCount: 1, - }, - }, - prevMapping: nil, - }, - wantPartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 50, - AssignedInstanceId: "consumer-instance-1", - }, - { - RangeStart: 50, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-1", - }, - }, - }, - { - name: "2 consumer instances, 2 partitions", - args: args{ - partitions: []*pub_balancer.PartitionSlotToBroker{ - { - RangeStart: 0, - RangeStop: 50, - }, - { - RangeStart: 50, - RangeStop: 100, - }, - }, - consumerInstanceIds: []*ConsumerGroupInstance{ - { - InstanceId: "consumer-instance-1", - MaxPartitionCount: 1, - }, - { - InstanceId: "consumer-instance-2", - MaxPartitionCount: 1, - }, - }, - prevMapping: nil, - }, - wantPartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 50, - AssignedInstanceId: "consumer-instance-1", - }, - { - RangeStart: 50, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-2", - }, - }, - }, - { - name: "2 consumer instances, 2 partitions, 1 deleted consumer instance", - args: args{ - partitions: []*pub_balancer.PartitionSlotToBroker{ - { - RangeStart: 0, - RangeStop: 50, - }, - { - RangeStart: 50, - RangeStop: 100, - }, - }, - consumerInstanceIds: []*ConsumerGroupInstance{ - { - InstanceId: "consumer-instance-1", - MaxPartitionCount: 1, - }, - { - InstanceId: "consumer-instance-2", - MaxPartitionCount: 1, - }, - }, - prevMapping: &PartitionSlotToConsumerInstanceList{ - PartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 50, - AssignedInstanceId: "consumer-instance-3", - }, - { - RangeStart: 50, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-2", - }, - }, - }, - }, - wantPartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 50, - AssignedInstanceId: "consumer-instance-1", - }, - { - RangeStart: 50, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-2", - }, - }, - }, - { - name: "2 consumer instances, 2 partitions, 1 new consumer instance", - args: args{ - partitions: []*pub_balancer.PartitionSlotToBroker{ - { - RangeStart: 0, - RangeStop: 50, - }, - { - RangeStart: 50, - RangeStop: 100, - }, - }, - consumerInstanceIds: []*ConsumerGroupInstance{ - { - InstanceId: "consumer-instance-1", - MaxPartitionCount: 1, - }, - { - InstanceId: "consumer-instance-2", - MaxPartitionCount: 1, - }, - { - InstanceId: "consumer-instance-3", - MaxPartitionCount: 1, - }, - }, - prevMapping: &PartitionSlotToConsumerInstanceList{ - PartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 50, - AssignedInstanceId: "consumer-instance-3", - }, - { - RangeStart: 50, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-2", - }, - }, - }, - }, - wantPartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 50, - AssignedInstanceId: "consumer-instance-3", - }, - { - RangeStart: 50, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-2", - }, - }, - }, - { - name: "2 consumer instances, 2 partitions, 1 new partition", - args: args{ - partitions: []*pub_balancer.PartitionSlotToBroker{ - { - RangeStart: 0, - RangeStop: 50, - }, - { - RangeStart: 50, - RangeStop: 100, - }, - { - RangeStart: 100, - RangeStop: 150, - }, - }, - consumerInstanceIds: []*ConsumerGroupInstance{ - { - InstanceId: "consumer-instance-1", - MaxPartitionCount: 1, - }, - { - InstanceId: "consumer-instance-2", - MaxPartitionCount: 1, - }, - }, - prevMapping: &PartitionSlotToConsumerInstanceList{ - PartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 50, - AssignedInstanceId: "consumer-instance-1", - }, - { - RangeStart: 50, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-2", - }, - }, - }, - }, - wantPartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 50, - AssignedInstanceId: "consumer-instance-1", - }, - { - RangeStart: 50, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-2", - }, - { - RangeStart: 100, - RangeStop: 150, - AssignedInstanceId: "consumer-instance-1", - }, - }, - }, - { - name: "2 consumer instances, 2 partitions, 1 new partition, 1 new consumer instance", - args: args{ - partitions: []*pub_balancer.PartitionSlotToBroker{ - { - RangeStart: 0, - RangeStop: 50, - }, - { - RangeStart: 50, - RangeStop: 100, - }, - { - RangeStart: 100, - RangeStop: 150, - }, - }, - consumerInstanceIds: []*ConsumerGroupInstance{ - { - InstanceId: "consumer-instance-1", - MaxPartitionCount: 1, - }, - { - InstanceId: "consumer-instance-2", - MaxPartitionCount: 1, - }, - { - InstanceId: "consumer-instance-3", - MaxPartitionCount: 1, - }, - }, - prevMapping: &PartitionSlotToConsumerInstanceList{ - PartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 50, - AssignedInstanceId: "consumer-instance-1", - }, - { - RangeStart: 50, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-2", - }, - }, - }, - }, - wantPartitionSlots: []*PartitionSlotToConsumerInstance{ - { - RangeStart: 0, - RangeStop: 50, - AssignedInstanceId: "consumer-instance-1", - }, - { - RangeStart: 50, - RangeStop: 100, - AssignedInstanceId: "consumer-instance-2", - }, - { - RangeStart: 100, - RangeStop: 150, - AssignedInstanceId: "consumer-instance-3", - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if gotPartitionSlots := doBalanceSticky(tt.args.partitions, tt.args.consumerInstanceIds, tt.args.prevMapping); !reflect.DeepEqual(gotPartitionSlots, tt.wantPartitionSlots) { - t.Errorf("doBalanceSticky() = %v, want %v", gotPartitionSlots, tt.wantPartitionSlots) - } - }) - } -} diff --git a/weed/mq/sub_coordinator/partition_list.go b/weed/mq/sub_coordinator/partition_list.go index 16bf1ff0c..38c130598 100644 --- a/weed/mq/sub_coordinator/partition_list.go +++ b/weed/mq/sub_coordinator/partition_list.go @@ -1,7 +1,5 @@ package sub_coordinator -import "time" - type PartitionSlotToConsumerInstance struct { RangeStart int32 RangeStop int32 @@ -16,10 +14,3 @@ type PartitionSlotToConsumerInstanceList struct { RingSize int32 Version int64 } - -func NewPartitionSlotToConsumerInstanceList(ringSize int32, version time.Time) *PartitionSlotToConsumerInstanceList { - return &PartitionSlotToConsumerInstanceList{ - RingSize: ringSize, - Version: version.UnixNano(), - } -} diff --git a/weed/mq/topic/local_partition_offset.go b/weed/mq/topic/local_partition_offset.go index 9c8a2dac4..ef7da3606 100644 --- a/weed/mq/topic/local_partition_offset.go +++ b/weed/mq/topic/local_partition_offset.go @@ -90,22 +90,3 @@ type OffsetAwarePublisher struct { partition *LocalPartition assignOffsetFn OffsetAssignmentFunc } - -// NewOffsetAwarePublisher creates a new offset-aware publisher -func NewOffsetAwarePublisher(partition *LocalPartition, assignOffsetFn OffsetAssignmentFunc) *OffsetAwarePublisher { - return &OffsetAwarePublisher{ - partition: partition, - assignOffsetFn: assignOffsetFn, - } -} - -// Publish publishes a message with automatic offset assignment -func (oap *OffsetAwarePublisher) Publish(message *mq_pb.DataMessage) error { - _, err := oap.partition.PublishWithOffset(message, oap.assignOffsetFn) - return err -} - -// GetPartition returns the underlying partition -func (oap *OffsetAwarePublisher) GetPartition() *LocalPartition { - return oap.partition -} diff --git a/weed/mq/topic/partition.go b/weed/mq/topic/partition.go index 658ec85c4..fc3b71aac 100644 --- a/weed/mq/topic/partition.go +++ b/weed/mq/topic/partition.go @@ -16,15 +16,6 @@ type Partition struct { UnixTimeNs int64 // in nanoseconds } -func NewPartition(rangeStart, rangeStop, ringSize int32, unixTimeNs int64) *Partition { - return &Partition{ - RangeStart: rangeStart, - RangeStop: rangeStop, - RingSize: ringSize, - UnixTimeNs: unixTimeNs, - } -} - func (partition Partition) Equals(other Partition) bool { if partition.RangeStart != other.RangeStart { return false @@ -57,24 +48,6 @@ func FromPbPartition(partition *schema_pb.Partition) Partition { } } -func SplitPartitions(targetCount int32, ts int64) []*Partition { - partitions := make([]*Partition, 0, targetCount) - partitionSize := PartitionCount / targetCount - for i := int32(0); i < targetCount; i++ { - partitionStop := (i + 1) * partitionSize - if i == targetCount-1 { - partitionStop = PartitionCount - } - partitions = append(partitions, &Partition{ - RangeStart: i * partitionSize, - RangeStop: partitionStop, - RingSize: PartitionCount, - UnixTimeNs: ts, - }) - } - return partitions -} - func (partition Partition) ToPbPartition() *schema_pb.Partition { return &schema_pb.Partition{ RangeStart: partition.RangeStart, diff --git a/weed/operation/assign_file_id.go b/weed/operation/assign_file_id.go index 7c2c71074..5609bf8ac 100644 --- a/weed/operation/assign_file_id.go +++ b/weed/operation/assign_file_id.go @@ -3,8 +3,6 @@ package operation import ( "context" "fmt" - "strings" - "sync" "time" "github.com/seaweedfs/seaweedfs/weed/pb" @@ -41,118 +39,6 @@ type AssignResult struct { Replicas []Location `json:"replicas,omitempty"` } -// This is a proxy to the master server, only for assigning volume ids. -// It runs via grpc to the master server in streaming mode. -// The connection to the master would only be re-established when the last connection has error. -type AssignProxy struct { - grpcConnection *grpc.ClientConn - pool chan *singleThreadAssignProxy -} - -func NewAssignProxy(masterFn GetMasterFn, grpcDialOption grpc.DialOption, concurrency int) (ap *AssignProxy, err error) { - ap = &AssignProxy{ - pool: make(chan *singleThreadAssignProxy, concurrency), - } - ap.grpcConnection, err = pb.GrpcDial(context.Background(), masterFn(context.Background()).ToGrpcAddress(), true, grpcDialOption) - if err != nil { - return nil, fmt.Errorf("fail to dial %s: %v", masterFn(context.Background()).ToGrpcAddress(), err) - } - for i := 0; i < concurrency; i++ { - ap.pool <- &singleThreadAssignProxy{} - } - return ap, nil -} - -func (ap *AssignProxy) Assign(primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (ret *AssignResult, err error) { - p := <-ap.pool - defer func() { - ap.pool <- p - }() - - return p.doAssign(ap.grpcConnection, primaryRequest, alternativeRequests...) -} - -type singleThreadAssignProxy struct { - assignClient master_pb.Seaweed_StreamAssignClient - sync.Mutex -} - -func (ap *singleThreadAssignProxy) doAssign(grpcConnection *grpc.ClientConn, primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (ret *AssignResult, err error) { - ap.Lock() - defer ap.Unlock() - - if ap.assignClient == nil { - client := master_pb.NewSeaweedClient(grpcConnection) - ap.assignClient, err = client.StreamAssign(context.Background()) - if err != nil { - ap.assignClient = nil - return nil, fmt.Errorf("fail to create stream assign client: %w", err) - } - } - - var requests []*VolumeAssignRequest - requests = append(requests, primaryRequest) - requests = append(requests, alternativeRequests...) - ret = &AssignResult{} - - for _, request := range requests { - if request == nil { - continue - } - req := &master_pb.AssignRequest{ - Count: request.Count, - Replication: request.Replication, - Collection: request.Collection, - Ttl: request.Ttl, - DiskType: request.DiskType, - DataCenter: request.DataCenter, - Rack: request.Rack, - DataNode: request.DataNode, - WritableVolumeCount: request.WritableVolumeCount, - } - if err = ap.assignClient.Send(req); err != nil { - ap.assignClient = nil - return nil, fmt.Errorf("StreamAssignSend: %w", err) - } - resp, grpcErr := ap.assignClient.Recv() - if grpcErr != nil { - ap.assignClient = nil - return nil, grpcErr - } - if resp.Error != "" { - // StreamAssign returns transient warmup errors as in-band responses. - // Wrap them as codes.Unavailable so the caller's retry logic can - // classify them as retriable. - if strings.Contains(resp.Error, "warming up") { - return nil, status.Errorf(codes.Unavailable, "StreamAssignRecv: %s", resp.Error) - } - return nil, fmt.Errorf("StreamAssignRecv: %v", resp.Error) - } - - ret.Count = resp.Count - ret.Fid = resp.Fid - ret.Url = resp.Location.Url - ret.PublicUrl = resp.Location.PublicUrl - ret.GrpcPort = int(resp.Location.GrpcPort) - ret.Error = resp.Error - ret.Auth = security.EncodedJwt(resp.Auth) - for _, r := range resp.Replicas { - ret.Replicas = append(ret.Replicas, Location{ - Url: r.Url, - PublicUrl: r.PublicUrl, - DataCenter: r.DataCenter, - }) - } - - if ret.Count <= 0 { - continue - } - break - } - - return -} - func Assign(ctx context.Context, masterFn GetMasterFn, grpcDialOption grpc.DialOption, primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (*AssignResult, error) { var requests []*VolumeAssignRequest diff --git a/weed/operation/assign_file_id_test.go b/weed/operation/assign_file_id_test.go deleted file mode 100644 index ecfa7d6d0..000000000 --- a/weed/operation/assign_file_id_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package operation - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/pb" - "google.golang.org/grpc" -) - -func BenchmarkWithConcurrency(b *testing.B) { - concurrencyLevels := []int{1, 10, 100, 1000} - - ap, _ := NewAssignProxy(func(_ context.Context) pb.ServerAddress { - return pb.ServerAddress("localhost:9333") - }, grpc.WithInsecure(), 16) - - for _, concurrency := range concurrencyLevels { - b.Run( - fmt.Sprintf("Concurrency-%d", concurrency), - func(b *testing.B) { - for i := 0; i < b.N; i++ { - done := make(chan struct{}) - startTime := time.Now() - - for j := 0; j < concurrency; j++ { - go func() { - - ap.Assign(&VolumeAssignRequest{ - Count: 1, - }) - - done <- struct{}{} - }() - } - - for j := 0; j < concurrency; j++ { - <-done - } - - duration := time.Since(startTime) - b.Logf("Concurrency: %d, Duration: %v", concurrency, duration) - } - }, - ) - } -} - -func BenchmarkStreamAssign(b *testing.B) { - ap, _ := NewAssignProxy(func(_ context.Context) pb.ServerAddress { - return pb.ServerAddress("localhost:9333") - }, grpc.WithInsecure(), 16) - for i := 0; i < b.N; i++ { - ap.Assign(&VolumeAssignRequest{ - Count: 1, - }) - } -} - -func BenchmarkUnaryAssign(b *testing.B) { - for i := 0; i < b.N; i++ { - Assign(context.Background(), func(_ context.Context) pb.ServerAddress { - return pb.ServerAddress("localhost:9333") - }, grpc.WithInsecure(), &VolumeAssignRequest{ - Count: 1, - }) - } -} diff --git a/weed/pb/filer_pb/filer_client.go b/weed/pb/filer_pb/filer_client.go index c93417eee..3e7f9859e 100644 --- a/weed/pb/filer_pb/filer_client.go +++ b/weed/pb/filer_pb/filer_client.go @@ -93,11 +93,6 @@ func List(ctx context.Context, filerClient FilerClient, parentDirectoryPath, pre }) } -func doList(ctx context.Context, filerClient FilerClient, fullDirPath util.FullPath, prefix string, fn EachEntryFunction, startFrom string, inclusive bool, limit uint32) (err error) { - _, err = doListWithSnapshot(ctx, filerClient, fullDirPath, prefix, fn, startFrom, inclusive, limit, 0) - return err -} - func doListWithSnapshot(ctx context.Context, filerClient FilerClient, fullDirPath util.FullPath, prefix string, fn EachEntryFunction, startFrom string, inclusive bool, limit uint32, snapshotTsNs int64) (actualSnapshotTsNs int64, err error) { err = filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error { actualSnapshotTsNs, err = DoSeaweedListWithSnapshot(ctx, client, fullDirPath, prefix, fn, startFrom, inclusive, limit, snapshotTsNs) @@ -212,26 +207,6 @@ func Exists(ctx context.Context, filerClient FilerClient, parentDirectoryPath st return } -func Touch(ctx context.Context, filerClient FilerClient, parentDirectoryPath string, entryName string, entry *Entry) (err error) { - - return filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error { - - request := &UpdateEntryRequest{ - Directory: parentDirectoryPath, - Entry: entry, - } - - glog.V(4).InfofCtx(ctx, "touch entry %v/%v: %v", parentDirectoryPath, entryName, request) - if err := UpdateEntry(ctx, client, request); err != nil { - glog.V(0).InfofCtx(ctx, "touch exists entry %v: %v", request, err) - return fmt.Errorf("touch exists entry %s/%s: %v", parentDirectoryPath, entryName, err) - } - - return nil - }) - -} - func Mkdir(ctx context.Context, filerClient FilerClient, parentDirectoryPath string, dirName string, fn func(entry *Entry)) error { return filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error { return DoMkdir(ctx, client, parentDirectoryPath, dirName, fn) @@ -349,59 +324,3 @@ func DoRemoveWithResponse(ctx context.Context, client SeaweedFilerClient, parent return resp, nil } } - -// DoDeleteEmptyParentDirectories recursively deletes empty parent directories. -// It stops at root "/" or at stopAtPath. -// For safety, dirPath must be under stopAtPath (when stopAtPath is provided). -// The checked map tracks already-processed directories to avoid redundant work in batch operations. -func DoDeleteEmptyParentDirectories(ctx context.Context, client SeaweedFilerClient, dirPath util.FullPath, stopAtPath util.FullPath, checked map[string]bool) { - if dirPath == "/" || dirPath == stopAtPath { - return - } - - // Skip if already checked (for batch delete optimization) - dirPathStr := string(dirPath) - if checked != nil { - if checked[dirPathStr] { - return - } - checked[dirPathStr] = true - } - - // Safety check: if stopAtPath is provided, dirPath must be under it (root "/" allows everything) - stopStr := string(stopAtPath) - if stopAtPath != "" && stopStr != "/" && !strings.HasPrefix(dirPathStr+"/", stopStr+"/") { - glog.V(1).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: %s is not under %s, skipping", dirPath, stopAtPath) - return - } - - // Check if directory is empty by listing with limit 1 - isEmpty := true - err := SeaweedList(ctx, client, dirPathStr, "", func(entry *Entry, isLast bool) error { - isEmpty = false - return io.EOF // Use sentinel error to explicitly stop iteration - }, "", false, 1) - - if err != nil && err != io.EOF { - glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: error checking %s: %v", dirPath, err) - return - } - - if !isEmpty { - // Directory is not empty, stop checking upward - glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: directory %s is not empty, stopping cleanup", dirPath) - return - } - - // Directory is empty, try to delete it - glog.V(2).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: deleting empty directory %s", dirPath) - parentDir, dirName := dirPath.DirAndName() - - if err := DoRemove(ctx, client, parentDir, dirName, false, false, false, false, nil); err == nil { - // Successfully deleted, continue checking upwards - DoDeleteEmptyParentDirectories(ctx, client, util.FullPath(parentDir), stopAtPath, checked) - } else { - // Failed to delete, stop cleanup - glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: failed to delete %s: %v", dirPath, err) - } -} diff --git a/weed/pb/filer_pb/filer_pb_helper.go b/weed/pb/filer_pb/filer_pb_helper.go index 05d5f602a..b621e366a 100644 --- a/weed/pb/filer_pb/filer_pb_helper.go +++ b/weed/pb/filer_pb/filer_pb_helper.go @@ -111,15 +111,6 @@ func BeforeEntrySerialization(chunks []*FileChunk) { } } -func EnsureFid(chunk *FileChunk) { - if chunk.Fid != nil { - return - } - if fid, err := ToFileIdObject(chunk.FileId); err == nil { - chunk.Fid = fid - } -} - func AfterEntryDeserialization(chunks []*FileChunk) { for _, chunk := range chunks { @@ -309,16 +300,6 @@ func MetadataEventTouchesDirectory(event *SubscribeMetadataResponse, dir string) MetadataEventTargetDirectory(event) == dir } -func MetadataEventTouchesDirectoryPrefix(event *SubscribeMetadataResponse, prefix string) bool { - if strings.HasPrefix(MetadataEventSourceDirectory(event), prefix) { - return true - } - return event != nil && - event.EventNotification != nil && - event.EventNotification.NewEntry != nil && - strings.HasPrefix(MetadataEventTargetDirectory(event), prefix) -} - func MetadataEventMatchesSubscription(event *SubscribeMetadataResponse, pathPrefix string, pathPrefixes []string, directories []string) bool { if event == nil { return false diff --git a/weed/pb/filer_pb/filer_pb_helper_test.go b/weed/pb/filer_pb/filer_pb_helper_test.go deleted file mode 100644 index b38b094e3..000000000 --- a/weed/pb/filer_pb/filer_pb_helper_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package filer_pb - -import ( - "testing" - - "google.golang.org/protobuf/proto" -) - -func TestFileIdSize(t *testing.T) { - fileIdStr := "11745,0293434534cbb9892b" - - fid, _ := ToFileIdObject(fileIdStr) - bytes, _ := proto.Marshal(fid) - - println(len(fileIdStr)) - println(len(bytes)) -} - -func TestMetadataEventMatchesSubscription(t *testing.T) { - event := &SubscribeMetadataResponse{ - Directory: "/tmp", - EventNotification: &EventNotification{ - OldEntry: &Entry{Name: "old-name"}, - NewEntry: &Entry{Name: "new-name"}, - NewParentPath: "/watched", - }, - } - - tests := []struct { - name string - pathPrefix string - pathPrefixes []string - directories []string - }{ - { - name: "primary path prefix matches rename target", - pathPrefix: "/watched/new-name", - }, - { - name: "additional path prefix matches rename target", - pathPrefix: "/data", - pathPrefixes: []string{"/watched"}, - }, - { - name: "directory watch matches rename target directory", - pathPrefix: "/data", - directories: []string{"/watched"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if !MetadataEventMatchesSubscription(event, tt.pathPrefix, tt.pathPrefixes, tt.directories) { - t.Fatalf("MetadataEventMatchesSubscription returned false") - } - }) - } -} - -func TestMetadataEventTouchesDirectoryHelpers(t *testing.T) { - renameInto := &SubscribeMetadataResponse{ - Directory: "/tmp", - EventNotification: &EventNotification{ - OldEntry: &Entry{Name: "filer.conf"}, - NewEntry: &Entry{Name: "filer.conf"}, - NewParentPath: "/etc/seaweedfs", - }, - } - if got := MetadataEventTargetDirectory(renameInto); got != "/etc/seaweedfs" { - t.Fatalf("MetadataEventTargetDirectory = %q, want /etc/seaweedfs", got) - } - if !MetadataEventTouchesDirectory(renameInto, "/etc/seaweedfs") { - t.Fatalf("expected rename target to touch /etc/seaweedfs") - } - - renameOut := &SubscribeMetadataResponse{ - Directory: "/etc/remote", - EventNotification: &EventNotification{ - OldEntry: &Entry{Name: "remote.conf"}, - NewEntry: &Entry{Name: "remote.conf"}, - NewParentPath: "/tmp", - }, - } - if !MetadataEventTouchesDirectoryPrefix(renameOut, "/etc/remote") { - t.Fatalf("expected rename source to touch /etc/remote") - } -} diff --git a/weed/pb/grpc_client_server.go b/weed/pb/grpc_client_server.go index 4b7c0852d..82b9a23f5 100644 --- a/weed/pb/grpc_client_server.go +++ b/weed/pb/grpc_client_server.go @@ -28,7 +28,6 @@ import ( "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" "github.com/seaweedfs/seaweedfs/weed/pb/master_pb" "github.com/seaweedfs/seaweedfs/weed/pb/mq_pb" - "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb" ) const ( @@ -318,18 +317,6 @@ func WithGrpcClient(streamingMode bool, signature int32, fn func(*grpc.ClientCon } -func ParseServerAddress(server string, deltaPort int) (newServerAddress string, err error) { - - host, port, parseErr := hostAndPort(server) - if parseErr != nil { - return "", fmt.Errorf("server port parse error: %w", parseErr) - } - - newPort := int(port) + deltaPort - - return util.JoinHostPort(host, newPort), nil -} - func hostAndPort(address string) (host string, port uint64, err error) { colonIndex := strings.LastIndex(address, ":") if colonIndex < 0 { @@ -457,10 +444,3 @@ func WithOneOfGrpcFilerClients(streamingMode bool, filerAddresses []ServerAddres return err } - -func WithWorkerClient(streamingMode bool, workerAddress string, grpcDialOption grpc.DialOption, fn func(client worker_pb.WorkerServiceClient) error) error { - return WithGrpcClient(streamingMode, 0, func(grpcConnection *grpc.ClientConn) error { - client := worker_pb.NewWorkerServiceClient(grpcConnection) - return fn(client) - }, workerAddress, false, grpcDialOption) -} diff --git a/weed/pb/server_address.go b/weed/pb/server_address.go index 151323b03..7cdece5eb 100644 --- a/weed/pb/server_address.go +++ b/weed/pb/server_address.go @@ -157,14 +157,6 @@ func (sa ServerAddresses) ToAddressMap() (addresses map[string]ServerAddress) { return } -func (sa ServerAddresses) ToAddressStrings() (addresses []string) { - parts := strings.Split(string(sa), ",") - for _, address := range parts { - addresses = append(addresses, address) - } - return -} - func ToAddressStrings(addresses []ServerAddress) []string { var strings []string for _, addr := range addresses { @@ -172,20 +164,6 @@ func ToAddressStrings(addresses []ServerAddress) []string { } return strings } -func ToAddressStringsFromMap(addresses map[string]ServerAddress) []string { - var strings []string - for _, addr := range addresses { - strings = append(strings, string(addr)) - } - return strings -} -func FromAddressStrings(strings []string) []ServerAddress { - var addresses []ServerAddress - for _, addr := range strings { - addresses = append(addresses, ServerAddress(addr)) - } - return addresses -} func ParseUrl(input string) (address ServerAddress, path string, err error) { if !strings.HasPrefix(input, "http://") { diff --git a/weed/plugin/worker/iceberg/detection.go b/weed/plugin/worker/iceberg/detection.go index 80e7fcd61..8a279287c 100644 --- a/weed/plugin/worker/iceberg/detection.go +++ b/weed/plugin/worker/iceberg/detection.go @@ -449,58 +449,6 @@ func hasEligibleCompaction( return len(bins) > 0, nil } -func countDataManifestsForRewrite( - ctx context.Context, - filerClient filer_pb.SeaweedFilerClient, - bucketName, tablePath string, - manifests []iceberg.ManifestFile, - meta table.Metadata, - predicate *partitionPredicate, -) (int64, error) { - if predicate == nil { - return countDataManifests(manifests), nil - } - - specsByID := specByID(meta) - - var count int64 - for _, mf := range manifests { - if mf.ManifestContent() != iceberg.ManifestContentData { - continue - } - manifestData, err := loadFileByIcebergPath(ctx, filerClient, bucketName, tablePath, mf.FilePath()) - if err != nil { - return 0, fmt.Errorf("read manifest %s: %w", mf.FilePath(), err) - } - entries, err := iceberg.ReadManifest(mf, bytes.NewReader(manifestData), true) - if err != nil { - return 0, fmt.Errorf("parse manifest %s: %w", mf.FilePath(), err) - } - if len(entries) == 0 { - continue - } - spec, ok := specsByID[int(mf.PartitionSpecID())] - if !ok { - continue - } - allMatch := len(entries) > 0 - for _, entry := range entries { - match, err := predicate.Matches(spec, entry.DataFile().Partition()) - if err != nil { - return 0, err - } - if !match { - allMatch = false - break - } - } - if allMatch { - count++ - } - } - return count, nil -} - func compactionMinInputFiles(minInputFiles int64) (int, error) { // Ensure the configured value is positive and fits into the platform's int type if minInputFiles <= 0 { diff --git a/weed/plugin/worker/iceberg/planning_index.go b/weed/plugin/worker/iceberg/planning_index.go index 74e019354..a2401836a 100644 --- a/weed/plugin/worker/iceberg/planning_index.go +++ b/weed/plugin/worker/iceberg/planning_index.go @@ -137,26 +137,6 @@ func mergePlanningIndexSections(index, existing *planningIndex) *planningIndex { return index } -func buildPlanningIndex( - ctx context.Context, - filerClient filer_pb.SeaweedFilerClient, - bucketName, tablePath string, - meta table.Metadata, - config Config, - ops []string, -) (*planningIndex, error) { - currentSnap := meta.CurrentSnapshot() - if currentSnap == nil || currentSnap.ManifestList == "" { - return nil, nil - } - - manifests, err := loadCurrentManifests(ctx, filerClient, bucketName, tablePath, meta) - if err != nil { - return nil, err - } - return buildPlanningIndexFromManifests(ctx, filerClient, bucketName, tablePath, meta, config, ops, manifests) -} - func buildPlanningIndexFromManifests( ctx context.Context, filerClient filer_pb.SeaweedFilerClient, diff --git a/weed/plugin/worker/lifecycle/config.go b/weed/plugin/worker/lifecycle/config.go deleted file mode 100644 index 62e0b4dbf..000000000 --- a/weed/plugin/worker/lifecycle/config.go +++ /dev/null @@ -1,131 +0,0 @@ -package lifecycle - -import ( - "strconv" - "strings" - - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/pb/plugin_pb" -) - -const ( - jobType = "s3_lifecycle" - - defaultBatchSize = 1000 - defaultMaxDeletesPerBucket = 10000 - defaultDryRun = false - defaultDeleteMarkerCleanup = true - defaultAbortMPUDaysDefault = 7 - - MetricObjectsExpired = "objects_expired" - MetricObjectsScanned = "objects_scanned" - MetricBucketsScanned = "buckets_scanned" - MetricBucketsWithRules = "buckets_with_rules" - MetricDeleteMarkersClean = "delete_markers_cleaned" - MetricMPUAborted = "mpu_aborted" - MetricErrors = "errors" - MetricDurationMs = "duration_ms" -) - -// Config holds parsed worker config values for lifecycle management. -type Config struct { - BatchSize int64 - MaxDeletesPerBucket int64 - DryRun bool - DeleteMarkerCleanup bool - AbortMPUDays int64 -} - -// ParseConfig extracts a lifecycle Config from plugin config values. -func ParseConfig(values map[string]*plugin_pb.ConfigValue) Config { - cfg := Config{ - BatchSize: readInt64Config(values, "batch_size", defaultBatchSize), - MaxDeletesPerBucket: readInt64Config(values, "max_deletes_per_bucket", defaultMaxDeletesPerBucket), - DryRun: readBoolConfig(values, "dry_run", defaultDryRun), - DeleteMarkerCleanup: readBoolConfig(values, "delete_marker_cleanup", defaultDeleteMarkerCleanup), - AbortMPUDays: readInt64Config(values, "abort_mpu_days", defaultAbortMPUDaysDefault), - } - - if cfg.BatchSize <= 0 { - cfg.BatchSize = defaultBatchSize - } - if cfg.MaxDeletesPerBucket <= 0 { - cfg.MaxDeletesPerBucket = defaultMaxDeletesPerBucket - } - if cfg.AbortMPUDays < 0 { - cfg.AbortMPUDays = defaultAbortMPUDaysDefault - } - - return cfg -} - -func readStringConfig(values map[string]*plugin_pb.ConfigValue, field string, fallback string) string { - if values == nil { - return fallback - } - value := values[field] - if value == nil { - return fallback - } - switch kind := value.Kind.(type) { - case *plugin_pb.ConfigValue_StringValue: - return kind.StringValue - case *plugin_pb.ConfigValue_Int64Value: - return strconv.FormatInt(kind.Int64Value, 10) - default: - glog.V(1).Infof("readStringConfig: unexpected type %T for field %q", value.Kind, field) - } - return fallback -} - -func readBoolConfig(values map[string]*plugin_pb.ConfigValue, field string, fallback bool) bool { - if values == nil { - return fallback - } - value := values[field] - if value == nil { - return fallback - } - switch kind := value.Kind.(type) { - case *plugin_pb.ConfigValue_BoolValue: - return kind.BoolValue - case *plugin_pb.ConfigValue_StringValue: - s := strings.TrimSpace(strings.ToLower(kind.StringValue)) - if s == "true" || s == "1" || s == "yes" { - return true - } - if s == "false" || s == "0" || s == "no" { - return false - } - glog.V(1).Infof("readBoolConfig: unrecognized string value %q for field %q, using fallback %v", kind.StringValue, field, fallback) - case *plugin_pb.ConfigValue_Int64Value: - return kind.Int64Value != 0 - default: - glog.V(1).Infof("readBoolConfig: unexpected config value type %T for field %q, using fallback %v", value.Kind, field, fallback) - } - return fallback -} - -func readInt64Config(values map[string]*plugin_pb.ConfigValue, field string, fallback int64) int64 { - if values == nil { - return fallback - } - value := values[field] - if value == nil { - return fallback - } - switch kind := value.Kind.(type) { - case *plugin_pb.ConfigValue_Int64Value: - return kind.Int64Value - case *plugin_pb.ConfigValue_DoubleValue: - return int64(kind.DoubleValue) - case *plugin_pb.ConfigValue_StringValue: - parsed, err := strconv.ParseInt(strings.TrimSpace(kind.StringValue), 10, 64) - if err == nil { - return parsed - } - default: - glog.V(1).Infof("readInt64Config: unexpected config value type %T for field %q, using fallback %d", value.Kind, field, fallback) - } - return fallback -} diff --git a/weed/plugin/worker/lifecycle/detection.go b/weed/plugin/worker/lifecycle/detection.go deleted file mode 100644 index d8267b2f0..000000000 --- a/weed/plugin/worker/lifecycle/detection.go +++ /dev/null @@ -1,221 +0,0 @@ -package lifecycle - -import ( - "context" - "fmt" - "path" - "strings" - - "github.com/seaweedfs/seaweedfs/weed/filer" - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/pb/plugin_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/util/wildcard" -) - -const lifecycleXMLKey = "s3-bucket-lifecycle-configuration-xml" - -// detectBucketsWithLifecycleRules scans all S3 buckets to find those -// with lifecycle rules, either TTL entries in filer.conf or lifecycle -// XML stored in bucket metadata. -func (h *Handler) detectBucketsWithLifecycleRules( - ctx context.Context, - filerClient filer_pb.SeaweedFilerClient, - config Config, - bucketFilter string, - maxResults int, -) ([]*plugin_pb.JobProposal, error) { - // Load filer configuration to find TTL rules. - fc, err := loadFilerConf(ctx, filerClient) - if err != nil { - return nil, fmt.Errorf("load filer conf: %w", err) - } - - bucketsPath := defaultBucketsPath - bucketMatchers := wildcard.CompileWildcardMatchers(bucketFilter) - - // List all buckets. - bucketEntries, err := listFilerEntries(ctx, filerClient, bucketsPath, "") - if err != nil { - return nil, fmt.Errorf("list buckets at %s: %w", bucketsPath, err) - } - - var proposals []*plugin_pb.JobProposal - for _, entry := range bucketEntries { - select { - case <-ctx.Done(): - return proposals, ctx.Err() - default: - } - - if !entry.IsDirectory { - continue - } - bucketName := entry.Name - if !wildcard.MatchesAnyWildcard(bucketMatchers, bucketName) { - continue - } - - // Check for lifecycle rules from two sources: - // 1. filer.conf TTLs (legacy Expiration.Days fast path) - // 2. Stored lifecycle XML in bucket metadata (full rule support) - collection := bucketName - ttls := fc.GetCollectionTtls(collection) - - hasLifecycleXML := entry.Extended != nil && len(entry.Extended[lifecycleXMLKey]) > 0 - versioningStatus := "" - if entry.Extended != nil { - versioningStatus = string(entry.Extended[s3_constants.ExtVersioningKey]) - } - - ruleCount := int64(len(ttls)) - if !hasLifecycleXML && ruleCount == 0 { - continue - } - - glog.V(2).Infof("s3_lifecycle: bucket %s has %d TTL rule(s), lifecycle_xml=%v, versioning=%s", - bucketName, ruleCount, hasLifecycleXML, versioningStatus) - - proposal := &plugin_pb.JobProposal{ - ProposalId: fmt.Sprintf("s3_lifecycle:%s", bucketName), - JobType: jobType, - Summary: fmt.Sprintf("Lifecycle management for bucket %s", bucketName), - DedupeKey: fmt.Sprintf("s3_lifecycle:%s", bucketName), - Parameters: map[string]*plugin_pb.ConfigValue{ - "bucket": {Kind: &plugin_pb.ConfigValue_StringValue{StringValue: bucketName}}, - "buckets_path": {Kind: &plugin_pb.ConfigValue_StringValue{StringValue: bucketsPath}}, - "collection": {Kind: &plugin_pb.ConfigValue_StringValue{StringValue: collection}}, - "rule_count": {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: ruleCount}}, - "has_lifecycle_xml": {Kind: &plugin_pb.ConfigValue_BoolValue{BoolValue: hasLifecycleXML}}, - "versioning_status": {Kind: &plugin_pb.ConfigValue_StringValue{StringValue: versioningStatus}}, - }, - Labels: map[string]string{ - "bucket": bucketName, - }, - } - - proposals = append(proposals, proposal) - if maxResults > 0 && len(proposals) >= maxResults { - break - } - } - - return proposals, nil -} - -const defaultBucketsPath = "/buckets" - -// loadFilerConf reads the filer configuration from the filer. -func loadFilerConf(ctx context.Context, client filer_pb.SeaweedFilerClient) (*filer.FilerConf, error) { - fc := filer.NewFilerConf() - - content, err := filer.ReadInsideFiler(ctx, client, filer.DirectoryEtcSeaweedFS, filer.FilerConfName) - if err != nil { - // filer.conf may not exist yet - return empty config. - glog.V(1).Infof("s3_lifecycle: filer.conf not found or unreadable: %v (using empty config)", err) - return fc, nil - } - if err := fc.LoadFromBytes(content); err != nil { - return nil, fmt.Errorf("parse filer.conf: %w", err) - } - - return fc, nil -} - -// listFilerEntries lists directory entries from the filer. -func listFilerEntries(ctx context.Context, client filer_pb.SeaweedFilerClient, dir, startFrom string) ([]*filer_pb.Entry, error) { - var entries []*filer_pb.Entry - err := filer_pb.SeaweedList(ctx, client, dir, "", func(entry *filer_pb.Entry, isLast bool) error { - entries = append(entries, entry) - return nil - }, startFrom, false, 10000) - return entries, err -} - -type expiredObject struct { - dir string - name string -} - -// listExpiredObjects scans a bucket directory tree for objects whose TTL -// has expired based on their TtlSec attribute set by PutBucketLifecycle. -func listExpiredObjects( - ctx context.Context, - client filer_pb.SeaweedFilerClient, - bucketsPath, bucket string, - limit int64, -) ([]expiredObject, int64, error) { - var expired []expiredObject - var scanned int64 - - bucketPath := path.Join(bucketsPath, bucket) - - // Walk the bucket directory tree using breadth-first traversal. - dirsToProcess := []string{bucketPath} - for len(dirsToProcess) > 0 { - select { - case <-ctx.Done(): - return expired, scanned, ctx.Err() - default: - } - - dir := dirsToProcess[0] - dirsToProcess = dirsToProcess[1:] - - limitReached := false - err := filer_pb.SeaweedList(ctx, client, dir, "", func(entry *filer_pb.Entry, isLast bool) error { - if entry.IsDirectory { - dirsToProcess = append(dirsToProcess, path.Join(dir, entry.Name)) - return nil - } - scanned++ - - if isExpiredByTTL(entry) { - expired = append(expired, expiredObject{ - dir: dir, - name: entry.Name, - }) - } - - if limit > 0 && int64(len(expired)) >= limit { - limitReached = true - return fmt.Errorf("limit reached") - } - return nil - }, "", false, 10000) - - if err != nil && !strings.Contains(err.Error(), "limit reached") { - return expired, scanned, fmt.Errorf("list %s: %w", dir, err) - } - - if limitReached || (limit > 0 && int64(len(expired)) >= limit) { - break - } - } - - return expired, scanned, nil -} - -// isExpiredByTTL checks if an entry is expired based on its TTL attribute. -// SeaweedFS sets TtlSec on entries when lifecycle rules are applied via -// PutBucketLifecycleConfiguration. An entry is expired when -// creation_time + TTL < now. -func isExpiredByTTL(entry *filer_pb.Entry) bool { - if entry == nil || entry.Attributes == nil { - return false - } - - ttlSec := entry.Attributes.TtlSec - if ttlSec <= 0 { - return false - } - - crTime := entry.Attributes.Crtime - if crTime <= 0 { - return false - } - - expirationUnix := crTime + int64(ttlSec) - return expirationUnix < nowUnix() -} diff --git a/weed/plugin/worker/lifecycle/detection_test.go b/weed/plugin/worker/lifecycle/detection_test.go deleted file mode 100644 index d9ff86688..000000000 --- a/weed/plugin/worker/lifecycle/detection_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package lifecycle - -import ( - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" -) - -func TestBucketHasLifecycleXML(t *testing.T) { - tests := []struct { - name string - extended map[string][]byte - want bool - }{ - { - name: "has_lifecycle_xml", - extended: map[string][]byte{lifecycleXMLKey: []byte("")}, - want: true, - }, - { - name: "empty_lifecycle_xml", - extended: map[string][]byte{lifecycleXMLKey: {}}, - want: false, - }, - { - name: "no_lifecycle_xml", - extended: map[string][]byte{"other-key": []byte("value")}, - want: false, - }, - { - name: "nil_extended", - extended: nil, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := tt.extended != nil && len(tt.extended[lifecycleXMLKey]) > 0 - if got != tt.want { - t.Errorf("hasLifecycleXML = %v, want %v", got, tt.want) - } - }) - } -} - -func TestBucketVersioningStatus(t *testing.T) { - tests := []struct { - name string - extended map[string][]byte - want string - }{ - { - name: "versioning_enabled", - extended: map[string][]byte{ - s3_constants.ExtVersioningKey: []byte("Enabled"), - }, - want: "Enabled", - }, - { - name: "versioning_suspended", - extended: map[string][]byte{ - s3_constants.ExtVersioningKey: []byte("Suspended"), - }, - want: "Suspended", - }, - { - name: "no_versioning", - extended: map[string][]byte{}, - want: "", - }, - { - name: "nil_extended", - extended: nil, - want: "", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var got string - if tt.extended != nil { - got = string(tt.extended[s3_constants.ExtVersioningKey]) - } - if got != tt.want { - t.Errorf("versioningStatus = %q, want %q", got, tt.want) - } - }) - } -} - -func TestDetectionProposalParameters(t *testing.T) { - // Verify that bucket entries with lifecycle XML or TTL rules produce - // proposals with the expected parameters. - t.Run("bucket_with_lifecycle_xml_and_versioning", func(t *testing.T) { - entry := &filer_pb.Entry{ - Name: "my-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - lifecycleXMLKey: []byte(`Enabled`), - s3_constants.ExtVersioningKey: []byte("Enabled"), - }, - } - - hasXML := entry.Extended != nil && len(entry.Extended[lifecycleXMLKey]) > 0 - versioning := "" - if entry.Extended != nil { - versioning = string(entry.Extended[s3_constants.ExtVersioningKey]) - } - - if !hasXML { - t.Error("expected hasLifecycleXML=true") - } - if versioning != "Enabled" { - t.Errorf("expected versioning=Enabled, got %q", versioning) - } - }) - - t.Run("bucket_without_lifecycle_or_ttl_is_skipped", func(t *testing.T) { - entry := &filer_pb.Entry{ - Name: "empty-bucket", - IsDirectory: true, - Extended: map[string][]byte{}, - } - - hasXML := entry.Extended != nil && len(entry.Extended[lifecycleXMLKey]) > 0 - ttlCount := 0 // simulated: no TTL rules in filer.conf - - if hasXML || ttlCount > 0 { - t.Error("expected bucket to be skipped (no lifecycle XML, no TTLs)") - } - }) -} diff --git a/weed/plugin/worker/lifecycle/execution.go b/weed/plugin/worker/lifecycle/execution.go deleted file mode 100644 index 183e7648e..000000000 --- a/weed/plugin/worker/lifecycle/execution.go +++ /dev/null @@ -1,878 +0,0 @@ -package lifecycle - -import ( - "context" - "errors" - "fmt" - "math" - "path" - "sort" - "strconv" - "strings" - "time" - - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/pb/plugin_pb" - pluginworker "github.com/seaweedfs/seaweedfs/weed/plugin/worker" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3lifecycle" -) - -var errLimitReached = errors.New("limit reached") - -type executionResult struct { - objectsExpired int64 - objectsScanned int64 - deleteMarkersClean int64 - mpuAborted int64 - errors int64 -} - -// executeLifecycleForBucket processes lifecycle rules for a single bucket: -// 1. Reads filer.conf to get TTL rules for the bucket's collection -// 2. Walks the bucket directory tree to find expired objects -// 3. Deletes expired objects (unless dry run) -func (h *Handler) executeLifecycleForBucket( - ctx context.Context, - filerClient filer_pb.SeaweedFilerClient, - config Config, - bucket, bucketsPath string, - sender pluginworker.ExecutionSender, - jobID string, -) (*executionResult, error) { - result := &executionResult{} - - // Try to load lifecycle rules from stored XML first (full rule evaluation). - // Fall back to filer.conf TTL-only evaluation only if no XML is configured. - // If XML exists but is malformed, fail closed (don't fall back to TTL, - // which could apply broader rules and delete objects the XML rules would keep). - // Transient filer errors fall back to TTL with a warning. - lifecycleRules, xmlErr := loadLifecycleRulesFromBucket(ctx, filerClient, bucketsPath, bucket) - if xmlErr != nil && errors.Is(xmlErr, errMalformedLifecycleXML) { - glog.Errorf("s3_lifecycle: bucket %s: %v (skipping bucket)", bucket, xmlErr) - return result, xmlErr - } - if xmlErr != nil { - glog.V(1).Infof("s3_lifecycle: bucket %s: transient error loading lifecycle XML: %v, falling back to TTL", bucket, xmlErr) - } - // lifecycleRules is non-nil when XML was present (even if empty/all disabled). - // Only fall back to TTL when XML was truly absent (nil). - xmlPresent := xmlErr == nil && lifecycleRules != nil - useRuleEval := xmlPresent && len(lifecycleRules) > 0 - - if !useRuleEval && !xmlPresent { - // Fall back to filer.conf TTL rules only when no lifecycle XML exists. - // When XML is present but has no effective rules, skip TTL fallback. - fc, err := loadFilerConf(ctx, filerClient) - if err != nil { - return result, fmt.Errorf("load filer conf: %w", err) - } - collection := bucket - ttlRules := fc.GetCollectionTtls(collection) - if len(ttlRules) == 0 { - glog.V(1).Infof("s3_lifecycle: bucket %s has no lifecycle rules, skipping", bucket) - return result, nil - } - } - - _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{ - JobId: jobID, - JobType: jobType, - State: plugin_pb.JobState_JOB_STATE_RUNNING, - ProgressPercent: 10, - Stage: "scanning", - Message: fmt.Sprintf("scanning bucket %s for expired objects", bucket), - }) - - // Shared budget across all phases so we don't exceed MaxDeletesPerBucket. - remaining := config.MaxDeletesPerBucket - - // Find expired objects using rule-based evaluation or TTL fallback. - var expired []expiredObject - var scanned int64 - var err error - if useRuleEval { - expired, scanned, err = listExpiredObjectsByRules(ctx, filerClient, bucketsPath, bucket, lifecycleRules, remaining) - } else if !xmlPresent { - // TTL-only scan when no lifecycle XML exists. - expired, scanned, err = listExpiredObjects(ctx, filerClient, bucketsPath, bucket, remaining) - } - // When xmlPresent but no effective rules (all disabled), skip object scanning. - result.objectsScanned = scanned - if err != nil { - return result, fmt.Errorf("list expired objects: %w", err) - } - - if len(expired) > 0 { - glog.V(1).Infof("s3_lifecycle: bucket %s: found %d expired objects out of %d scanned", bucket, len(expired), scanned) - } else { - glog.V(1).Infof("s3_lifecycle: bucket %s: scanned %d objects, none expired", bucket, scanned) - } - - if config.DryRun && len(expired) > 0 { - result.objectsExpired = int64(len(expired)) - _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{ - JobId: jobID, - JobType: jobType, - State: plugin_pb.JobState_JOB_STATE_RUNNING, - ProgressPercent: 100, - Stage: "dry_run", - Message: fmt.Sprintf("dry run: would delete %d expired objects", len(expired)), - }) - return result, nil - } - - // Delete expired objects in batches. - if len(expired) > 0 { - _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{ - JobId: jobID, - JobType: jobType, - State: plugin_pb.JobState_JOB_STATE_RUNNING, - ProgressPercent: 50, - Stage: "deleting", - Message: fmt.Sprintf("deleting %d expired objects", len(expired)), - }) - - var batchSize int - if config.BatchSize <= 0 { - batchSize = defaultBatchSize - } else if config.BatchSize > math.MaxInt { - batchSize = math.MaxInt - } else { - batchSize = int(config.BatchSize) - } - - for i := 0; i < len(expired); i += batchSize { - select { - case <-ctx.Done(): - return result, ctx.Err() - default: - } - - end := i + batchSize - if end > len(expired) { - end = len(expired) - } - batch := expired[i:end] - - deleted, errs, batchErr := deleteExpiredObjects(ctx, filerClient, batch) - result.objectsExpired += int64(deleted) - result.errors += int64(errs) - - if batchErr != nil { - return result, batchErr - } - - progress := float64(end)/float64(len(expired))*50 + 50 // 50-100% - _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{ - JobId: jobID, - JobType: jobType, - State: plugin_pb.JobState_JOB_STATE_RUNNING, - ProgressPercent: progress, - Stage: "deleting", - Message: fmt.Sprintf("deleted %d/%d expired objects", result.objectsExpired, len(expired)), - }) - } - - // Clean up .versions directories left empty after version deletion. - cleanupEmptyVersionsDirectories(ctx, filerClient, expired) - - remaining -= result.objectsExpired + result.errors - if remaining < 0 { - remaining = 0 - } - } - - // Delete marker cleanup. - if config.DeleteMarkerCleanup && remaining > 0 { - _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{ - JobId: jobID, JobType: jobType, - State: plugin_pb.JobState_JOB_STATE_RUNNING, - Stage: "cleaning_delete_markers", Message: "cleaning expired delete markers", - }) - cleaned, cleanErrs, cleanCtxErr := cleanupDeleteMarkers(ctx, filerClient, bucketsPath, bucket, lifecycleRules, remaining) - result.deleteMarkersClean = int64(cleaned) - result.errors += int64(cleanErrs) - if cleanCtxErr != nil { - return result, cleanCtxErr - } - remaining -= int64(cleaned + cleanErrs) - if remaining < 0 { - remaining = 0 - } - } - - // Abort incomplete multipart uploads. - // When lifecycle XML exists, evaluate each upload against the rules - // (respecting per-rule prefix filters and DaysAfterInitiation). - // Fall back to worker config abort_mpu_days only when no lifecycle - // XML is configured for the bucket. - if xmlPresent && remaining > 0 { - _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{ - JobId: jobID, JobType: jobType, - State: plugin_pb.JobState_JOB_STATE_RUNNING, - Stage: "aborting_mpus", Message: "evaluating MPU abort rules", - }) - aborted, abortErrs, abortCtxErr := abortMPUsByRules(ctx, filerClient, bucketsPath, bucket, lifecycleRules, remaining) - result.mpuAborted = int64(aborted) - result.errors += int64(abortErrs) - if abortCtxErr != nil { - return result, abortCtxErr - } - } else if !xmlPresent && config.AbortMPUDays > 0 && remaining > 0 { - _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{ - JobId: jobID, JobType: jobType, - State: plugin_pb.JobState_JOB_STATE_RUNNING, - Stage: "aborting_mpus", Message: fmt.Sprintf("aborting multipart uploads older than %d days", config.AbortMPUDays), - }) - aborted, abortErrs, abortCtxErr := abortIncompleteMPUs(ctx, filerClient, bucketsPath, bucket, config.AbortMPUDays, remaining) - result.mpuAborted = int64(aborted) - result.errors += int64(abortErrs) - if abortCtxErr != nil { - return result, abortCtxErr - } - } - - return result, nil -} - -// cleanupDeleteMarkers scans versioned objects and removes delete markers -// that are the sole remaining version. This matches AWS S3 -// ExpiredObjectDeleteMarker semantics: a delete marker is only removed when -// it is the only version of an object (no non-current versions behind it). -// -// This phase should run AFTER NoncurrentVersionExpiration (PR 4) so that -// non-current versions have already been cleaned up, potentially leaving -// delete markers as sole versions eligible for removal. -func cleanupDeleteMarkers( - ctx context.Context, - client filer_pb.SeaweedFilerClient, - bucketsPath, bucket string, - rules []s3lifecycle.Rule, - limit int64, -) (cleaned, errors int, ctxErr error) { - bucketPath := path.Join(bucketsPath, bucket) - - dirsToProcess := []string{bucketPath} - for len(dirsToProcess) > 0 { - if ctx.Err() != nil { - return cleaned, errors, ctx.Err() - } - - dir := dirsToProcess[0] - dirsToProcess = dirsToProcess[1:] - - listErr := filer_pb.SeaweedList(ctx, client, dir, "", func(entry *filer_pb.Entry, isLast bool) error { - if entry.IsDirectory { - if dir == bucketPath && entry.Name == s3_constants.MultipartUploadsFolder { - return nil - } - if strings.HasSuffix(entry.Name, s3_constants.VersionsFolder) { - versionsDir := path.Join(dir, entry.Name) - // Check if the latest version is a delete marker. - latestIsMarker := string(entry.Extended[s3_constants.ExtLatestVersionIsDeleteMarker]) == "true" - if !latestIsMarker { - return nil - } - // Count versions in the directory. - versionCount := 0 - countErr := filer_pb.SeaweedList(ctx, client, versionsDir, "", func(ve *filer_pb.Entry, _ bool) error { - if !ve.IsDirectory { - versionCount++ - } - return nil - }, "", false, 10000) - if countErr != nil { - glog.V(1).Infof("s3_lifecycle: failed to count versions in %s: %v", versionsDir, countErr) - errors++ - return nil - } - // Only remove if the delete marker is the sole version. - if versionCount != 1 { - return nil - } - // Check that a matching ExpiredObjectDeleteMarker rule exists. - // The rule's prefix filter must match this object's key. - relDir := strings.TrimPrefix(versionsDir, bucketPath+"/") - objKey := strings.TrimSuffix(relDir, s3_constants.VersionsFolder) - if len(rules) > 0 && !matchesDeleteMarkerRule(rules, objKey) { - return nil - } - // Find and remove the sole delete marker entry. - removedHere := false - removeErr := filer_pb.SeaweedList(ctx, client, versionsDir, "", func(ve *filer_pb.Entry, _ bool) error { - if !ve.IsDirectory && isDeleteMarker(ve) { - if err := filer_pb.DoRemove(ctx, client, versionsDir, ve.Name, true, false, false, false, nil); err != nil { - glog.V(1).Infof("s3_lifecycle: failed to remove delete marker %s/%s: %v", versionsDir, ve.Name, err) - errors++ - } else { - cleaned++ - removedHere = true - } - } - return nil - }, "", false, 10) - if removeErr != nil { - glog.V(1).Infof("s3_lifecycle: failed to scan for delete marker in %s: %v", versionsDir, removeErr) - } - // Remove the now-empty .versions directory only if we - // actually deleted the marker in this specific directory. - if removedHere { - _ = filer_pb.DoRemove(ctx, client, dir, entry.Name, true, true, true, false, nil) - } - return nil - } - dirsToProcess = append(dirsToProcess, path.Join(dir, entry.Name)) - return nil - } - - // For non-versioned objects: only clean up if explicitly a delete marker - // and a matching rule exists. - relKey := strings.TrimPrefix(path.Join(dir, entry.Name), bucketPath+"/") - if isDeleteMarker(entry) && matchesDeleteMarkerRule(rules, relKey) { - if err := filer_pb.DoRemove(ctx, client, dir, entry.Name, true, false, false, false, nil); err != nil { - glog.V(1).Infof("s3_lifecycle: failed to remove delete marker %s/%s: %v", dir, entry.Name, err) - errors++ - } else { - cleaned++ - } - } - - if limit > 0 && int64(cleaned+errors) >= limit { - return fmt.Errorf("limit reached") - } - return nil - }, "", false, 10000) - - if listErr != nil && !strings.Contains(listErr.Error(), "limit reached") { - return cleaned, errors, fmt.Errorf("list %s: %w", dir, listErr) - } - - if limit > 0 && int64(cleaned+errors) >= limit { - break - } - } - return cleaned, errors, nil -} - -// isDeleteMarker checks if an entry is an S3 delete marker. -func isDeleteMarker(entry *filer_pb.Entry) bool { - if entry == nil || entry.Extended == nil { - return false - } - return string(entry.Extended[s3_constants.ExtDeleteMarkerKey]) == "true" -} - -// matchesDeleteMarkerRule checks if any enabled ExpiredObjectDeleteMarker rule -// matches the given object key using the full filter model (prefix, tags, size). -// When no lifecycle rules are provided (nil means no XML configured), -// falls back to legacy behavior (returns true to allow cleanup). -// A non-nil empty slice means XML was present but had no matching rules, -// so cleanup is not allowed. -func matchesDeleteMarkerRule(rules []s3lifecycle.Rule, objKey string) bool { - if rules == nil { - return true // legacy fallback: no lifecycle XML configured - } - // Delete markers have no size or tags, so build a minimal ObjectInfo. - obj := s3lifecycle.ObjectInfo{Key: objKey} - for _, r := range rules { - if r.Status == "Enabled" && r.ExpiredObjectDeleteMarker && s3lifecycle.MatchesFilter(r, obj) { - return true - } - } - return false -} - -// abortMPUsByRules scans the .uploads directory and evaluates each upload -// against lifecycle rules using EvaluateMPUAbort, which respects per-rule -// prefix filters and DaysAfterInitiation thresholds. -func abortMPUsByRules( - ctx context.Context, - client filer_pb.SeaweedFilerClient, - bucketsPath, bucket string, - rules []s3lifecycle.Rule, - limit int64, -) (aborted, errs int, ctxErr error) { - uploadsDir := path.Join(bucketsPath, bucket, ".uploads") - now := time.Now() - - listErr := filer_pb.SeaweedList(ctx, client, uploadsDir, "", func(entry *filer_pb.Entry, isLast bool) error { - if ctx.Err() != nil { - return ctx.Err() - } - if !entry.IsDirectory { - return nil - } - if entry.Attributes == nil || entry.Attributes.Crtime <= 0 { - return nil - } - - createdAt := time.Unix(entry.Attributes.Crtime, 0) - result := s3lifecycle.EvaluateMPUAbort(rules, entry.Name, createdAt, now) - if result.Action == s3lifecycle.ActionAbortMultipartUpload { - uploadPath := path.Join(uploadsDir, entry.Name) - if err := filer_pb.DoRemove(ctx, client, uploadsDir, entry.Name, true, true, true, false, nil); err != nil { - glog.V(1).Infof("s3_lifecycle: failed to abort MPU %s: %v", uploadPath, err) - errs++ - } else { - aborted++ - } - } - - if limit > 0 && int64(aborted+errs) >= limit { - return errLimitReached - } - return nil - }, "", false, 10000) - - if listErr != nil && !errors.Is(listErr, errLimitReached) { - return aborted, errs, fmt.Errorf("list uploads in %s: %w", uploadsDir, listErr) - } - return aborted, errs, nil -} - -// abortIncompleteMPUs scans the .uploads directory under a bucket and -// removes multipart upload entries older than the specified number of days. -func abortIncompleteMPUs( - ctx context.Context, - client filer_pb.SeaweedFilerClient, - bucketsPath, bucket string, - olderThanDays, limit int64, -) (aborted, errors int, ctxErr error) { - uploadsDir := path.Join(bucketsPath, bucket, ".uploads") - cutoff := time.Now().Add(-time.Duration(olderThanDays) * 24 * time.Hour) - - listErr := filer_pb.SeaweedList(ctx, client, uploadsDir, "", func(entry *filer_pb.Entry, isLast bool) error { - if ctx.Err() != nil { - return ctx.Err() - } - - if !entry.IsDirectory { - return nil - } - - // Each subdirectory under .uploads is one multipart upload. - // Check the directory creation time. - if entry.Attributes != nil && entry.Attributes.Crtime > 0 { - created := time.Unix(entry.Attributes.Crtime, 0) - if created.Before(cutoff) { - uploadPath := path.Join(uploadsDir, entry.Name) - if err := filer_pb.DoRemove(ctx, client, uploadsDir, entry.Name, true, true, true, false, nil); err != nil { - glog.V(1).Infof("s3_lifecycle: failed to abort MPU %s: %v", uploadPath, err) - errors++ - } else { - aborted++ - } - } - } - - if limit > 0 && int64(aborted+errors) >= limit { - return fmt.Errorf("limit reached") - } - return nil - }, "", false, 10000) - - if listErr != nil && !strings.Contains(listErr.Error(), "limit reached") { - return aborted, errors, fmt.Errorf("list uploads in %s: %w", uploadsDir, listErr) - } - - return aborted, errors, nil -} - -// deleteExpiredObjects deletes a batch of expired objects from the filer. -// Returns a non-nil error when the context is canceled mid-batch. -func deleteExpiredObjects( - ctx context.Context, - client filer_pb.SeaweedFilerClient, - objects []expiredObject, -) (deleted, errors int, ctxErr error) { - for _, obj := range objects { - if ctx.Err() != nil { - return deleted, errors, ctx.Err() - } - - err := filer_pb.DoRemove(ctx, client, obj.dir, obj.name, true, false, false, false, nil) - if err != nil { - glog.V(1).Infof("s3_lifecycle: failed to delete %s/%s: %v", obj.dir, obj.name, err) - errors++ - continue - } - deleted++ - } - return deleted, errors, nil -} - -// nowUnix returns the current time as a Unix timestamp. -func nowUnix() int64 { - return time.Now().Unix() -} - -// listExpiredObjectsByRules scans a bucket directory tree and evaluates -// lifecycle rules against each object using the s3lifecycle evaluator. -// This function handles non-versioned objects (IsLatest=true). Versioned -// objects in .versions directories are handled by processVersionsDirectory -// (added in a separate change for NoncurrentVersionExpiration support). -func listExpiredObjectsByRules( - ctx context.Context, - client filer_pb.SeaweedFilerClient, - bucketsPath, bucket string, - rules []s3lifecycle.Rule, - limit int64, -) ([]expiredObject, int64, error) { - var expired []expiredObject - var scanned int64 - - bucketPath := path.Join(bucketsPath, bucket) - now := time.Now() - needTags := s3lifecycle.HasTagRules(rules) - - dirsToProcess := []string{bucketPath} - for len(dirsToProcess) > 0 { - select { - case <-ctx.Done(): - return expired, scanned, ctx.Err() - default: - } - - dir := dirsToProcess[0] - dirsToProcess = dirsToProcess[1:] - - limitReached := false - err := filer_pb.SeaweedList(ctx, client, dir, "", func(entry *filer_pb.Entry, isLast bool) error { - if entry.IsDirectory { - if dir == bucketPath && entry.Name == s3_constants.MultipartUploadsFolder { - return nil // skip .uploads at bucket root only - } - if strings.HasSuffix(entry.Name, s3_constants.VersionsFolder) { - versionsDir := path.Join(dir, entry.Name) - - // Evaluate Expiration rules against the latest version. - // In versioned buckets, data lives in .versions/ directories, - // so we must evaluate the latest version here — it is never - // seen as a regular file entry in the parent directory. - if obj, ok := latestVersionExpiredByRules(ctx, client, entry, versionsDir, bucketPath, rules, now, needTags); ok { - expired = append(expired, obj) - scanned++ - if limit > 0 && int64(len(expired)) >= limit { - limitReached = true - return errLimitReached - } - } - - // Process noncurrent versions. - vExpired, vScanned, vErr := processVersionsDirectory(ctx, client, versionsDir, bucketPath, rules, now, needTags, limit-int64(len(expired))) - if vErr != nil { - glog.V(1).Infof("s3_lifecycle: %v", vErr) - return vErr - } - expired = append(expired, vExpired...) - scanned += vScanned - if limit > 0 && int64(len(expired)) >= limit { - limitReached = true - return errLimitReached - } - return nil - } - dirsToProcess = append(dirsToProcess, path.Join(dir, entry.Name)) - return nil - } - scanned++ - - // Skip objects already handled by TTL fast path. - if entry.Attributes != nil && entry.Attributes.TtlSec > 0 { - expirationUnix := entry.Attributes.Crtime + int64(entry.Attributes.TtlSec) - if expirationUnix > nowUnix() { - return nil // will be expired by RocksDB compaction - } - } - - // Build ObjectInfo for the evaluator. - relKey := strings.TrimPrefix(path.Join(dir, entry.Name), bucketPath+"/") - objInfo := s3lifecycle.ObjectInfo{ - Key: relKey, - IsLatest: true, // non-versioned objects are always "latest" - } - if entry.Attributes != nil { - objInfo.Size = int64(entry.Attributes.GetFileSize()) - if entry.Attributes.Mtime > 0 { - objInfo.ModTime = time.Unix(entry.Attributes.Mtime, 0) - } else if entry.Attributes.Crtime > 0 { - objInfo.ModTime = time.Unix(entry.Attributes.Crtime, 0) - } - } - if needTags { - objInfo.Tags = s3lifecycle.ExtractTags(entry.Extended) - } - - result := s3lifecycle.Evaluate(rules, objInfo, now) - if result.Action == s3lifecycle.ActionDeleteObject { - expired = append(expired, expiredObject{dir: dir, name: entry.Name}) - } - - if limit > 0 && int64(len(expired)) >= limit { - limitReached = true - return errLimitReached - } - return nil - }, "", false, 10000) - - if err != nil && !errors.Is(err, errLimitReached) { - return expired, scanned, fmt.Errorf("list %s: %w", dir, err) - } - - if limitReached || (limit > 0 && int64(len(expired)) >= limit) { - break - } - } - - return expired, scanned, nil -} - -// processVersionsDirectory evaluates NoncurrentVersionExpiration rules -// against all versions in a .versions directory. -func processVersionsDirectory( - ctx context.Context, - client filer_pb.SeaweedFilerClient, - versionsDir, bucketPath string, - rules []s3lifecycle.Rule, - now time.Time, - needTags bool, - limit int64, -) ([]expiredObject, int64, error) { - var expired []expiredObject - var scanned int64 - - // Check if any rule has NoncurrentVersionExpiration. - hasNoncurrentRules := false - for _, r := range rules { - if r.Status == "Enabled" && r.NoncurrentVersionExpirationDays > 0 { - hasNoncurrentRules = true - break - } - } - if !hasNoncurrentRules { - return nil, 0, nil - } - - // List all versions in this directory. - var versions []*filer_pb.Entry - listErr := filer_pb.SeaweedList(ctx, client, versionsDir, "", func(entry *filer_pb.Entry, isLast bool) error { - if !entry.IsDirectory { - versions = append(versions, entry) - } - return nil - }, "", false, 10000) - if listErr != nil { - return nil, 0, fmt.Errorf("list versions in %s: %w", versionsDir, listErr) - } - if len(versions) <= 1 { - return nil, 0, nil // only one version (the latest), nothing to expire - } - - // Sort by version timestamp, newest first. - sortVersionsByVersionId(versions) - - // Derive the object key from the .versions directory path. - // e.g., /buckets/mybucket/path/to/key.versions -> path/to/key - relDir := strings.TrimPrefix(versionsDir, bucketPath+"/") - objKey := strings.TrimSuffix(relDir, s3_constants.VersionsFolder) - - // Walk versions: first is latest, rest are non-current. - noncurrentIndex := 0 - for i := 1; i < len(versions); i++ { - entry := versions[i] - scanned++ - - // Skip delete markers from expiration evaluation, but count - // them toward NewerNoncurrentVersions so data versions get - // the correct noncurrent index. - if isDeleteMarker(entry) { - noncurrentIndex++ - continue - } - - // Determine successor's timestamp (the version that replaced this one). - successorEntry := versions[i-1] - successorVersionId := strings.TrimPrefix(successorEntry.Name, "v_") - successorTime := s3lifecycle.GetVersionTimestamp(successorVersionId) - if successorTime.IsZero() && successorEntry.Attributes != nil && successorEntry.Attributes.Mtime > 0 { - successorTime = time.Unix(successorEntry.Attributes.Mtime, 0) - } - - objInfo := s3lifecycle.ObjectInfo{ - Key: objKey, - IsLatest: false, - SuccessorModTime: successorTime, - NumVersions: len(versions), - NoncurrentIndex: noncurrentIndex, - } - if entry.Attributes != nil { - objInfo.Size = int64(entry.Attributes.GetFileSize()) - if entry.Attributes.Mtime > 0 { - objInfo.ModTime = time.Unix(entry.Attributes.Mtime, 0) - } - } - if needTags { - objInfo.Tags = s3lifecycle.ExtractTags(entry.Extended) - } - - // Evaluate using the detailed ShouldExpireNoncurrentVersion which - // handles NewerNoncurrentVersions. - for _, rule := range rules { - if s3lifecycle.ShouldExpireNoncurrentVersion(rule, objInfo, noncurrentIndex, now) { - expired = append(expired, expiredObject{dir: versionsDir, name: entry.Name}) - break - } - } - - noncurrentIndex++ - - if limit > 0 && int64(len(expired)) >= limit { - break - } - } - - return expired, scanned, nil -} - -// latestVersionExpiredByRules evaluates Expiration rules (Days/Date) against -// the latest version in a .versions directory. In versioned buckets all data -// lives inside .versions/ directories, so the latest version is never seen as -// a regular file entry during the bucket walk. Without this check, Expiration -// rules would never fire for versioned objects (issue #8757). -// -// The .versions directory entry caches metadata about the latest version in -// its Extended attributes, so we can evaluate expiration without an extra -// filer round-trip. -func latestVersionExpiredByRules( - ctx context.Context, - client filer_pb.SeaweedFilerClient, - dirEntry *filer_pb.Entry, - versionsDir, bucketPath string, - rules []s3lifecycle.Rule, - now time.Time, - needTags bool, -) (expiredObject, bool) { - if dirEntry.Extended == nil { - return expiredObject{}, false - } - - // Skip if the latest version is a delete marker — those are handled - // by the ExpiredObjectDeleteMarker rule in cleanupDeleteMarkers. - if string(dirEntry.Extended[s3_constants.ExtLatestVersionIsDeleteMarker]) == "true" { - return expiredObject{}, false - } - - latestFileName := string(dirEntry.Extended[s3_constants.ExtLatestVersionFileNameKey]) - if latestFileName == "" { - return expiredObject{}, false - } - - // Derive the object key: /buckets/b/path/key.versions → path/key - relDir := strings.TrimPrefix(versionsDir, bucketPath+"/") - objKey := strings.TrimSuffix(relDir, s3_constants.VersionsFolder) - - objInfo := s3lifecycle.ObjectInfo{ - Key: objKey, - IsLatest: true, - } - - // Populate ModTime from cached metadata. - if mtimeStr := string(dirEntry.Extended[s3_constants.ExtLatestVersionMtimeKey]); mtimeStr != "" { - if mtime, err := strconv.ParseInt(mtimeStr, 10, 64); err == nil { - objInfo.ModTime = time.Unix(mtime, 0) - } - } - if objInfo.ModTime.IsZero() && dirEntry.Attributes != nil && dirEntry.Attributes.Mtime > 0 { - objInfo.ModTime = time.Unix(dirEntry.Attributes.Mtime, 0) - } - - // Populate Size from cached metadata. - if sizeStr := string(dirEntry.Extended[s3_constants.ExtLatestVersionSizeKey]); sizeStr != "" { - if size, err := strconv.ParseInt(sizeStr, 10, 64); err == nil { - objInfo.Size = size - } - } - - if needTags { - // Tags are stored on the version file entry, not the .versions - // directory. Fetch the actual version file to get them. - resp, err := client.LookupDirectoryEntry(ctx, &filer_pb.LookupDirectoryEntryRequest{ - Directory: versionsDir, - Name: latestFileName, - }) - if err == nil && resp.Entry != nil { - objInfo.Tags = s3lifecycle.ExtractTags(resp.Entry.Extended) - } - } - - result := s3lifecycle.Evaluate(rules, objInfo, now) - if result.Action == s3lifecycle.ActionDeleteObject { - return expiredObject{dir: versionsDir, name: latestFileName}, true - } - - return expiredObject{}, false -} - -// cleanupEmptyVersionsDirectories removes .versions directories that became -// empty after their contents were deleted. This is called after -// deleteExpiredObjects to avoid leaving orphaned directories. -func cleanupEmptyVersionsDirectories( - ctx context.Context, - client filer_pb.SeaweedFilerClient, - deleted []expiredObject, -) int { - // Collect unique .versions directories that had entries deleted. - versionsDirs := map[string]struct{}{} - for _, obj := range deleted { - if strings.HasSuffix(obj.dir, s3_constants.VersionsFolder) { - versionsDirs[obj.dir] = struct{}{} - } - } - - cleaned := 0 - for vDir := range versionsDirs { - if ctx.Err() != nil { - break - } - // Check if the directory is now empty. - empty := true - listErr := filer_pb.SeaweedList(ctx, client, vDir, "", func(entry *filer_pb.Entry, isLast bool) error { - empty = false - return errLimitReached // stop after first entry - }, "", false, 1) - - if listErr != nil && !errors.Is(listErr, errLimitReached) { - glog.V(1).Infof("s3_lifecycle: failed to check if versions dir %s is empty: %v", vDir, listErr) - continue - } - - if !empty { - continue - } - - // Remove the empty .versions directory. - parentDir, dirName := path.Split(vDir) - parentDir = strings.TrimSuffix(parentDir, "/") - if err := filer_pb.DoRemove(ctx, client, parentDir, dirName, false, true, true, false, nil); err != nil { - glog.V(1).Infof("s3_lifecycle: failed to clean up empty versions dir %s: %v", vDir, err) - } else { - cleaned++ - } - } - return cleaned -} - -// sortVersionsByVersionId sorts version entries newest-first using full -// version ID comparison (matching compareVersionIds in s3api_version_id.go). -// This uses the complete version ID string, not just the decoded timestamp, -// so entries with the same timestamp prefix are correctly ordered by their -// random suffix. -func sortVersionsByVersionId(versions []*filer_pb.Entry) { - sort.Slice(versions, func(i, j int) bool { - vidI := strings.TrimPrefix(versions[i].Name, "v_") - vidJ := strings.TrimPrefix(versions[j].Name, "v_") - return s3lifecycle.CompareVersionIds(vidI, vidJ) < 0 - }) -} diff --git a/weed/plugin/worker/lifecycle/execution_test.go b/weed/plugin/worker/lifecycle/execution_test.go deleted file mode 100644 index cfcae7613..000000000 --- a/weed/plugin/worker/lifecycle/execution_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package lifecycle - -import ( - "testing" - - "github.com/seaweedfs/seaweedfs/weed/s3api/s3lifecycle" -) - -func TestMatchesDeleteMarkerRule(t *testing.T) { - t.Run("nil_rules_legacy_fallback", func(t *testing.T) { - if !matchesDeleteMarkerRule(nil, "any/key") { - t.Error("nil rules should return true (legacy fallback)") - } - }) - - t.Run("empty_rules_xml_present_no_match", func(t *testing.T) { - rules := []s3lifecycle.Rule{} - if matchesDeleteMarkerRule(rules, "any/key") { - t.Error("empty rules (XML present) should return false") - } - }) - - t.Run("matching_prefix_rule", func(t *testing.T) { - rules := []s3lifecycle.Rule{ - {ID: "cleanup", Status: "Enabled", Prefix: "logs/", ExpiredObjectDeleteMarker: true}, - } - if !matchesDeleteMarkerRule(rules, "logs/app.log") { - t.Error("should match rule with matching prefix") - } - }) - - t.Run("non_matching_prefix", func(t *testing.T) { - rules := []s3lifecycle.Rule{ - {ID: "cleanup", Status: "Enabled", Prefix: "logs/", ExpiredObjectDeleteMarker: true}, - } - if matchesDeleteMarkerRule(rules, "data/file.txt") { - t.Error("should not match rule with non-matching prefix") - } - }) - - t.Run("disabled_rule", func(t *testing.T) { - rules := []s3lifecycle.Rule{ - {ID: "cleanup", Status: "Disabled", ExpiredObjectDeleteMarker: true}, - } - if matchesDeleteMarkerRule(rules, "any/key") { - t.Error("disabled rule should not match") - } - }) - - t.Run("rule_without_delete_marker_flag", func(t *testing.T) { - rules := []s3lifecycle.Rule{ - {ID: "expire", Status: "Enabled", ExpirationDays: 30}, - } - if matchesDeleteMarkerRule(rules, "any/key") { - t.Error("rule without ExpiredObjectDeleteMarker should not match") - } - }) - - t.Run("tag_filtered_rule_no_tags_on_marker", func(t *testing.T) { - rules := []s3lifecycle.Rule{ - { - ID: "tagged", Status: "Enabled", - ExpiredObjectDeleteMarker: true, - FilterTags: map[string]string{"env": "dev"}, - }, - } - // Delete markers have no tags, so a tag-filtered rule should not match. - if matchesDeleteMarkerRule(rules, "any/key") { - t.Error("tag-filtered rule should not match delete marker (no tags)") - } - }) -} diff --git a/weed/plugin/worker/lifecycle/handler.go b/weed/plugin/worker/lifecycle/handler.go deleted file mode 100644 index 22ab4d1ff..000000000 --- a/weed/plugin/worker/lifecycle/handler.go +++ /dev/null @@ -1,380 +0,0 @@ -package lifecycle - -import ( - "context" - "fmt" - "strings" - "time" - - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/pb" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/pb/plugin_pb" - pluginworker "github.com/seaweedfs/seaweedfs/weed/plugin/worker" - "google.golang.org/grpc" - "google.golang.org/protobuf/types/known/timestamppb" -) - -func init() { - pluginworker.RegisterHandler(pluginworker.HandlerFactory{ - JobType: jobType, - Category: pluginworker.CategoryHeavy, - Aliases: []string{"lifecycle", "s3-lifecycle", "s3.lifecycle"}, - Build: func(opts pluginworker.HandlerBuildOptions) (pluginworker.JobHandler, error) { - return NewHandler(opts.GrpcDialOption), nil - }, - }) -} - -// Handler implements the JobHandler interface for S3 lifecycle management: -// object expiration, delete marker cleanup, and abort incomplete multipart uploads. -type Handler struct { - grpcDialOption grpc.DialOption -} - -const filerConnectTimeout = 5 * time.Second - -// NewHandler creates a new handler for S3 lifecycle management. -func NewHandler(grpcDialOption grpc.DialOption) *Handler { - return &Handler{grpcDialOption: grpcDialOption} -} - -func (h *Handler) Capability() *plugin_pb.JobTypeCapability { - return &plugin_pb.JobTypeCapability{ - JobType: jobType, - CanDetect: true, - CanExecute: true, - MaxDetectionConcurrency: 1, - MaxExecutionConcurrency: 4, - DisplayName: "S3 Lifecycle", - Description: "Manages S3 object lifecycle: expiration of objects based on TTL rules, delete marker cleanup, and abort of incomplete multipart uploads", - Weight: 40, - } -} - -func (h *Handler) Descriptor() *plugin_pb.JobTypeDescriptor { - return &plugin_pb.JobTypeDescriptor{ - JobType: jobType, - DisplayName: "S3 Lifecycle Management", - Description: "Automated S3 object lifecycle management: expire objects by TTL rules, clean up expired delete markers, and abort stale multipart uploads", - Icon: "fas fa-hourglass-half", - DescriptorVersion: 1, - AdminConfigForm: &plugin_pb.ConfigForm{ - FormId: "s3-lifecycle-admin", - Title: "S3 Lifecycle Admin Config", - Description: "Admin-side controls for S3 lifecycle management scope.", - Sections: []*plugin_pb.ConfigSection{ - { - SectionId: "scope", - Title: "Scope", - Description: "Which buckets to include in lifecycle management.", - Fields: []*plugin_pb.ConfigField{ - { - Name: "bucket_filter", - Label: "Bucket Filter", - Description: "Wildcard pattern for bucket names to include (e.g. \"prod-*\"). Empty means all buckets.", - FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_STRING, - Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_TEXT, - }, - }, - }, - }, - }, - WorkerConfigForm: &plugin_pb.ConfigForm{ - FormId: "s3-lifecycle-worker", - Title: "S3 Lifecycle Worker Config", - Description: "Worker-side controls for lifecycle execution behavior.", - Sections: []*plugin_pb.ConfigSection{ - { - SectionId: "execution", - Title: "Execution", - Description: "Controls for lifecycle rule execution.", - Fields: []*plugin_pb.ConfigField{ - { - Name: "batch_size", - Label: "Batch Size", - Description: "Number of entries to process per filer listing page.", - FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_INT64, - Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_NUMBER, - MinValue: configInt64(100), - MaxValue: configInt64(10000), - }, - { - Name: "max_deletes_per_bucket", - Label: "Max Deletes Per Bucket", - Description: "Maximum number of expired objects to delete per bucket in one execution run.", - FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_INT64, - Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_NUMBER, - MinValue: configInt64(100), - MaxValue: configInt64(1000000), - }, - { - Name: "dry_run", - Label: "Dry Run", - Description: "When enabled, detect expired objects but do not delete them.", - FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_BOOL, - Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_TOGGLE, - }, - { - Name: "delete_marker_cleanup", - Label: "Delete Marker Cleanup", - Description: "Remove expired delete markers that have no non-current versions.", - FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_BOOL, - Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_TOGGLE, - }, - { - Name: "abort_mpu_days", - Label: "Abort Incomplete MPU (days)", - Description: "Abort incomplete multipart uploads older than this many days. 0 disables.", - FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_INT64, - Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_NUMBER, - MinValue: configInt64(0), - MaxValue: configInt64(365), - }, - }, - }, - }, - }, - AdminRuntimeDefaults: &plugin_pb.AdminRuntimeDefaults{ - Enabled: true, - DetectionIntervalSeconds: 300, // 5 minutes - DetectionTimeoutSeconds: 60, - MaxJobsPerDetection: 100, - GlobalExecutionConcurrency: 2, - PerWorkerExecutionConcurrency: 2, - RetryLimit: 1, - RetryBackoffSeconds: 10, - }, - WorkerDefaultValues: map[string]*plugin_pb.ConfigValue{ - "batch_size": {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: defaultBatchSize}}, - "max_deletes_per_bucket": {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: defaultMaxDeletesPerBucket}}, - "dry_run": {Kind: &plugin_pb.ConfigValue_BoolValue{BoolValue: defaultDryRun}}, - "delete_marker_cleanup": {Kind: &plugin_pb.ConfigValue_BoolValue{BoolValue: defaultDeleteMarkerCleanup}}, - "abort_mpu_days": {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: defaultAbortMPUDaysDefault}}, - }, - } -} - -func (h *Handler) Detect(ctx context.Context, req *plugin_pb.RunDetectionRequest, sender pluginworker.DetectionSender) error { - if req == nil { - return fmt.Errorf("nil detection request") - } - - config := ParseConfig(req.WorkerConfigValues) - - bucketFilter := readStringConfig(req.AdminConfigValues, "bucket_filter", "") - - filerAddresses := filerAddressesFromCluster(req.ClusterContext) - if len(filerAddresses) == 0 { - _ = sender.SendActivity(pluginworker.BuildDetectorActivity("skipped", "no filer addresses in cluster context", nil)) - return sendEmptyDetection(sender) - } - - _ = sender.SendActivity(pluginworker.BuildDetectorActivity("connecting", "connecting to filer", nil)) - - filerClient, filerConn, err := connectToFiler(ctx, filerAddresses, h.grpcDialOption) - if err != nil { - return fmt.Errorf("failed to connect to any filer: %v", err) - } - defer filerConn.Close() - - maxResults := int(req.MaxResults) - if maxResults <= 0 { - maxResults = 100 - } - - _ = sender.SendActivity(pluginworker.BuildDetectorActivity("scanning", "scanning buckets for lifecycle rules", nil)) - proposals, err := h.detectBucketsWithLifecycleRules(ctx, filerClient, config, bucketFilter, maxResults) - if err != nil { - _ = sender.SendActivity(pluginworker.BuildDetectorActivity("scan_error", fmt.Sprintf("error scanning buckets: %v", err), nil)) - return fmt.Errorf("detect lifecycle rules: %w", err) - } - - _ = sender.SendActivity(pluginworker.BuildDetectorActivity("scan_complete", - fmt.Sprintf("found %d bucket(s) with lifecycle rules", len(proposals)), - map[string]*plugin_pb.ConfigValue{ - "buckets_found": {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: int64(len(proposals))}}, - })) - - if err := sender.SendProposals(&plugin_pb.DetectionProposals{ - JobType: jobType, - Proposals: proposals, - HasMore: len(proposals) >= maxResults, - }); err != nil { - return err - } - - return sender.SendComplete(&plugin_pb.DetectionComplete{ - JobType: jobType, - Success: true, - TotalProposals: int32(len(proposals)), - }) -} - -func (h *Handler) Execute(ctx context.Context, req *plugin_pb.ExecuteJobRequest, sender pluginworker.ExecutionSender) error { - if req == nil || req.Job == nil { - return fmt.Errorf("nil execution request") - } - - job := req.Job - config := ParseConfig(req.WorkerConfigValues) - - bucket := readParamString(job.Parameters, "bucket") - bucketsPath := readParamString(job.Parameters, "buckets_path") - if bucket == "" || bucketsPath == "" { - return fmt.Errorf("missing bucket or buckets_path parameter") - } - - filerAddresses := filerAddressesFromCluster(req.ClusterContext) - if len(filerAddresses) == 0 { - return fmt.Errorf("no filer addresses in cluster context") - } - - filerClient, filerConn, err := connectToFiler(ctx, filerAddresses, h.grpcDialOption) - if err != nil { - return fmt.Errorf("failed to connect to any filer: %v", err) - } - defer filerConn.Close() - - _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{ - JobId: job.JobId, - JobType: jobType, - State: plugin_pb.JobState_JOB_STATE_ASSIGNED, - ProgressPercent: 0, - Stage: "starting", - Message: fmt.Sprintf("executing lifecycle rules for bucket %s", bucket), - }) - - start := time.Now() - result, execErr := h.executeLifecycleForBucket(ctx, filerClient, config, bucket, bucketsPath, sender, job.JobId) - elapsed := time.Since(start) - - metrics := map[string]*plugin_pb.ConfigValue{ - MetricDurationMs: {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: elapsed.Milliseconds()}}, - } - if result != nil { - metrics[MetricObjectsExpired] = &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: result.objectsExpired}} - metrics[MetricObjectsScanned] = &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: result.objectsScanned}} - metrics[MetricDeleteMarkersClean] = &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: result.deleteMarkersClean}} - metrics[MetricMPUAborted] = &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: result.mpuAborted}} - metrics[MetricErrors] = &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: result.errors}} - } - - var scanned, expired int64 - if result != nil { - scanned = result.objectsScanned - expired = result.objectsExpired - } - - success := execErr == nil && (result == nil || result.errors == 0) - message := fmt.Sprintf("bucket %s: scanned %d objects, expired %d", bucket, scanned, expired) - if result != nil && result.deleteMarkersClean > 0 { - message += fmt.Sprintf(", delete markers cleaned %d", result.deleteMarkersClean) - } - if result != nil && result.mpuAborted > 0 { - message += fmt.Sprintf(", MPUs aborted %d", result.mpuAborted) - } - if config.DryRun { - message += " (dry run)" - } - if result != nil && result.errors > 0 { - message += fmt.Sprintf(" (%d errors)", result.errors) - } - if execErr != nil { - message = fmt.Sprintf("lifecycle execution failed for bucket %s: %v", bucket, execErr) - } - - errMsg := "" - if execErr != nil { - errMsg = execErr.Error() - } else if result != nil && result.errors > 0 { - errMsg = fmt.Sprintf("%d objects failed to process", result.errors) - } - - return sender.SendCompleted(&plugin_pb.JobCompleted{ - JobId: job.JobId, - JobType: jobType, - Success: success, - ErrorMessage: errMsg, - Result: &plugin_pb.JobResult{ - Summary: message, - OutputValues: metrics, - }, - CompletedAt: timestamppb.Now(), - }) -} - -func connectToFiler(ctx context.Context, addresses []string, dialOption grpc.DialOption) (filer_pb.SeaweedFilerClient, *grpc.ClientConn, error) { - var lastErr error - for _, addr := range addresses { - grpcAddr := pb.ServerAddress(addr).ToGrpcAddress() - connCtx, cancel := context.WithTimeout(ctx, filerConnectTimeout) - conn, err := pb.GrpcDial(connCtx, grpcAddr, false, dialOption) - cancel() - if err != nil { - lastErr = err - glog.V(1).Infof("s3_lifecycle: failed to connect to filer %s (grpc %s): %v", addr, grpcAddr, err) - continue - } - // Verify the connection with a ping. - client := filer_pb.NewSeaweedFilerClient(conn) - pingCtx, pingCancel := context.WithTimeout(ctx, filerConnectTimeout) - _, pingErr := client.Ping(pingCtx, &filer_pb.PingRequest{}) - pingCancel() - if pingErr != nil { - _ = conn.Close() - lastErr = pingErr - glog.V(1).Infof("s3_lifecycle: filer %s ping failed: %v", grpcAddr, pingErr) - continue - } - return client, conn, nil - } - return nil, nil, lastErr -} - -func sendEmptyDetection(sender pluginworker.DetectionSender) error { - if err := sender.SendProposals(&plugin_pb.DetectionProposals{ - JobType: jobType, - Proposals: []*plugin_pb.JobProposal{}, - HasMore: false, - }); err != nil { - return err - } - return sender.SendComplete(&plugin_pb.DetectionComplete{ - JobType: jobType, - Success: true, - TotalProposals: 0, - }) -} - -func filerAddressesFromCluster(cc *plugin_pb.ClusterContext) []string { - if cc == nil { - return nil - } - var addrs []string - for _, addr := range cc.FilerGrpcAddresses { - trimmed := strings.TrimSpace(addr) - if trimmed != "" { - addrs = append(addrs, trimmed) - } - } - return addrs -} - -func readParamString(params map[string]*plugin_pb.ConfigValue, key string) string { - if params == nil { - return "" - } - v := params[key] - if v == nil { - return "" - } - if sv, ok := v.Kind.(*plugin_pb.ConfigValue_StringValue); ok { - return sv.StringValue - } - return "" -} - -func configInt64(v int64) *plugin_pb.ConfigValue { - return &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: v}} -} diff --git a/weed/plugin/worker/lifecycle/integration_test.go b/weed/plugin/worker/lifecycle/integration_test.go deleted file mode 100644 index 60b11175c..000000000 --- a/weed/plugin/worker/lifecycle/integration_test.go +++ /dev/null @@ -1,781 +0,0 @@ -package lifecycle - -import ( - "context" - "fmt" - "math" - "net" - "sort" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/pb" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3lifecycle" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" -) - -// testFilerServer is an in-memory filer gRPC server for integration tests. -type testFilerServer struct { - filer_pb.UnimplementedSeaweedFilerServer - mu sync.RWMutex - entries map[string]*filer_pb.Entry // key: "dir\x00name" -} - -func newTestFilerServer() *testFilerServer { - return &testFilerServer{entries: make(map[string]*filer_pb.Entry)} -} - -func (s *testFilerServer) key(dir, name string) string { return dir + "\x00" + name } - -func (s *testFilerServer) splitKey(key string) (string, string) { - for i := range key { - if key[i] == '\x00' { - return key[:i], key[i+1:] - } - } - return key, "" -} - -func (s *testFilerServer) putEntry(dir string, entry *filer_pb.Entry) { - s.mu.Lock() - defer s.mu.Unlock() - s.entries[s.key(dir, entry.Name)] = proto.Clone(entry).(*filer_pb.Entry) -} - -func (s *testFilerServer) getEntry(dir, name string) *filer_pb.Entry { - s.mu.RLock() - defer s.mu.RUnlock() - e := s.entries[s.key(dir, name)] - if e == nil { - return nil - } - return proto.Clone(e).(*filer_pb.Entry) -} - -func (s *testFilerServer) hasEntry(dir, name string) bool { - s.mu.RLock() - defer s.mu.RUnlock() - _, ok := s.entries[s.key(dir, name)] - return ok -} - -func (s *testFilerServer) LookupDirectoryEntry(_ context.Context, req *filer_pb.LookupDirectoryEntryRequest) (*filer_pb.LookupDirectoryEntryResponse, error) { - s.mu.RLock() - defer s.mu.RUnlock() - entry, found := s.entries[s.key(req.Directory, req.Name)] - if !found { - return nil, status.Error(codes.NotFound, filer_pb.ErrNotFound.Error()) - } - return &filer_pb.LookupDirectoryEntryResponse{Entry: proto.Clone(entry).(*filer_pb.Entry)}, nil -} - -func (s *testFilerServer) ListEntries(req *filer_pb.ListEntriesRequest, stream grpc.ServerStreamingServer[filer_pb.ListEntriesResponse]) error { - // Snapshot entries under lock, then stream without holding the lock - // (streaming callbacks may trigger DeleteEntry which needs a write lock). - s.mu.RLock() - var names []string - for key := range s.entries { - dir, name := s.splitKey(key) - if dir == req.Directory { - if req.StartFromFileName != "" && name <= req.StartFromFileName { - continue - } - if req.Prefix != "" && !strings.HasPrefix(name, req.Prefix) { - continue - } - names = append(names, name) - } - } - sort.Strings(names) - - // Clone entries while still holding the lock. - type namedEntry struct { - name string - entry *filer_pb.Entry - } - snapshot := make([]namedEntry, 0, len(names)) - for _, name := range names { - if req.Limit > 0 && uint32(len(snapshot)) >= req.Limit { - break - } - snapshot = append(snapshot, namedEntry{ - name: name, - entry: proto.Clone(s.entries[s.key(req.Directory, name)]).(*filer_pb.Entry), - }) - } - s.mu.RUnlock() - - // Stream responses without holding any lock. - for _, ne := range snapshot { - if err := stream.Send(&filer_pb.ListEntriesResponse{Entry: ne.entry}); err != nil { - return err - } - } - return nil -} - -func (s *testFilerServer) CreateEntry(_ context.Context, req *filer_pb.CreateEntryRequest) (*filer_pb.CreateEntryResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - s.entries[s.key(req.Directory, req.Entry.Name)] = proto.Clone(req.Entry).(*filer_pb.Entry) - return &filer_pb.CreateEntryResponse{}, nil -} - -func (s *testFilerServer) DeleteEntry(_ context.Context, req *filer_pb.DeleteEntryRequest) (*filer_pb.DeleteEntryResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - k := s.key(req.Directory, req.Name) - if _, found := s.entries[k]; !found { - return nil, status.Error(codes.NotFound, filer_pb.ErrNotFound.Error()) - } - delete(s.entries, k) - if req.IsRecursive { - // Delete all descendants: any entry whose directory starts with - // the deleted path (handles nested subdirectories). - deletedPath := req.Directory + "/" + req.Name - for key := range s.entries { - dir, _ := s.splitKey(key) - if dir == deletedPath || strings.HasPrefix(dir, deletedPath+"/") { - delete(s.entries, key) - } - } - } - return &filer_pb.DeleteEntryResponse{}, nil -} - -// startTestFiler starts an in-memory filer gRPC server and returns a client. -func startTestFiler(t *testing.T) (*testFilerServer, filer_pb.SeaweedFilerClient) { - t.Helper() - - lis, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen: %v", err) - } - - server := newTestFilerServer() - grpcServer := pb.NewGrpcServer() - filer_pb.RegisterSeaweedFilerServer(grpcServer, server) - go func() { _ = grpcServer.Serve(lis) }() - - t.Cleanup(func() { - grpcServer.Stop() - _ = lis.Close() - }) - - host, portStr, err := net.SplitHostPort(lis.Addr().String()) - if err != nil { - t.Fatalf("split host port: %v", err) - } - port, err := strconv.Atoi(portStr) - if err != nil { - t.Fatalf("parse port: %v", err) - } - addr := pb.NewServerAddress(host, 1, port) - - conn, err := pb.GrpcDial(context.Background(), addr.ToGrpcAddress(), false, grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - t.Fatalf("dial: %v", err) - } - t.Cleanup(func() { _ = conn.Close() }) - - return server, filer_pb.NewSeaweedFilerClient(conn) -} - -// Helper to create a version ID from a timestamp. -func testVersionId(ts time.Time) string { - inverted := math.MaxInt64 - ts.UnixNano() - return fmt.Sprintf("%016x", inverted) + "0000000000000000" -} - -func TestIntegration_ListExpiredObjectsByRules(t *testing.T) { - server, client := startTestFiler(t) - bucketsPath := "/buckets" - bucket := "test-bucket" - bucketDir := bucketsPath + "/" + bucket - - now := time.Now() - old := now.Add(-60 * 24 * time.Hour) // 60 days ago - recent := now.Add(-5 * 24 * time.Hour) // 5 days ago - - // Create bucket directory. - server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true}) - - // Create objects. - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "old-file.txt", - Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 1024}, - }) - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "recent-file.txt", - Attributes: &filer_pb.FuseAttributes{Mtime: recent.Unix(), FileSize: 1024}, - }) - - rules := []s3lifecycle.Rule{{ - ID: "expire-30d", Status: "Enabled", - ExpirationDays: 30, - }} - - expired, scanned, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100) - if err != nil { - t.Fatalf("listExpiredObjectsByRules: %v", err) - } - - if scanned != 2 { - t.Errorf("expected 2 scanned, got %d", scanned) - } - if len(expired) != 1 { - t.Fatalf("expected 1 expired, got %d", len(expired)) - } - if expired[0].name != "old-file.txt" { - t.Errorf("expected old-file.txt expired, got %s", expired[0].name) - } -} - -func TestIntegration_ListExpiredObjectsByRules_TagFilter(t *testing.T) { - server, client := startTestFiler(t) - bucketsPath := "/buckets" - bucket := "tag-bucket" - bucketDir := bucketsPath + "/" + bucket - - old := time.Now().Add(-60 * 24 * time.Hour) - - server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true}) - - // Object with matching tag. - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "tagged.txt", - Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 100}, - Extended: map[string][]byte{"X-Amz-Tagging-env": []byte("dev")}, - }) - // Object without tag. - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "untagged.txt", - Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 100}, - }) - - rules := []s3lifecycle.Rule{{ - ID: "tag-expire", Status: "Enabled", - ExpirationDays: 30, - FilterTags: map[string]string{"env": "dev"}, - }} - - expired, _, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100) - if err != nil { - t.Fatalf("listExpiredObjectsByRules: %v", err) - } - - if len(expired) != 1 { - t.Fatalf("expected 1 expired (tagged only), got %d", len(expired)) - } - if expired[0].name != "tagged.txt" { - t.Errorf("expected tagged.txt, got %s", expired[0].name) - } -} - -func TestIntegration_ProcessVersionsDirectory(t *testing.T) { - server, client := startTestFiler(t) - bucketsPath := "/buckets" - bucket := "versioned-bucket" - bucketDir := bucketsPath + "/" + bucket - versionsDir := bucketDir + "/key.versions" - - now := time.Now() - t1 := now.Add(-90 * 24 * time.Hour) // oldest - t2 := now.Add(-60 * 24 * time.Hour) - t3 := now.Add(-1 * 24 * time.Hour) // newest (latest) - - vid1 := testVersionId(t1) - vid2 := testVersionId(t2) - vid3 := testVersionId(t3) - - server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true}) - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "key.versions", IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.ExtLatestVersionIdKey: []byte(vid3), - }, - }) - - // Three versions: vid3 (latest), vid2 (noncurrent), vid1 (noncurrent) - server.putEntry(versionsDir, &filer_pb.Entry{ - Name: "v_" + vid3, - Attributes: &filer_pb.FuseAttributes{Mtime: t3.Unix(), FileSize: 100}, - Extended: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte(vid3), - }, - }) - server.putEntry(versionsDir, &filer_pb.Entry{ - Name: "v_" + vid2, - Attributes: &filer_pb.FuseAttributes{Mtime: t2.Unix(), FileSize: 100}, - Extended: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte(vid2), - }, - }) - server.putEntry(versionsDir, &filer_pb.Entry{ - Name: "v_" + vid1, - Attributes: &filer_pb.FuseAttributes{Mtime: t1.Unix(), FileSize: 100}, - Extended: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte(vid1), - }, - }) - - rules := []s3lifecycle.Rule{{ - ID: "noncurrent-30d", Status: "Enabled", - NoncurrentVersionExpirationDays: 30, - }} - - expired, scanned, err := processVersionsDirectory( - context.Background(), client, versionsDir, bucketDir, - rules, now, false, 100, - ) - if err != nil { - t.Fatalf("processVersionsDirectory: %v", err) - } - - // vid3 is latest (not expired). vid2 became noncurrent when vid3 was created - // (1 day ago), so vid2 is NOT old enough (< 30 days noncurrent). - // vid1 became noncurrent when vid2 was created (60 days ago), so vid1 IS expired. - if scanned != 2 { - t.Errorf("expected 2 scanned (non-current versions), got %d", scanned) - } - if len(expired) != 1 { - t.Fatalf("expected 1 expired (only vid1), got %d", len(expired)) - } - if expired[0].name != "v_"+vid1 { - t.Errorf("expected v_%s expired, got %s", vid1, expired[0].name) - } -} - -func TestIntegration_ProcessVersionsDirectory_NewerNoncurrentVersions(t *testing.T) { - server, client := startTestFiler(t) - bucketsPath := "/buckets" - bucket := "keep-n-bucket" - bucketDir := bucketsPath + "/" + bucket - versionsDir := bucketDir + "/obj.versions" - - now := time.Now() - // Create 5 versions, all old enough to expire by days alone. - versions := make([]time.Time, 5) - vids := make([]string, 5) - for i := 0; i < 5; i++ { - versions[i] = now.Add(time.Duration(-(90 - i*10)) * 24 * time.Hour) - vids[i] = testVersionId(versions[i]) - } - // vids[4] is newest (latest), vids[0] is oldest - - server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true}) - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "obj.versions", IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.ExtLatestVersionIdKey: []byte(vids[4]), - }, - }) - - for i, vid := range vids { - server.putEntry(versionsDir, &filer_pb.Entry{ - Name: "v_" + vid, - Attributes: &filer_pb.FuseAttributes{Mtime: versions[i].Unix(), FileSize: 100}, - Extended: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte(vid), - }, - }) - } - - rules := []s3lifecycle.Rule{{ - ID: "keep-2", Status: "Enabled", - NoncurrentVersionExpirationDays: 7, - NewerNoncurrentVersions: 2, - }} - - expired, _, err := processVersionsDirectory( - context.Background(), client, versionsDir, bucketDir, - rules, now, false, 100, - ) - if err != nil { - t.Fatalf("processVersionsDirectory: %v", err) - } - - // 4 noncurrent versions (vids[0..3]). Keep newest 2 (vids[3], vids[2]). - // Expire vids[1] and vids[0]. - if len(expired) != 2 { - t.Fatalf("expected 2 expired (keep 2 newest noncurrent), got %d", len(expired)) - } - expiredNames := map[string]bool{} - for _, e := range expired { - expiredNames[e.name] = true - } - if !expiredNames["v_"+vids[0]] { - t.Errorf("expected vids[0] (oldest) to be expired") - } - if !expiredNames["v_"+vids[1]] { - t.Errorf("expected vids[1] to be expired") - } -} - -func TestIntegration_AbortMPUsByRules(t *testing.T) { - server, client := startTestFiler(t) - bucketsPath := "/buckets" - bucket := "mpu-bucket" - uploadsDir := bucketsPath + "/" + bucket + "/.uploads" - - now := time.Now() - old := now.Add(-10 * 24 * time.Hour) - recent := now.Add(-2 * 24 * time.Hour) - - server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true}) - server.putEntry(bucketsPath+"/"+bucket, &filer_pb.Entry{Name: ".uploads", IsDirectory: true}) - - // Old upload under logs/ prefix. - server.putEntry(uploadsDir, &filer_pb.Entry{ - Name: "logs_upload1", IsDirectory: true, - Attributes: &filer_pb.FuseAttributes{Crtime: old.Unix()}, - }) - // Recent upload under logs/ prefix. - server.putEntry(uploadsDir, &filer_pb.Entry{ - Name: "logs_upload2", IsDirectory: true, - Attributes: &filer_pb.FuseAttributes{Crtime: recent.Unix()}, - }) - // Old upload under data/ prefix (should not match logs/ rule). - server.putEntry(uploadsDir, &filer_pb.Entry{ - Name: "data_upload1", IsDirectory: true, - Attributes: &filer_pb.FuseAttributes{Crtime: old.Unix()}, - }) - - rules := []s3lifecycle.Rule{{ - ID: "abort-logs", Status: "Enabled", - Prefix: "logs", - AbortMPUDaysAfterInitiation: 7, - }} - - aborted, errs, err := abortMPUsByRules(context.Background(), client, bucketsPath, bucket, rules, 100) - if err != nil { - t.Fatalf("abortMPUsByRules: %v", err) - } - if errs != 0 { - t.Errorf("expected 0 errors, got %d", errs) - } - - // Only logs_upload1 should be aborted (old + matches prefix). - // logs_upload2 is too recent, data_upload1 doesn't match prefix. - if aborted != 1 { - t.Errorf("expected 1 aborted, got %d", aborted) - } - - // Verify the correct upload was removed. - if server.hasEntry(uploadsDir, "logs_upload1") { - t.Error("logs_upload1 should have been removed") - } - if !server.hasEntry(uploadsDir, "logs_upload2") { - t.Error("logs_upload2 should still exist") - } - if !server.hasEntry(uploadsDir, "data_upload1") { - t.Error("data_upload1 should still exist (wrong prefix)") - } -} - -func TestIntegration_DeleteExpiredObjects(t *testing.T) { - server, client := startTestFiler(t) - bucketsPath := "/buckets" - bucket := "delete-bucket" - bucketDir := bucketsPath + "/" + bucket - - now := time.Now() - old := now.Add(-60 * 24 * time.Hour) - - server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true}) - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "to-delete.txt", - Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 100}, - }) - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "to-keep.txt", - Attributes: &filer_pb.FuseAttributes{Mtime: now.Unix(), FileSize: 100}, - }) - - rules := []s3lifecycle.Rule{{ - ID: "expire", Status: "Enabled", - ExpirationDays: 30, - }} - - expired, _, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100) - if err != nil { - t.Fatalf("list: %v", err) - } - - // Actually delete them. - deleted, errs, err := deleteExpiredObjects(context.Background(), client, expired) - if err != nil { - t.Fatalf("delete: %v", err) - } - if deleted != 1 || errs != 0 { - t.Errorf("expected 1 deleted 0 errors, got %d deleted %d errors", deleted, errs) - } - - if server.hasEntry(bucketDir, "to-delete.txt") { - t.Error("to-delete.txt should have been removed") - } - if !server.hasEntry(bucketDir, "to-keep.txt") { - t.Error("to-keep.txt should still exist") - } -} - -// TestIntegration_VersionedBucket_ExpirationDays verifies that Expiration.Days -// rules correctly detect and delete the latest version in a versioned bucket -// where all data lives in .versions/ directories (issue #8757). -func TestIntegration_VersionedBucket_ExpirationDays(t *testing.T) { - server, client := startTestFiler(t) - bucketsPath := "/buckets" - bucket := "versioned-expire" - bucketDir := bucketsPath + "/" + bucket - - now := time.Now() - old := now.Add(-60 * 24 * time.Hour) // 60 days ago — should expire - recent := now.Add(-5 * 24 * time.Hour) // 5 days ago — should NOT expire - - vidOld := testVersionId(old) - vidRecent := testVersionId(recent) - - server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true}) - - // --- Single-version object (old, should expire) --- - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "old-file.txt" + s3_constants.VersionsFolder, IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.ExtLatestVersionIdKey: []byte(vidOld), - s3_constants.ExtLatestVersionFileNameKey: []byte("v_" + vidOld), - s3_constants.ExtLatestVersionMtimeKey: []byte(strconv.FormatInt(old.Unix(), 10)), - s3_constants.ExtLatestVersionSizeKey: []byte("3400000000"), - s3_constants.ExtLatestVersionIsDeleteMarker: []byte("false"), - }, - }) - oldVersionsDir := bucketDir + "/old-file.txt" + s3_constants.VersionsFolder - server.putEntry(oldVersionsDir, &filer_pb.Entry{ - Name: "v_" + vidOld, - Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 3400000000}, - Extended: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte(vidOld), - }, - }) - - // --- Single-version object (recent, should NOT expire) --- - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "recent-file.txt" + s3_constants.VersionsFolder, IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.ExtLatestVersionIdKey: []byte(vidRecent), - s3_constants.ExtLatestVersionFileNameKey: []byte("v_" + vidRecent), - s3_constants.ExtLatestVersionMtimeKey: []byte(strconv.FormatInt(recent.Unix(), 10)), - s3_constants.ExtLatestVersionSizeKey: []byte("3400000000"), - s3_constants.ExtLatestVersionIsDeleteMarker: []byte("false"), - }, - }) - recentVersionsDir := bucketDir + "/recent-file.txt" + s3_constants.VersionsFolder - server.putEntry(recentVersionsDir, &filer_pb.Entry{ - Name: "v_" + vidRecent, - Attributes: &filer_pb.FuseAttributes{Mtime: recent.Unix(), FileSize: 3400000000}, - Extended: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte(vidRecent), - }, - }) - - // --- Object with delete marker as latest (should NOT be expired by Expiration.Days) --- - vidMarker := testVersionId(old) - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "deleted-obj.txt" + s3_constants.VersionsFolder, IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.ExtLatestVersionIdKey: []byte(vidMarker), - s3_constants.ExtLatestVersionFileNameKey: []byte("v_" + vidMarker), - s3_constants.ExtLatestVersionMtimeKey: []byte(strconv.FormatInt(old.Unix(), 10)), - s3_constants.ExtLatestVersionIsDeleteMarker: []byte("true"), - }, - }) - - rules := []s3lifecycle.Rule{{ - ID: "expire-30d", Status: "Enabled", - ExpirationDays: 30, - }} - - expired, scanned, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100) - if err != nil { - t.Fatalf("listExpiredObjectsByRules: %v", err) - } - - // Only old-file.txt's latest version should be expired. - // recent-file.txt is too young; deleted-obj.txt is a delete marker. - if len(expired) != 1 { - t.Fatalf("expected 1 expired, got %d: %+v", len(expired), expired) - } - if expired[0].dir != oldVersionsDir { - t.Errorf("expected dir=%s, got %s", oldVersionsDir, expired[0].dir) - } - if expired[0].name != "v_"+vidOld { - t.Errorf("expected name=v_%s, got %s", vidOld, expired[0].name) - } - // The old-file.txt latest version should count as scanned. - if scanned < 1 { - t.Errorf("expected at least 1 scanned, got %d", scanned) - } -} - -// TestIntegration_VersionedBucket_ExpirationDays_DeleteAndCleanup verifies -// end-to-end deletion and .versions directory cleanup for a single-version -// versioned object expired by Expiration.Days. -func TestIntegration_VersionedBucket_ExpirationDays_DeleteAndCleanup(t *testing.T) { - server, client := startTestFiler(t) - bucketsPath := "/buckets" - bucket := "versioned-cleanup" - bucketDir := bucketsPath + "/" + bucket - - now := time.Now() - old := now.Add(-60 * 24 * time.Hour) - vidOld := testVersionId(old) - - server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true}) - - // Single-version object that should expire. - versionsDir := bucketDir + "/data.bin" + s3_constants.VersionsFolder - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "data.bin" + s3_constants.VersionsFolder, IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.ExtLatestVersionIdKey: []byte(vidOld), - s3_constants.ExtLatestVersionFileNameKey: []byte("v_" + vidOld), - s3_constants.ExtLatestVersionMtimeKey: []byte(strconv.FormatInt(old.Unix(), 10)), - s3_constants.ExtLatestVersionSizeKey: []byte("1024"), - s3_constants.ExtLatestVersionIsDeleteMarker: []byte("false"), - }, - }) - server.putEntry(versionsDir, &filer_pb.Entry{ - Name: "v_" + vidOld, - Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 1024}, - Extended: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte(vidOld), - }, - }) - - rules := []s3lifecycle.Rule{{ - ID: "expire-30d", Status: "Enabled", - ExpirationDays: 30, - }} - - // Step 1: Detect expired. - expired, _, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100) - if err != nil { - t.Fatalf("list: %v", err) - } - if len(expired) != 1 { - t.Fatalf("expected 1 expired, got %d", len(expired)) - } - - // Step 2: Delete the expired version file. - deleted, errs, delErr := deleteExpiredObjects(context.Background(), client, expired) - if delErr != nil { - t.Fatalf("delete: %v", delErr) - } - if deleted != 1 || errs != 0 { - t.Errorf("expected 1 deleted 0 errors, got %d deleted %d errors", deleted, errs) - } - - // Version file should be gone. - if server.hasEntry(versionsDir, "v_"+vidOld) { - t.Error("version file should have been removed") - } - - // Step 3: Cleanup empty .versions directory. - cleaned := cleanupEmptyVersionsDirectories(context.Background(), client, expired) - if cleaned != 1 { - t.Errorf("expected 1 directory cleaned, got %d", cleaned) - } - - // The .versions directory itself should be gone. - if server.hasEntry(bucketDir, "data.bin"+s3_constants.VersionsFolder) { - t.Error(".versions directory should have been removed after cleanup") - } -} - -// TestIntegration_VersionedBucket_MultiVersion_ExpirationDays verifies that -// when a multi-version object's latest version expires, only the latest -// version is deleted and noncurrent versions remain. -func TestIntegration_VersionedBucket_MultiVersion_ExpirationDays(t *testing.T) { - server, client := startTestFiler(t) - bucketsPath := "/buckets" - bucket := "versioned-multi" - bucketDir := bucketsPath + "/" + bucket - - now := time.Now() - tOld := now.Add(-60 * 24 * time.Hour) - tNoncurrent := now.Add(-90 * 24 * time.Hour) - vidLatest := testVersionId(tOld) - vidNoncurrent := testVersionId(tNoncurrent) - - server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true}) - - versionsDir := bucketDir + "/multi.txt" + s3_constants.VersionsFolder - server.putEntry(bucketDir, &filer_pb.Entry{ - Name: "multi.txt" + s3_constants.VersionsFolder, IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.ExtLatestVersionIdKey: []byte(vidLatest), - s3_constants.ExtLatestVersionFileNameKey: []byte("v_" + vidLatest), - s3_constants.ExtLatestVersionMtimeKey: []byte(strconv.FormatInt(tOld.Unix(), 10)), - s3_constants.ExtLatestVersionSizeKey: []byte("500"), - s3_constants.ExtLatestVersionIsDeleteMarker: []byte("false"), - }, - }) - server.putEntry(versionsDir, &filer_pb.Entry{ - Name: "v_" + vidLatest, - Attributes: &filer_pb.FuseAttributes{Mtime: tOld.Unix(), FileSize: 500}, - Extended: map[string][]byte{s3_constants.ExtVersionIdKey: []byte(vidLatest)}, - }) - server.putEntry(versionsDir, &filer_pb.Entry{ - Name: "v_" + vidNoncurrent, - Attributes: &filer_pb.FuseAttributes{Mtime: tNoncurrent.Unix(), FileSize: 500}, - Extended: map[string][]byte{s3_constants.ExtVersionIdKey: []byte(vidNoncurrent)}, - }) - - rules := []s3lifecycle.Rule{{ - ID: "expire-30d", Status: "Enabled", - ExpirationDays: 30, - }} - - expired, _, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100) - if err != nil { - t.Fatalf("list: %v", err) - } - // Only the latest version should be detected as expired. - if len(expired) != 1 { - t.Fatalf("expected 1 expired (latest only), got %d", len(expired)) - } - if expired[0].name != "v_"+vidLatest { - t.Errorf("expected latest version expired, got %s", expired[0].name) - } - - // Delete it. - deleted, errs, delErr := deleteExpiredObjects(context.Background(), client, expired) - if delErr != nil { - t.Fatalf("delete: %v", delErr) - } - if deleted != 1 || errs != 0 { - t.Errorf("expected 1 deleted 0 errors, got %d deleted %d errors", deleted, errs) - } - - // Noncurrent version should still exist. - if !server.hasEntry(versionsDir, "v_"+vidNoncurrent) { - t.Error("noncurrent version should still exist") - } - - // .versions directory should NOT be cleaned up (not empty). - cleaned := cleanupEmptyVersionsDirectories(context.Background(), client, expired) - if cleaned != 0 { - t.Errorf("expected 0 directories cleaned (not empty), got %d", cleaned) - } - if !server.hasEntry(bucketDir, "multi.txt"+s3_constants.VersionsFolder) { - t.Error(".versions directory should still exist (has noncurrent version)") - } -} diff --git a/weed/plugin/worker/lifecycle/rules.go b/weed/plugin/worker/lifecycle/rules.go deleted file mode 100644 index c3855f22c..000000000 --- a/weed/plugin/worker/lifecycle/rules.go +++ /dev/null @@ -1,199 +0,0 @@ -package lifecycle - -import ( - "bytes" - "context" - "encoding/xml" - "errors" - "fmt" - "time" - - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3lifecycle" -) - -// lifecycleConfig mirrors the XML structure just enough to parse rules. -// We define a minimal local struct to avoid importing the s3api package -// (which would create a circular dependency if s3api ever imports the worker). -type lifecycleConfig struct { - XMLName xml.Name `xml:"LifecycleConfiguration"` - Rules []lifecycleConfigRule `xml:"Rule"` -} - -type lifecycleConfigRule struct { - ID string `xml:"ID"` - Status string `xml:"Status"` - Filter lifecycleFilter `xml:"Filter"` - Prefix string `xml:"Prefix"` - Expiration lifecycleExpiration `xml:"Expiration"` - NoncurrentVersionExpiration noncurrentVersionExpiration `xml:"NoncurrentVersionExpiration"` - AbortIncompleteMultipartUpload abortMPU `xml:"AbortIncompleteMultipartUpload"` -} - -type lifecycleFilter struct { - Prefix string `xml:"Prefix"` - Tag lifecycleTag `xml:"Tag"` - And lifecycleAnd `xml:"And"` - ObjectSizeGreaterThan int64 `xml:"ObjectSizeGreaterThan"` - ObjectSizeLessThan int64 `xml:"ObjectSizeLessThan"` -} - -type lifecycleAnd struct { - Prefix string `xml:"Prefix"` - Tags []lifecycleTag `xml:"Tag"` - ObjectSizeGreaterThan int64 `xml:"ObjectSizeGreaterThan"` - ObjectSizeLessThan int64 `xml:"ObjectSizeLessThan"` -} - -type lifecycleTag struct { - Key string `xml:"Key"` - Value string `xml:"Value"` -} - -type lifecycleExpiration struct { - Days int `xml:"Days"` - Date string `xml:"Date"` - ExpiredObjectDeleteMarker bool `xml:"ExpiredObjectDeleteMarker"` -} - -type noncurrentVersionExpiration struct { - NoncurrentDays int `xml:"NoncurrentDays"` - NewerNoncurrentVersions int `xml:"NewerNoncurrentVersions"` -} - -type abortMPU struct { - DaysAfterInitiation int `xml:"DaysAfterInitiation"` -} - -// errMalformedLifecycleXML indicates the lifecycle XML exists but could not be parsed. -// Callers should fail closed (not fall back to TTL) to avoid broader deletions. -var errMalformedLifecycleXML = errors.New("malformed lifecycle XML") - -// loadLifecycleRulesFromBucket reads the lifecycle XML from a bucket's -// metadata and converts it to evaluator-friendly rules. -// -// Returns: -// - (rules, nil) when lifecycle XML is configured and parseable -// - (nil, nil) when no lifecycle XML is configured (caller should use TTL fallback) -// - (nil, errMalformedLifecycleXML) when XML exists but is malformed (fail closed) -// - (nil, err) for transient filer errors (caller should use TTL fallback with warning) -func loadLifecycleRulesFromBucket( - ctx context.Context, - client filer_pb.SeaweedFilerClient, - bucketsPath, bucket string, -) ([]s3lifecycle.Rule, error) { - bucketDir := bucketsPath - resp, err := filer_pb.LookupEntry(ctx, client, &filer_pb.LookupDirectoryEntryRequest{ - Directory: bucketDir, - Name: bucket, - }) - if err != nil { - // Transient filer error — not the same as malformed XML. - return nil, fmt.Errorf("lookup bucket %s: %w", bucket, err) - } - if resp.Entry == nil || resp.Entry.Extended == nil { - return nil, nil - } - xmlData := resp.Entry.Extended[lifecycleXMLKey] - if len(xmlData) == 0 { - return nil, nil - } - rules, parseErr := parseLifecycleXML(xmlData) - if parseErr != nil { - return nil, fmt.Errorf("%w: bucket %s: %v", errMalformedLifecycleXML, bucket, parseErr) - } - // Return non-nil empty slice when XML was present but yielded no rules - // (e.g., all rules disabled). This lets callers distinguish "no XML" (nil) - // from "XML present, no effective rules" (empty slice). - if rules == nil { - rules = []s3lifecycle.Rule{} - } - return rules, nil -} - -// parseLifecycleXML parses lifecycle configuration XML and converts it -// to evaluator-friendly rules. -func parseLifecycleXML(data []byte) ([]s3lifecycle.Rule, error) { - var config lifecycleConfig - if err := xml.NewDecoder(bytes.NewReader(data)).Decode(&config); err != nil { - return nil, fmt.Errorf("decode lifecycle XML: %w", err) - } - - var rules []s3lifecycle.Rule - for _, r := range config.Rules { - rule := s3lifecycle.Rule{ - ID: r.ID, - Status: r.Status, - } - - // Resolve prefix: Filter.And.Prefix > Filter.Prefix > Rule.Prefix - switch { - case r.Filter.And.Prefix != "" || len(r.Filter.And.Tags) > 0 || - r.Filter.And.ObjectSizeGreaterThan > 0 || r.Filter.And.ObjectSizeLessThan > 0: - rule.Prefix = r.Filter.And.Prefix - rule.FilterTags = tagsToMap(r.Filter.And.Tags) - rule.FilterSizeGreaterThan = r.Filter.And.ObjectSizeGreaterThan - rule.FilterSizeLessThan = r.Filter.And.ObjectSizeLessThan - case r.Filter.Tag.Key != "": - rule.Prefix = r.Filter.Prefix - rule.FilterTags = map[string]string{r.Filter.Tag.Key: r.Filter.Tag.Value} - rule.FilterSizeGreaterThan = r.Filter.ObjectSizeGreaterThan - rule.FilterSizeLessThan = r.Filter.ObjectSizeLessThan - default: - if r.Filter.Prefix != "" { - rule.Prefix = r.Filter.Prefix - } else { - rule.Prefix = r.Prefix - } - rule.FilterSizeGreaterThan = r.Filter.ObjectSizeGreaterThan - rule.FilterSizeLessThan = r.Filter.ObjectSizeLessThan - } - - rule.ExpirationDays = r.Expiration.Days - rule.ExpiredObjectDeleteMarker = r.Expiration.ExpiredObjectDeleteMarker - rule.NoncurrentVersionExpirationDays = r.NoncurrentVersionExpiration.NoncurrentDays - rule.NewerNoncurrentVersions = r.NoncurrentVersionExpiration.NewerNoncurrentVersions - rule.AbortMPUDaysAfterInitiation = r.AbortIncompleteMultipartUpload.DaysAfterInitiation - - // Parse Date if present. - if r.Expiration.Date != "" { - // Date may be RFC3339 or ISO 8601 date-only. - parsed, parseErr := parseExpirationDate(r.Expiration.Date) - if parseErr != nil { - glog.V(1).Infof("s3_lifecycle: skipping rule %s: invalid expiration date %q: %v", r.ID, r.Expiration.Date, parseErr) - continue - } - rule.ExpirationDate = parsed - } - - rules = append(rules, rule) - } - return rules, nil -} - -func tagsToMap(tags []lifecycleTag) map[string]string { - if len(tags) == 0 { - return nil - } - m := make(map[string]string, len(tags)) - for _, t := range tags { - m[t.Key] = t.Value - } - return m -} - -func parseExpirationDate(s string) (time.Time, error) { - // Try RFC3339 first, then ISO 8601 date-only. - formats := []string{ - "2006-01-02T15:04:05Z07:00", - "2006-01-02", - } - for _, f := range formats { - t, err := time.Parse(f, s) - if err == nil { - return t, nil - } - } - return time.Time{}, fmt.Errorf("unrecognized date format: %s", s) -} diff --git a/weed/plugin/worker/lifecycle/rules_test.go b/weed/plugin/worker/lifecycle/rules_test.go deleted file mode 100644 index ab57137a7..000000000 --- a/weed/plugin/worker/lifecycle/rules_test.go +++ /dev/null @@ -1,256 +0,0 @@ -package lifecycle - -import ( - "testing" - "time" -) - -func TestParseLifecycleXML_CompleteConfig(t *testing.T) { - xml := []byte(` - - rotation - - Enabled - 30 - - 7 - 2 - - - 3 - - -`) - - rules, err := parseLifecycleXML(xml) - if err != nil { - t.Fatalf("parseLifecycleXML: %v", err) - } - if len(rules) != 1 { - t.Fatalf("expected 1 rule, got %d", len(rules)) - } - - r := rules[0] - if r.ID != "rotation" { - t.Errorf("expected ID 'rotation', got %q", r.ID) - } - if r.Status != "Enabled" { - t.Errorf("expected Status 'Enabled', got %q", r.Status) - } - if r.ExpirationDays != 30 { - t.Errorf("expected ExpirationDays=30, got %d", r.ExpirationDays) - } - if r.NoncurrentVersionExpirationDays != 7 { - t.Errorf("expected NoncurrentVersionExpirationDays=7, got %d", r.NoncurrentVersionExpirationDays) - } - if r.NewerNoncurrentVersions != 2 { - t.Errorf("expected NewerNoncurrentVersions=2, got %d", r.NewerNoncurrentVersions) - } - if r.AbortMPUDaysAfterInitiation != 3 { - t.Errorf("expected AbortMPUDaysAfterInitiation=3, got %d", r.AbortMPUDaysAfterInitiation) - } -} - -func TestParseLifecycleXML_PrefixFilter(t *testing.T) { - xml := []byte(` - - logs - Enabled - logs/ - 7 - -`) - - rules, err := parseLifecycleXML(xml) - if err != nil { - t.Fatalf("parseLifecycleXML: %v", err) - } - if len(rules) != 1 { - t.Fatalf("expected 1 rule, got %d", len(rules)) - } - if rules[0].Prefix != "logs/" { - t.Errorf("expected Prefix='logs/', got %q", rules[0].Prefix) - } -} - -func TestParseLifecycleXML_LegacyPrefix(t *testing.T) { - // Old-style at rule level instead of inside . - xml := []byte(` - - old - Enabled - archive/ - 90 - -`) - - rules, err := parseLifecycleXML(xml) - if err != nil { - t.Fatalf("parseLifecycleXML: %v", err) - } - if len(rules) != 1 { - t.Fatalf("expected 1 rule, got %d", len(rules)) - } - if rules[0].Prefix != "archive/" { - t.Errorf("expected Prefix='archive/', got %q", rules[0].Prefix) - } -} - -func TestParseLifecycleXML_TagFilter(t *testing.T) { - xml := []byte(` - - tag-rule - Enabled - - envdev - - 1 - -`) - - rules, err := parseLifecycleXML(xml) - if err != nil { - t.Fatalf("parseLifecycleXML: %v", err) - } - if len(rules) != 1 { - t.Fatalf("expected 1 rule, got %d", len(rules)) - } - r := rules[0] - if len(r.FilterTags) != 1 || r.FilterTags["env"] != "dev" { - t.Errorf("expected FilterTags={env:dev}, got %v", r.FilterTags) - } -} - -func TestParseLifecycleXML_AndFilter(t *testing.T) { - xml := []byte(` - - and-rule - Enabled - - - data/ - envstaging - 1024 - - - 14 - -`) - - rules, err := parseLifecycleXML(xml) - if err != nil { - t.Fatalf("parseLifecycleXML: %v", err) - } - if len(rules) != 1 { - t.Fatalf("expected 1 rule, got %d", len(rules)) - } - r := rules[0] - if r.Prefix != "data/" { - t.Errorf("expected Prefix='data/', got %q", r.Prefix) - } - if r.FilterTags["env"] != "staging" { - t.Errorf("expected tag env=staging, got %v", r.FilterTags) - } - if r.FilterSizeGreaterThan != 1024 { - t.Errorf("expected FilterSizeGreaterThan=1024, got %d", r.FilterSizeGreaterThan) - } -} - -func TestParseLifecycleXML_ExpirationDate(t *testing.T) { - xml := []byte(` - - date-rule - Enabled - - 2026-06-01T00:00:00Z - -`) - - rules, err := parseLifecycleXML(xml) - if err != nil { - t.Fatalf("parseLifecycleXML: %v", err) - } - expected := time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC) - if !rules[0].ExpirationDate.Equal(expected) { - t.Errorf("expected ExpirationDate=%v, got %v", expected, rules[0].ExpirationDate) - } -} - -func TestParseLifecycleXML_ExpiredObjectDeleteMarker(t *testing.T) { - xml := []byte(` - - marker-cleanup - Enabled - - true - -`) - - rules, err := parseLifecycleXML(xml) - if err != nil { - t.Fatalf("parseLifecycleXML: %v", err) - } - if !rules[0].ExpiredObjectDeleteMarker { - t.Error("expected ExpiredObjectDeleteMarker=true") - } -} - -func TestParseLifecycleXML_MultipleRules(t *testing.T) { - xml := []byte(` - - rule1 - Enabled - logs/ - 7 - - - rule2 - Disabled - temp/ - 1 - - - rule3 - Enabled - - 365 - -`) - - rules, err := parseLifecycleXML(xml) - if err != nil { - t.Fatalf("parseLifecycleXML: %v", err) - } - if len(rules) != 3 { - t.Fatalf("expected 3 rules, got %d", len(rules)) - } - if rules[1].Status != "Disabled" { - t.Errorf("expected rule2 Status=Disabled, got %q", rules[1].Status) - } -} - -func TestParseExpirationDate(t *testing.T) { - tests := []struct { - name string - input string - want time.Time - wantErr bool - }{ - {"rfc3339_utc", "2026-06-01T00:00:00Z", time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC), false}, - {"rfc3339_offset", "2026-06-01T00:00:00+05:00", time.Date(2026, 6, 1, 0, 0, 0, 0, time.FixedZone("", 5*3600)), false}, - {"date_only", "2026-06-01", time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC), false}, - {"invalid", "not-a-date", time.Time{}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := parseExpirationDate(tt.input) - if (err != nil) != tt.wantErr { - t.Errorf("parseExpirationDate(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) - return - } - if !tt.wantErr && !got.Equal(tt.want) { - t.Errorf("parseExpirationDate(%q) = %v, want %v", tt.input, got, tt.want) - } - }) - } -} diff --git a/weed/plugin/worker/lifecycle/version_test.go b/weed/plugin/worker/lifecycle/version_test.go deleted file mode 100644 index 43cc0d93b..000000000 --- a/weed/plugin/worker/lifecycle/version_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package lifecycle - -import ( - "fmt" - "math" - "strings" - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3lifecycle" -) - -// makeVersionId creates a new-format version ID from a timestamp. -func makeVersionId(t time.Time) string { - inverted := math.MaxInt64 - t.UnixNano() - return fmt.Sprintf("%016x", inverted) + "0000000000000000" -} - -func TestSortVersionsByVersionId(t *testing.T) { - t1 := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) - t2 := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC) - t3 := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) - - vid1 := makeVersionId(t1) - vid2 := makeVersionId(t2) - vid3 := makeVersionId(t3) - - entries := []*filer_pb.Entry{ - {Name: "v_" + vid1}, - {Name: "v_" + vid3}, - {Name: "v_" + vid2}, - } - - sortVersionsByVersionId(entries) - - // Should be sorted newest first: t3, t2, t1. - expected := []string{"v_" + vid3, "v_" + vid2, "v_" + vid1} - for i, want := range expected { - if entries[i].Name != want { - t.Errorf("entries[%d].Name = %s, want %s", i, entries[i].Name, want) - } - } -} - -func TestSortVersionsByVersionId_SameTimestampDifferentSuffix(t *testing.T) { - // Two versions with the same timestamp prefix but different random suffix. - // The sort must still produce a deterministic order. - base := makeVersionId(time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC)) - vid1 := base[:16] + "aaaaaaaaaaaaaaaa" - vid2 := base[:16] + "bbbbbbbbbbbbbbbb" - - entries := []*filer_pb.Entry{ - {Name: "v_" + vid2}, - {Name: "v_" + vid1}, - } - - sortVersionsByVersionId(entries) - - // New format: smaller hex = newer. vid1 ("aaa...") < vid2 ("bbb...") so vid1 is newer. - if strings.TrimPrefix(entries[0].Name, "v_") != vid1 { - t.Errorf("expected vid1 (newer) first, got %s", entries[0].Name) - } -} - -func TestCompareVersionIdsMixedFormats(t *testing.T) { - // Old format: raw nanosecond timestamp (below threshold ~0x17...). - // New format: inverted timestamp (above threshold ~0x68...). - oldTs := time.Date(2023, 6, 15, 12, 0, 0, 0, time.UTC) - newTs := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC) - - oldFormatId := fmt.Sprintf("%016x", oldTs.UnixNano()) + "abcdef0123456789" - newFormatId := makeVersionId(newTs) // uses inverted timestamp - - // newTs is more recent, so newFormatId should sort as "newer". - cmp := s3lifecycle.CompareVersionIds(newFormatId, oldFormatId) - if cmp >= 0 { - t.Errorf("expected new-format ID (2026) to be newer than old-format ID (2023), got cmp=%d", cmp) - } - - // Reverse comparison. - cmp2 := s3lifecycle.CompareVersionIds(oldFormatId, newFormatId) - if cmp2 <= 0 { - t.Errorf("expected old-format ID (2023) to be older than new-format ID (2026), got cmp=%d", cmp2) - } - - // Sort a mixed slice: should be newest-first. - entries := []*filer_pb.Entry{ - {Name: "v_" + oldFormatId}, - {Name: "v_" + newFormatId}, - } - sortVersionsByVersionId(entries) - - if strings.TrimPrefix(entries[0].Name, "v_") != newFormatId { - t.Errorf("expected new-format (newer) entry first after sort") - } -} - -func TestVersionsDirectoryNaming(t *testing.T) { - if s3_constants.VersionsFolder != ".versions" { - t.Fatalf("unexpected VersionsFolder constant: %q", s3_constants.VersionsFolder) - } - - versionsDir := "/buckets/mybucket/path/to/key.versions" - bucketPath := "/buckets/mybucket" - relDir := strings.TrimPrefix(versionsDir, bucketPath+"/") - objKey := strings.TrimSuffix(relDir, s3_constants.VersionsFolder) - if objKey != "path/to/key" { - t.Errorf("expected 'path/to/key', got %q", objKey) - } -} diff --git a/weed/query/engine/aggregations.go b/weed/query/engine/aggregations.go index 6b58517e1..54130212e 100644 --- a/weed/query/engine/aggregations.go +++ b/weed/query/engine/aggregations.go @@ -74,11 +74,6 @@ func (opt *FastPathOptimizer) DetermineStrategy(aggregations []AggregationSpec) return strategy } -// CollectDataSources gathers information about available data sources for a topic -func (opt *FastPathOptimizer) CollectDataSources(ctx context.Context, hybridScanner *HybridMessageScanner) (*TopicDataSources, error) { - return opt.CollectDataSourcesWithTimeFilter(ctx, hybridScanner, 0, 0) -} - // CollectDataSourcesWithTimeFilter gathers information about available data sources for a topic // with optional time filtering to skip irrelevant parquet files func (opt *FastPathOptimizer) CollectDataSourcesWithTimeFilter(ctx context.Context, hybridScanner *HybridMessageScanner, startTimeNs, stopTimeNs int64) (*TopicDataSources, error) { diff --git a/weed/query/engine/engine.go b/weed/query/engine/engine.go index ac66a7453..7a1c783ba 100644 --- a/weed/query/engine/engine.go +++ b/weed/query/engine/engine.go @@ -539,20 +539,6 @@ func NewSQLEngine(masterAddress string) *SQLEngine { } } -// NewSQLEngineWithCatalog creates a new SQL execution engine with a custom catalog -// Used for testing or when you want to provide a pre-configured catalog -func NewSQLEngineWithCatalog(catalog *SchemaCatalog) *SQLEngine { - // Initialize global HTTP client if not already done - // This is needed for reading partition data from the filer - if util_http.GetGlobalHttpClient() == nil { - util_http.InitGlobalHttpClient() - } - - return &SQLEngine{ - catalog: catalog, - } -} - // GetCatalog returns the schema catalog for external access func (e *SQLEngine) GetCatalog() *SchemaCatalog { return e.catalog @@ -3682,11 +3668,6 @@ type ExecutionPlanBuilder struct { engine *SQLEngine } -// NewExecutionPlanBuilder creates a new execution plan builder -func NewExecutionPlanBuilder(engine *SQLEngine) *ExecutionPlanBuilder { - return &ExecutionPlanBuilder{engine: engine} -} - // BuildAggregationPlan builds an execution plan for aggregation queries func (builder *ExecutionPlanBuilder) BuildAggregationPlan( stmt *SelectStatement, diff --git a/weed/query/engine/engine_test.go b/weed/query/engine/engine_test.go deleted file mode 100644 index 42a5f4911..000000000 --- a/weed/query/engine/engine_test.go +++ /dev/null @@ -1,1329 +0,0 @@ -package engine - -import ( - "context" - "encoding/binary" - "errors" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/mq/topic" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "google.golang.org/protobuf/proto" -) - -// Mock implementations for testing -type MockHybridMessageScanner struct { - mock.Mock - topic topic.Topic -} - -func (m *MockHybridMessageScanner) ReadParquetStatistics(partitionPath string) ([]*ParquetFileStats, error) { - args := m.Called(partitionPath) - return args.Get(0).([]*ParquetFileStats), args.Error(1) -} - -type MockSQLEngine struct { - *SQLEngine - mockPartitions map[string][]string - mockParquetSourceFiles map[string]map[string]bool - mockLiveLogRowCounts map[string]int64 - mockColumnStats map[string]map[string]*ParquetColumnStats -} - -func NewMockSQLEngine() *MockSQLEngine { - return &MockSQLEngine{ - SQLEngine: &SQLEngine{ - catalog: &SchemaCatalog{ - databases: make(map[string]*DatabaseInfo), - currentDatabase: "test", - }, - }, - mockPartitions: make(map[string][]string), - mockParquetSourceFiles: make(map[string]map[string]bool), - mockLiveLogRowCounts: make(map[string]int64), - mockColumnStats: make(map[string]map[string]*ParquetColumnStats), - } -} - -func (m *MockSQLEngine) discoverTopicPartitions(namespace, topicName string) ([]string, error) { - key := namespace + "." + topicName - if partitions, exists := m.mockPartitions[key]; exists { - return partitions, nil - } - return []string{"partition-1", "partition-2"}, nil -} - -func (m *MockSQLEngine) extractParquetSourceFiles(fileStats []*ParquetFileStats) map[string]bool { - if len(fileStats) == 0 { - return make(map[string]bool) - } - return map[string]bool{"converted-log-1": true} -} - -func (m *MockSQLEngine) countLiveLogRowsExcludingParquetSources(ctx context.Context, partition string, parquetSources map[string]bool) (int64, error) { - if count, exists := m.mockLiveLogRowCounts[partition]; exists { - return count, nil - } - return 25, nil -} - -func (m *MockSQLEngine) computeLiveLogMinMax(partition, column string, parquetSources map[string]bool) (interface{}, interface{}, error) { - switch column { - case "id": - return int64(1), int64(50), nil - case "value": - return 10.5, 99.9, nil - default: - return nil, nil, nil - } -} - -func (m *MockSQLEngine) getSystemColumnGlobalMin(column string, allFileStats map[string][]*ParquetFileStats) interface{} { - return int64(1000000000) -} - -func (m *MockSQLEngine) getSystemColumnGlobalMax(column string, allFileStats map[string][]*ParquetFileStats) interface{} { - return int64(2000000000) -} - -func createMockColumnStats(column string, minVal, maxVal interface{}) *ParquetColumnStats { - return &ParquetColumnStats{ - ColumnName: column, - MinValue: convertToSchemaValue(minVal), - MaxValue: convertToSchemaValue(maxVal), - NullCount: 0, - } -} - -func convertToSchemaValue(val interface{}) *schema_pb.Value { - switch v := val.(type) { - case int64: - return &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: v}} - case float64: - return &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: v}} - case string: - return &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: v}} - } - return nil -} - -// Test FastPathOptimizer -func TestFastPathOptimizer_DetermineStrategy(t *testing.T) { - engine := NewMockSQLEngine() - optimizer := NewFastPathOptimizer(engine.SQLEngine) - - tests := []struct { - name string - aggregations []AggregationSpec - expected AggregationStrategy - }{ - { - name: "Supported aggregations", - aggregations: []AggregationSpec{ - {Function: FuncCOUNT, Column: "*"}, - {Function: FuncMAX, Column: "id"}, - {Function: FuncMIN, Column: "value"}, - }, - expected: AggregationStrategy{ - CanUseFastPath: true, - Reason: "all_aggregations_supported", - UnsupportedSpecs: []AggregationSpec{}, - }, - }, - { - name: "Unsupported aggregation", - aggregations: []AggregationSpec{ - {Function: FuncCOUNT, Column: "*"}, - {Function: FuncAVG, Column: "value"}, // Not supported - }, - expected: AggregationStrategy{ - CanUseFastPath: false, - Reason: "unsupported_aggregation_functions", - }, - }, - { - name: "Empty aggregations", - aggregations: []AggregationSpec{}, - expected: AggregationStrategy{ - CanUseFastPath: true, - Reason: "all_aggregations_supported", - UnsupportedSpecs: []AggregationSpec{}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - strategy := optimizer.DetermineStrategy(tt.aggregations) - - assert.Equal(t, tt.expected.CanUseFastPath, strategy.CanUseFastPath) - assert.Equal(t, tt.expected.Reason, strategy.Reason) - if !tt.expected.CanUseFastPath { - assert.NotEmpty(t, strategy.UnsupportedSpecs) - } - }) - } -} - -// Test AggregationComputer -func TestAggregationComputer_ComputeFastPathAggregations(t *testing.T) { - engine := NewMockSQLEngine() - computer := NewAggregationComputer(engine.SQLEngine) - - dataSources := &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/topic1/partition-1": { - { - RowCount: 30, - ColumnStats: map[string]*ParquetColumnStats{ - "id": createMockColumnStats("id", int64(10), int64(40)), - }, - }, - }, - }, - ParquetRowCount: 30, - LiveLogRowCount: 25, - PartitionsCount: 1, - } - - partitions := []string{"/topics/test/topic1/partition-1"} - - tests := []struct { - name string - aggregations []AggregationSpec - validate func(t *testing.T, results []AggregationResult) - }{ - { - name: "COUNT aggregation", - aggregations: []AggregationSpec{ - {Function: FuncCOUNT, Column: "*"}, - }, - validate: func(t *testing.T, results []AggregationResult) { - assert.Len(t, results, 1) - assert.Equal(t, int64(55), results[0].Count) // 30 + 25 - }, - }, - { - name: "MAX aggregation", - aggregations: []AggregationSpec{ - {Function: FuncMAX, Column: "id"}, - }, - validate: func(t *testing.T, results []AggregationResult) { - assert.Len(t, results, 1) - // Should be max of parquet stats (40) - mock doesn't combine with live log - assert.Equal(t, int64(40), results[0].Max) - }, - }, - { - name: "MIN aggregation", - aggregations: []AggregationSpec{ - {Function: FuncMIN, Column: "id"}, - }, - validate: func(t *testing.T, results []AggregationResult) { - assert.Len(t, results, 1) - // Should be min of parquet stats (10) - mock doesn't combine with live log - assert.Equal(t, int64(10), results[0].Min) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - results, err := computer.ComputeFastPathAggregations(ctx, tt.aggregations, dataSources, partitions) - - assert.NoError(t, err) - tt.validate(t, results) - }) - } -} - -// Test case-insensitive column lookup and null handling for MIN/MAX aggregations -func TestAggregationComputer_MinMaxEdgeCases(t *testing.T) { - engine := NewMockSQLEngine() - computer := NewAggregationComputer(engine.SQLEngine) - - tests := []struct { - name string - dataSources *TopicDataSources - aggregations []AggregationSpec - validate func(t *testing.T, results []AggregationResult, err error) - }{ - { - name: "Case insensitive column lookup", - dataSources: &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/partition-1": { - { - RowCount: 50, - ColumnStats: map[string]*ParquetColumnStats{ - "ID": createMockColumnStats("ID", int64(5), int64(95)), // Uppercase column name - }, - }, - }, - }, - ParquetRowCount: 50, - LiveLogRowCount: 0, - PartitionsCount: 1, - }, - aggregations: []AggregationSpec{ - {Function: FuncMIN, Column: "id"}, // lowercase column name - {Function: FuncMAX, Column: "id"}, - }, - validate: func(t *testing.T, results []AggregationResult, err error) { - assert.NoError(t, err) - assert.Len(t, results, 2) - assert.Equal(t, int64(5), results[0].Min, "MIN should work with case-insensitive lookup") - assert.Equal(t, int64(95), results[1].Max, "MAX should work with case-insensitive lookup") - }, - }, - { - name: "Null column stats handling", - dataSources: &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/partition-1": { - { - RowCount: 50, - ColumnStats: map[string]*ParquetColumnStats{ - "id": { - ColumnName: "id", - MinValue: nil, // Null min value - MaxValue: nil, // Null max value - NullCount: 50, - RowCount: 50, - }, - }, - }, - }, - }, - ParquetRowCount: 50, - LiveLogRowCount: 0, - PartitionsCount: 1, - }, - aggregations: []AggregationSpec{ - {Function: FuncMIN, Column: "id"}, - {Function: FuncMAX, Column: "id"}, - }, - validate: func(t *testing.T, results []AggregationResult, err error) { - assert.NoError(t, err) - assert.Len(t, results, 2) - // When stats are null, should fall back to system column or return nil - // This tests that we don't crash on null stats - }, - }, - { - name: "Mixed data types - string column", - dataSources: &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/partition-1": { - { - RowCount: 30, - ColumnStats: map[string]*ParquetColumnStats{ - "name": createMockColumnStats("name", "Alice", "Zoe"), - }, - }, - }, - }, - ParquetRowCount: 30, - LiveLogRowCount: 0, - PartitionsCount: 1, - }, - aggregations: []AggregationSpec{ - {Function: FuncMIN, Column: "name"}, - {Function: FuncMAX, Column: "name"}, - }, - validate: func(t *testing.T, results []AggregationResult, err error) { - assert.NoError(t, err) - assert.Len(t, results, 2) - assert.Equal(t, "Alice", results[0].Min) - assert.Equal(t, "Zoe", results[1].Max) - }, - }, - { - name: "Mixed data types - float column", - dataSources: &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/partition-1": { - { - RowCount: 25, - ColumnStats: map[string]*ParquetColumnStats{ - "price": createMockColumnStats("price", float64(19.99), float64(299.50)), - }, - }, - }, - }, - ParquetRowCount: 25, - LiveLogRowCount: 0, - PartitionsCount: 1, - }, - aggregations: []AggregationSpec{ - {Function: FuncMIN, Column: "price"}, - {Function: FuncMAX, Column: "price"}, - }, - validate: func(t *testing.T, results []AggregationResult, err error) { - assert.NoError(t, err) - assert.Len(t, results, 2) - assert.Equal(t, float64(19.99), results[0].Min) - assert.Equal(t, float64(299.50), results[1].Max) - }, - }, - { - name: "Column not found in parquet stats", - dataSources: &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/partition-1": { - { - RowCount: 20, - ColumnStats: map[string]*ParquetColumnStats{ - "id": createMockColumnStats("id", int64(1), int64(100)), - // Note: "nonexistent_column" is not in stats - }, - }, - }, - }, - ParquetRowCount: 20, - LiveLogRowCount: 10, // Has live logs to fall back to - PartitionsCount: 1, - }, - aggregations: []AggregationSpec{ - {Function: FuncMIN, Column: "nonexistent_column"}, - {Function: FuncMAX, Column: "nonexistent_column"}, - }, - validate: func(t *testing.T, results []AggregationResult, err error) { - assert.NoError(t, err) - assert.Len(t, results, 2) - // Should fall back to live log processing or return nil - // The key is that it shouldn't crash - }, - }, - { - name: "Multiple parquet files with different ranges", - dataSources: &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/partition-1": { - { - RowCount: 30, - ColumnStats: map[string]*ParquetColumnStats{ - "score": createMockColumnStats("score", int64(10), int64(50)), - }, - }, - { - RowCount: 40, - ColumnStats: map[string]*ParquetColumnStats{ - "score": createMockColumnStats("score", int64(5), int64(75)), // Lower min, higher max - }, - }, - }, - }, - ParquetRowCount: 70, - LiveLogRowCount: 0, - PartitionsCount: 1, - }, - aggregations: []AggregationSpec{ - {Function: FuncMIN, Column: "score"}, - {Function: FuncMAX, Column: "score"}, - }, - validate: func(t *testing.T, results []AggregationResult, err error) { - assert.NoError(t, err) - assert.Len(t, results, 2) - assert.Equal(t, int64(5), results[0].Min, "Should find global minimum across all files") - assert.Equal(t, int64(75), results[1].Max, "Should find global maximum across all files") - }, - }, - } - - partitions := []string{"/topics/test/partition-1"} - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - results, err := computer.ComputeFastPathAggregations(ctx, tt.aggregations, tt.dataSources, partitions) - tt.validate(t, results, err) - }) - } -} - -// Test the specific bug where MIN/MAX was returning empty values -func TestAggregationComputer_MinMaxEmptyValuesBugFix(t *testing.T) { - engine := NewMockSQLEngine() - computer := NewAggregationComputer(engine.SQLEngine) - - // This test specifically addresses the bug where MIN/MAX returned empty - // due to improper null checking and extraction logic - dataSources := &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/test-topic/partition1": { - { - RowCount: 100, - ColumnStats: map[string]*ParquetColumnStats{ - "id": { - ColumnName: "id", - MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int32Value{Int32Value: 0}}, // Min should be 0 - MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int32Value{Int32Value: 99}}, // Max should be 99 - NullCount: 0, - RowCount: 100, - }, - }, - }, - }, - }, - ParquetRowCount: 100, - LiveLogRowCount: 0, // No live logs, pure parquet stats - PartitionsCount: 1, - } - - partitions := []string{"/topics/test/test-topic/partition1"} - - tests := []struct { - name string - aggregSpec AggregationSpec - expected interface{} - }{ - { - name: "MIN should return 0 not empty", - aggregSpec: AggregationSpec{Function: FuncMIN, Column: "id"}, - expected: int32(0), // Should extract the actual minimum value - }, - { - name: "MAX should return 99 not empty", - aggregSpec: AggregationSpec{Function: FuncMAX, Column: "id"}, - expected: int32(99), // Should extract the actual maximum value - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - results, err := computer.ComputeFastPathAggregations(ctx, []AggregationSpec{tt.aggregSpec}, dataSources, partitions) - - assert.NoError(t, err) - assert.Len(t, results, 1) - - // Verify the result is not nil/empty - if tt.aggregSpec.Function == FuncMIN { - assert.NotNil(t, results[0].Min, "MIN result should not be nil") - assert.Equal(t, tt.expected, results[0].Min) - } else if tt.aggregSpec.Function == FuncMAX { - assert.NotNil(t, results[0].Max, "MAX result should not be nil") - assert.Equal(t, tt.expected, results[0].Max) - } - }) - } -} - -// Test the formatAggregationResult function with MIN/MAX edge cases -func TestSQLEngine_FormatAggregationResult_MinMax(t *testing.T) { - engine := NewTestSQLEngine() - - tests := []struct { - name string - spec AggregationSpec - result AggregationResult - expected string - }{ - { - name: "MIN with zero value should not be empty", - spec: AggregationSpec{Function: FuncMIN, Column: "id"}, - result: AggregationResult{Min: int32(0)}, - expected: "0", - }, - { - name: "MAX with large value", - spec: AggregationSpec{Function: FuncMAX, Column: "id"}, - result: AggregationResult{Max: int32(99)}, - expected: "99", - }, - { - name: "MIN with negative value", - spec: AggregationSpec{Function: FuncMIN, Column: "score"}, - result: AggregationResult{Min: int64(-50)}, - expected: "-50", - }, - { - name: "MAX with float value", - spec: AggregationSpec{Function: FuncMAX, Column: "price"}, - result: AggregationResult{Max: float64(299.99)}, - expected: "299.99", - }, - { - name: "MIN with string value", - spec: AggregationSpec{Function: FuncMIN, Column: "name"}, - result: AggregationResult{Min: "Alice"}, - expected: "Alice", - }, - { - name: "MIN with nil should return NULL", - spec: AggregationSpec{Function: FuncMIN, Column: "missing"}, - result: AggregationResult{Min: nil}, - expected: "", // NULL values display as empty - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sqlValue := engine.formatAggregationResult(tt.spec, tt.result) - assert.Equal(t, tt.expected, sqlValue.String()) - }) - } -} - -// Test the direct formatAggregationResult scenario that was originally broken -func TestSQLEngine_MinMaxBugFixIntegration(t *testing.T) { - // This test focuses on the core bug fix without the complexity of table discovery - // It directly tests the scenario where MIN/MAX returned empty due to the bug - - engine := NewTestSQLEngine() - - // Test the direct formatting path that was failing - tests := []struct { - name string - aggregSpec AggregationSpec - aggResult AggregationResult - expectedEmpty bool - expectedValue string - }{ - { - name: "MIN with zero should not be empty (the original bug)", - aggregSpec: AggregationSpec{Function: FuncMIN, Column: "id", Alias: "MIN(id)"}, - aggResult: AggregationResult{Min: int32(0)}, // This was returning empty before fix - expectedEmpty: false, - expectedValue: "0", - }, - { - name: "MAX with valid value should not be empty", - aggregSpec: AggregationSpec{Function: FuncMAX, Column: "id", Alias: "MAX(id)"}, - aggResult: AggregationResult{Max: int32(99)}, - expectedEmpty: false, - expectedValue: "99", - }, - { - name: "MIN with negative value should work", - aggregSpec: AggregationSpec{Function: FuncMIN, Column: "score", Alias: "MIN(score)"}, - aggResult: AggregationResult{Min: int64(-10)}, - expectedEmpty: false, - expectedValue: "-10", - }, - { - name: "MIN with nil should be empty (expected behavior)", - aggregSpec: AggregationSpec{Function: FuncMIN, Column: "missing", Alias: "MIN(missing)"}, - aggResult: AggregationResult{Min: nil}, - expectedEmpty: true, - expectedValue: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test the formatAggregationResult function directly - sqlValue := engine.formatAggregationResult(tt.aggregSpec, tt.aggResult) - result := sqlValue.String() - - if tt.expectedEmpty { - assert.Empty(t, result, "Result should be empty for nil values") - } else { - assert.NotEmpty(t, result, "Result should not be empty") - assert.Equal(t, tt.expectedValue, result) - } - }) - } -} - -// Test the tryFastParquetAggregation method specifically for the bug -func TestSQLEngine_FastParquetAggregationBugFix(t *testing.T) { - // This test verifies that the fast path aggregation logic works correctly - // and doesn't return nil/empty values when it should return actual data - - engine := NewMockSQLEngine() - computer := NewAggregationComputer(engine.SQLEngine) - - // Create realistic data sources that mimic the user's scenario - dataSources := &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/test-topic/v2025-09-01-22-54-02/0000-0630": { - { - RowCount: 100, - ColumnStats: map[string]*ParquetColumnStats{ - "id": { - ColumnName: "id", - MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int32Value{Int32Value: 0}}, - MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int32Value{Int32Value: 99}}, - NullCount: 0, - RowCount: 100, - }, - }, - }, - }, - }, - ParquetRowCount: 100, - LiveLogRowCount: 0, // Pure parquet scenario - PartitionsCount: 1, - } - - partitions := []string{"/topics/test/test-topic/v2025-09-01-22-54-02/0000-0630"} - - tests := []struct { - name string - aggregations []AggregationSpec - validateResults func(t *testing.T, results []AggregationResult) - }{ - { - name: "Single MIN aggregation should return value not nil", - aggregations: []AggregationSpec{ - {Function: FuncMIN, Column: "id", Alias: "MIN(id)"}, - }, - validateResults: func(t *testing.T, results []AggregationResult) { - assert.Len(t, results, 1) - assert.NotNil(t, results[0].Min, "MIN result should not be nil") - assert.Equal(t, int32(0), results[0].Min, "MIN should return the correct minimum value") - }, - }, - { - name: "Single MAX aggregation should return value not nil", - aggregations: []AggregationSpec{ - {Function: FuncMAX, Column: "id", Alias: "MAX(id)"}, - }, - validateResults: func(t *testing.T, results []AggregationResult) { - assert.Len(t, results, 1) - assert.NotNil(t, results[0].Max, "MAX result should not be nil") - assert.Equal(t, int32(99), results[0].Max, "MAX should return the correct maximum value") - }, - }, - { - name: "Combined MIN/MAX should both return values", - aggregations: []AggregationSpec{ - {Function: FuncMIN, Column: "id", Alias: "MIN(id)"}, - {Function: FuncMAX, Column: "id", Alias: "MAX(id)"}, - }, - validateResults: func(t *testing.T, results []AggregationResult) { - assert.Len(t, results, 2) - assert.NotNil(t, results[0].Min, "MIN result should not be nil") - assert.NotNil(t, results[1].Max, "MAX result should not be nil") - assert.Equal(t, int32(0), results[0].Min) - assert.Equal(t, int32(99), results[1].Max) - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - results, err := computer.ComputeFastPathAggregations(ctx, tt.aggregations, dataSources, partitions) - - assert.NoError(t, err, "ComputeFastPathAggregations should not error") - tt.validateResults(t, results) - }) - } -} - -// Test ExecutionPlanBuilder -func TestExecutionPlanBuilder_BuildAggregationPlan(t *testing.T) { - engine := NewMockSQLEngine() - builder := NewExecutionPlanBuilder(engine.SQLEngine) - - // Parse a simple SELECT statement using the native parser - stmt, err := ParseSQL("SELECT COUNT(*) FROM test_topic") - assert.NoError(t, err) - selectStmt := stmt.(*SelectStatement) - - aggregations := []AggregationSpec{ - {Function: FuncCOUNT, Column: "*"}, - } - - strategy := AggregationStrategy{ - CanUseFastPath: true, - Reason: "all_aggregations_supported", - } - - dataSources := &TopicDataSources{ - ParquetRowCount: 100, - LiveLogRowCount: 50, - PartitionsCount: 3, - ParquetFiles: map[string][]*ParquetFileStats{ - "partition-1": {{RowCount: 50}}, - "partition-2": {{RowCount: 50}}, - }, - } - - plan := builder.BuildAggregationPlan(selectStmt, aggregations, strategy, dataSources) - - assert.Equal(t, "SELECT", plan.QueryType) - assert.Equal(t, "hybrid_fast_path", plan.ExecutionStrategy) - assert.Contains(t, plan.DataSources, "parquet_stats") - assert.Contains(t, plan.DataSources, "live_logs") - assert.Equal(t, 3, plan.PartitionsScanned) - assert.Equal(t, 2, plan.ParquetFilesScanned) - assert.Contains(t, plan.OptimizationsUsed, "parquet_statistics") - assert.Equal(t, []string{"COUNT(*)"}, plan.Aggregations) - assert.Equal(t, int64(50), plan.TotalRowsProcessed) // Only live logs scanned -} - -// Test Error Types -func TestErrorTypes(t *testing.T) { - t.Run("AggregationError", func(t *testing.T) { - err := AggregationError{ - Operation: "MAX", - Column: "id", - Cause: errors.New("column not found"), - } - - expected := "aggregation error in MAX(id): column not found" - assert.Equal(t, expected, err.Error()) - }) - - t.Run("DataSourceError", func(t *testing.T) { - err := DataSourceError{ - Source: "partition_discovery:test.topic1", - Cause: errors.New("network timeout"), - } - - expected := "data source error in partition_discovery:test.topic1: network timeout" - assert.Equal(t, expected, err.Error()) - }) - - t.Run("OptimizationError", func(t *testing.T) { - err := OptimizationError{ - Strategy: "fast_path_aggregation", - Reason: "unsupported function: AVG", - } - - expected := "optimization failed for fast_path_aggregation: unsupported function: AVG" - assert.Equal(t, expected, err.Error()) - }) -} - -// Integration Tests -func TestIntegration_FastPathOptimization(t *testing.T) { - engine := NewMockSQLEngine() - - // Setup components - optimizer := NewFastPathOptimizer(engine.SQLEngine) - computer := NewAggregationComputer(engine.SQLEngine) - - // Mock data setup - aggregations := []AggregationSpec{ - {Function: FuncCOUNT, Column: "*"}, - {Function: FuncMAX, Column: "id"}, - } - - // Step 1: Determine strategy - strategy := optimizer.DetermineStrategy(aggregations) - assert.True(t, strategy.CanUseFastPath) - - // Step 2: Mock data sources - dataSources := &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/topic1/partition-1": {{ - RowCount: 75, - ColumnStats: map[string]*ParquetColumnStats{ - "id": createMockColumnStats("id", int64(1), int64(100)), - }, - }}, - }, - ParquetRowCount: 75, - LiveLogRowCount: 25, - PartitionsCount: 1, - } - - partitions := []string{"/topics/test/topic1/partition-1"} - - // Step 3: Compute aggregations - ctx := context.Background() - results, err := computer.ComputeFastPathAggregations(ctx, aggregations, dataSources, partitions) - assert.NoError(t, err) - assert.Len(t, results, 2) - assert.Equal(t, int64(100), results[0].Count) // 75 + 25 - assert.Equal(t, int64(100), results[1].Max) // From parquet stats mock -} - -func TestIntegration_FallbackToFullScan(t *testing.T) { - engine := NewMockSQLEngine() - optimizer := NewFastPathOptimizer(engine.SQLEngine) - - // Unsupported aggregations - aggregations := []AggregationSpec{ - {Function: "AVG", Column: "value"}, // Not supported - } - - // Step 1: Strategy should reject fast path - strategy := optimizer.DetermineStrategy(aggregations) - assert.False(t, strategy.CanUseFastPath) - assert.Equal(t, "unsupported_aggregation_functions", strategy.Reason) - assert.NotEmpty(t, strategy.UnsupportedSpecs) -} - -// Benchmark Tests -func BenchmarkFastPathOptimizer_DetermineStrategy(b *testing.B) { - engine := NewMockSQLEngine() - optimizer := NewFastPathOptimizer(engine.SQLEngine) - - aggregations := []AggregationSpec{ - {Function: FuncCOUNT, Column: "*"}, - {Function: FuncMAX, Column: "id"}, - {Function: "MIN", Column: "value"}, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - strategy := optimizer.DetermineStrategy(aggregations) - _ = strategy.CanUseFastPath - } -} - -func BenchmarkAggregationComputer_ComputeFastPathAggregations(b *testing.B) { - engine := NewMockSQLEngine() - computer := NewAggregationComputer(engine.SQLEngine) - - dataSources := &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "partition-1": {{ - RowCount: 1000, - ColumnStats: map[string]*ParquetColumnStats{ - "id": createMockColumnStats("id", int64(1), int64(1000)), - }, - }}, - }, - ParquetRowCount: 1000, - LiveLogRowCount: 100, - } - - aggregations := []AggregationSpec{ - {Function: FuncCOUNT, Column: "*"}, - {Function: FuncMAX, Column: "id"}, - } - - partitions := []string{"partition-1"} - ctx := context.Background() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - results, err := computer.ComputeFastPathAggregations(ctx, aggregations, dataSources, partitions) - if err != nil { - b.Fatal(err) - } - _ = results - } -} - -// Tests for convertLogEntryToRecordValue - Protocol Buffer parsing bug fix -func TestSQLEngine_ConvertLogEntryToRecordValue_ValidProtobuf(t *testing.T) { - engine := NewTestSQLEngine() - - // Create a valid RecordValue protobuf with user data - originalRecord := &schema_pb.RecordValue{ - Fields: map[string]*schema_pb.Value{ - "id": {Kind: &schema_pb.Value_Int32Value{Int32Value: 42}}, - "name": {Kind: &schema_pb.Value_StringValue{StringValue: "test-user"}}, - "score": {Kind: &schema_pb.Value_DoubleValue{DoubleValue: 95.5}}, - }, - } - - // Serialize the protobuf (this is what MQ actually stores) - protobufData, err := proto.Marshal(originalRecord) - assert.NoError(t, err) - - // Create a LogEntry with the serialized data - logEntry := &filer_pb.LogEntry{ - TsNs: 1609459200000000000, // 2021-01-01 00:00:00 UTC - PartitionKeyHash: 123, - Data: protobufData, // Protocol buffer data (not JSON!) - Key: []byte("test-key-001"), - } - - // Test the conversion - result, source, err := engine.convertLogEntryToRecordValue(logEntry) - - // Verify no error - assert.NoError(t, err) - assert.Equal(t, "live_log", source) - assert.NotNil(t, result) - assert.NotNil(t, result.Fields) - - // Verify system columns are added correctly - assert.Contains(t, result.Fields, SW_COLUMN_NAME_TIMESTAMP) - assert.Contains(t, result.Fields, SW_COLUMN_NAME_KEY) - assert.Equal(t, int64(1609459200000000000), result.Fields[SW_COLUMN_NAME_TIMESTAMP].GetInt64Value()) - assert.Equal(t, []byte("test-key-001"), result.Fields[SW_COLUMN_NAME_KEY].GetBytesValue()) - - // Verify user data is preserved - assert.Contains(t, result.Fields, "id") - assert.Contains(t, result.Fields, "name") - assert.Contains(t, result.Fields, "score") - assert.Equal(t, int32(42), result.Fields["id"].GetInt32Value()) - assert.Equal(t, "test-user", result.Fields["name"].GetStringValue()) - assert.Equal(t, 95.5, result.Fields["score"].GetDoubleValue()) -} - -func TestSQLEngine_ConvertLogEntryToRecordValue_InvalidProtobuf(t *testing.T) { - engine := NewTestSQLEngine() - - // Create LogEntry with invalid protobuf data (this would cause the original JSON parsing bug) - logEntry := &filer_pb.LogEntry{ - TsNs: 1609459200000000000, - PartitionKeyHash: 123, - Data: []byte{0x17, 0x00, 0xFF, 0xFE}, // Invalid protobuf data (starts with \x17 like in the original error) - Key: []byte("test-key"), - } - - // Test the conversion - result, source, err := engine.convertLogEntryToRecordValue(logEntry) - - // Should return error for invalid protobuf - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to unmarshal log entry protobuf") - assert.Nil(t, result) - assert.Empty(t, source) -} - -func TestSQLEngine_ConvertLogEntryToRecordValue_EmptyProtobuf(t *testing.T) { - engine := NewTestSQLEngine() - - // Create a minimal valid RecordValue (empty fields) - emptyRecord := &schema_pb.RecordValue{ - Fields: map[string]*schema_pb.Value{}, - } - protobufData, err := proto.Marshal(emptyRecord) - assert.NoError(t, err) - - logEntry := &filer_pb.LogEntry{ - TsNs: 1609459200000000000, - PartitionKeyHash: 456, - Data: protobufData, - Key: []byte("empty-key"), - } - - // Test the conversion - result, source, err := engine.convertLogEntryToRecordValue(logEntry) - - // Should succeed and add system columns - assert.NoError(t, err) - assert.Equal(t, "live_log", source) - assert.NotNil(t, result) - assert.NotNil(t, result.Fields) - - // Should have system columns - assert.Contains(t, result.Fields, SW_COLUMN_NAME_TIMESTAMP) - assert.Contains(t, result.Fields, SW_COLUMN_NAME_KEY) - assert.Equal(t, int64(1609459200000000000), result.Fields[SW_COLUMN_NAME_TIMESTAMP].GetInt64Value()) - assert.Equal(t, []byte("empty-key"), result.Fields[SW_COLUMN_NAME_KEY].GetBytesValue()) - - // Should have no user fields - userFieldCount := 0 - for fieldName := range result.Fields { - if fieldName != SW_COLUMN_NAME_TIMESTAMP && fieldName != SW_COLUMN_NAME_KEY { - userFieldCount++ - } - } - assert.Equal(t, 0, userFieldCount) -} - -func TestSQLEngine_ConvertLogEntryToRecordValue_NilFieldsMap(t *testing.T) { - engine := NewTestSQLEngine() - - // Create RecordValue with nil Fields map (edge case) - recordWithNilFields := &schema_pb.RecordValue{ - Fields: nil, // This should be handled gracefully - } - protobufData, err := proto.Marshal(recordWithNilFields) - assert.NoError(t, err) - - logEntry := &filer_pb.LogEntry{ - TsNs: 1609459200000000000, - PartitionKeyHash: 789, - Data: protobufData, - Key: []byte("nil-fields-key"), - } - - // Test the conversion - result, source, err := engine.convertLogEntryToRecordValue(logEntry) - - // Should succeed and create Fields map - assert.NoError(t, err) - assert.Equal(t, "live_log", source) - assert.NotNil(t, result) - assert.NotNil(t, result.Fields) // Should be created by the function - - // Should have system columns - assert.Contains(t, result.Fields, SW_COLUMN_NAME_TIMESTAMP) - assert.Contains(t, result.Fields, SW_COLUMN_NAME_KEY) - assert.Equal(t, int64(1609459200000000000), result.Fields[SW_COLUMN_NAME_TIMESTAMP].GetInt64Value()) - assert.Equal(t, []byte("nil-fields-key"), result.Fields[SW_COLUMN_NAME_KEY].GetBytesValue()) -} - -func TestSQLEngine_ConvertLogEntryToRecordValue_SystemColumnOverride(t *testing.T) { - engine := NewTestSQLEngine() - - // Create RecordValue that already has system column names (should be overridden) - recordWithSystemCols := &schema_pb.RecordValue{ - Fields: map[string]*schema_pb.Value{ - "user_field": {Kind: &schema_pb.Value_StringValue{StringValue: "user-data"}}, - SW_COLUMN_NAME_TIMESTAMP: {Kind: &schema_pb.Value_Int64Value{Int64Value: 999999999}}, // Should be overridden - SW_COLUMN_NAME_KEY: {Kind: &schema_pb.Value_StringValue{StringValue: "old-key"}}, // Should be overridden - }, - } - protobufData, err := proto.Marshal(recordWithSystemCols) - assert.NoError(t, err) - - logEntry := &filer_pb.LogEntry{ - TsNs: 1609459200000000000, - PartitionKeyHash: 100, - Data: protobufData, - Key: []byte("actual-key"), - } - - // Test the conversion - result, source, err := engine.convertLogEntryToRecordValue(logEntry) - - // Should succeed - assert.NoError(t, err) - assert.Equal(t, "live_log", source) - assert.NotNil(t, result) - - // System columns should use LogEntry values, not protobuf values - assert.Equal(t, int64(1609459200000000000), result.Fields[SW_COLUMN_NAME_TIMESTAMP].GetInt64Value()) - assert.Equal(t, []byte("actual-key"), result.Fields[SW_COLUMN_NAME_KEY].GetBytesValue()) - - // User field should be preserved - assert.Contains(t, result.Fields, "user_field") - assert.Equal(t, "user-data", result.Fields["user_field"].GetStringValue()) -} - -func TestSQLEngine_ConvertLogEntryToRecordValue_ComplexDataTypes(t *testing.T) { - engine := NewTestSQLEngine() - - // Test with various data types - complexRecord := &schema_pb.RecordValue{ - Fields: map[string]*schema_pb.Value{ - "int32_field": {Kind: &schema_pb.Value_Int32Value{Int32Value: -42}}, - "int64_field": {Kind: &schema_pb.Value_Int64Value{Int64Value: 9223372036854775807}}, - "float_field": {Kind: &schema_pb.Value_FloatValue{FloatValue: 3.14159}}, - "double_field": {Kind: &schema_pb.Value_DoubleValue{DoubleValue: 2.718281828}}, - "bool_field": {Kind: &schema_pb.Value_BoolValue{BoolValue: true}}, - "string_field": {Kind: &schema_pb.Value_StringValue{StringValue: "test string with unicode party"}}, - "bytes_field": {Kind: &schema_pb.Value_BytesValue{BytesValue: []byte{0x01, 0x02, 0x03}}}, - }, - } - protobufData, err := proto.Marshal(complexRecord) - assert.NoError(t, err) - - logEntry := &filer_pb.LogEntry{ - TsNs: 1609459200000000000, - PartitionKeyHash: 200, - Data: protobufData, - Key: []byte("complex-key"), - } - - // Test the conversion - result, source, err := engine.convertLogEntryToRecordValue(logEntry) - - // Should succeed - assert.NoError(t, err) - assert.Equal(t, "live_log", source) - assert.NotNil(t, result) - - // Verify all data types are preserved - assert.Equal(t, int32(-42), result.Fields["int32_field"].GetInt32Value()) - assert.Equal(t, int64(9223372036854775807), result.Fields["int64_field"].GetInt64Value()) - assert.Equal(t, float32(3.14159), result.Fields["float_field"].GetFloatValue()) - assert.Equal(t, 2.718281828, result.Fields["double_field"].GetDoubleValue()) - assert.Equal(t, true, result.Fields["bool_field"].GetBoolValue()) - assert.Equal(t, "test string with unicode party", result.Fields["string_field"].GetStringValue()) - assert.Equal(t, []byte{0x01, 0x02, 0x03}, result.Fields["bytes_field"].GetBytesValue()) - - // System columns should still be present - assert.Contains(t, result.Fields, SW_COLUMN_NAME_TIMESTAMP) - assert.Contains(t, result.Fields, SW_COLUMN_NAME_KEY) -} - -// Tests for log buffer deduplication functionality -func TestSQLEngine_GetLogBufferStartFromFile_BinaryFormat(t *testing.T) { - engine := NewTestSQLEngine() - - // Create sample buffer start (binary format) - bufferStartBytes := make([]byte, 8) - binary.BigEndian.PutUint64(bufferStartBytes, uint64(1609459100000000001)) - - // Create file entry with buffer start + some chunks - entry := &filer_pb.Entry{ - Name: "test-log-file", - Extended: map[string][]byte{ - "buffer_start": bufferStartBytes, - }, - Chunks: []*filer_pb.FileChunk{ - {FileId: "chunk1", Offset: 0, Size: 1000}, - {FileId: "chunk2", Offset: 1000, Size: 1000}, - {FileId: "chunk3", Offset: 2000, Size: 1000}, - }, - } - - // Test extraction - result, err := engine.getLogBufferStartFromFile(entry) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, int64(1609459100000000001), result.StartIndex) - - // Test extraction works correctly with the binary format -} - -func TestSQLEngine_GetLogBufferStartFromFile_NoMetadata(t *testing.T) { - engine := NewTestSQLEngine() - - // Create file entry without buffer start - entry := &filer_pb.Entry{ - Name: "test-log-file", - Extended: nil, - } - - // Test extraction - result, err := engine.getLogBufferStartFromFile(entry) - assert.NoError(t, err) - assert.Nil(t, result) -} - -func TestSQLEngine_GetLogBufferStartFromFile_InvalidData(t *testing.T) { - engine := NewTestSQLEngine() - - // Create file entry with invalid buffer start (wrong size) - entry := &filer_pb.Entry{ - Name: "test-log-file", - Extended: map[string][]byte{ - "buffer_start": []byte("invalid-binary"), - }, - } - - // Test extraction - result, err := engine.getLogBufferStartFromFile(entry) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid buffer_start format: expected 8 bytes") - assert.Nil(t, result) -} - -func TestSQLEngine_BuildLogBufferDeduplicationMap_NoBrokerClient(t *testing.T) { - engine := NewTestSQLEngine() - engine.catalog.brokerClient = nil // Simulate no broker client - - ctx := context.Background() - result, err := engine.buildLogBufferDeduplicationMap(ctx, "/topics/test/test-topic") - - assert.NoError(t, err) - assert.NotNil(t, result) - assert.Empty(t, result) -} - -func TestSQLEngine_LogBufferDeduplication_ServerRestartScenario(t *testing.T) { - // Simulate scenario: Buffer indexes are now initialized with process start time - // This tests that buffer start indexes are globally unique across server restarts - - // Before server restart: Process 1 buffer start (3 chunks) - beforeRestartStart := LogBufferStart{ - StartIndex: 1609459100000000000, // Process 1 start time - } - - // After server restart: Process 2 buffer start (3 chunks) - afterRestartStart := LogBufferStart{ - StartIndex: 1609459300000000000, // Process 2 start time (DIFFERENT) - } - - // Simulate 3 chunks for each file - chunkCount := int64(3) - - // Calculate end indexes for range comparison - beforeEnd := beforeRestartStart.StartIndex + chunkCount - 1 // [start, start+2] - afterStart := afterRestartStart.StartIndex // [start, start+2] - - // Test range overlap detection (should NOT overlap) - overlaps := beforeRestartStart.StartIndex <= (afterStart+chunkCount-1) && beforeEnd >= afterStart - assert.False(t, overlaps, "Buffer ranges after restart should not overlap") - - // Verify the start indexes are globally unique - assert.NotEqual(t, beforeRestartStart.StartIndex, afterRestartStart.StartIndex, "Start indexes should be different") - assert.Less(t, beforeEnd, afterStart, "Ranges should be completely separate") - - // Expected values: - // Before restart: [1609459100000000000, 1609459100000000002] - // After restart: [1609459300000000000, 1609459300000000002] - expectedBeforeEnd := int64(1609459100000000002) - expectedAfterStart := int64(1609459300000000000) - - assert.Equal(t, expectedBeforeEnd, beforeEnd) - assert.Equal(t, expectedAfterStart, afterStart) - - // This demonstrates that buffer start indexes initialized with process start time - // prevent false positive duplicates across server restarts -} - -// TestGetSQLValAlias tests the getSQLValAlias function, particularly for SQL injection prevention -func TestGetSQLValAlias(t *testing.T) { - engine := &SQLEngine{} - - tests := []struct { - name string - sqlVal *SQLVal - expected string - desc string - }{ - { - name: "simple string", - sqlVal: &SQLVal{ - Type: StrVal, - Val: []byte("hello"), - }, - expected: "'hello'", - desc: "Simple string should be wrapped in single quotes", - }, - { - name: "string with single quote", - sqlVal: &SQLVal{ - Type: StrVal, - Val: []byte("don't"), - }, - expected: "'don''t'", - desc: "String with single quote should have the quote escaped by doubling it", - }, - { - name: "string with multiple single quotes", - sqlVal: &SQLVal{ - Type: StrVal, - Val: []byte("'malicious'; DROP TABLE users; --"), - }, - expected: "'''malicious''; DROP TABLE users; --'", - desc: "String with SQL injection attempt should have all single quotes properly escaped", - }, - { - name: "empty string", - sqlVal: &SQLVal{ - Type: StrVal, - Val: []byte(""), - }, - expected: "''", - desc: "Empty string should result in empty quoted string", - }, - { - name: "integer value", - sqlVal: &SQLVal{ - Type: IntVal, - Val: []byte("123"), - }, - expected: "123", - desc: "Integer value should not be quoted", - }, - { - name: "float value", - sqlVal: &SQLVal{ - Type: FloatVal, - Val: []byte("123.45"), - }, - expected: "123.45", - desc: "Float value should not be quoted", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := engine.getSQLValAlias(tt.sqlVal) - assert.Equal(t, tt.expected, result, tt.desc) - }) - } -} diff --git a/weed/query/engine/errors.go b/weed/query/engine/errors.go index 6a297d92f..2c68ab10d 100644 --- a/weed/query/engine/errors.go +++ b/weed/query/engine/errors.go @@ -44,7 +44,7 @@ type ParseError struct { func (e ParseError) Error() string { if e.Cause != nil { - return fmt.Sprintf("SQL parse error: %s (%v)", e.Message, e.Cause) + return fmt.Sprintf("SQL parse error: %s (caused by: %v)", e.Message, e.Cause) } return fmt.Sprintf("SQL parse error: %s", e.Message) } diff --git a/weed/query/engine/execution_plan_fast_path_test.go b/weed/query/engine/execution_plan_fast_path_test.go deleted file mode 100644 index c0f08fa21..000000000 --- a/weed/query/engine/execution_plan_fast_path_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package engine - -import ( - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" - "github.com/stretchr/testify/assert" -) - -// TestExecutionPlanFastPathDisplay tests that the execution plan correctly shows -// "Parquet Statistics (fast path)" when fast path is used, not "Parquet Files (full scan)" -func TestExecutionPlanFastPathDisplay(t *testing.T) { - engine := NewMockSQLEngine() - - // Create realistic data sources for fast path scenario - dataSources := &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/topic/partition-1": { - { - RowCount: 500, - ColumnStats: map[string]*ParquetColumnStats{ - "id": { - ColumnName: "id", - MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1}}, - MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 500}}, - NullCount: 0, - RowCount: 500, - }, - }, - }, - }, - }, - ParquetRowCount: 500, - LiveLogRowCount: 0, // Pure parquet scenario - ideal for fast path - PartitionsCount: 1, - } - - t.Run("Fast path execution plan shows correct data sources", func(t *testing.T) { - optimizer := NewFastPathOptimizer(engine.SQLEngine) - - aggregations := []AggregationSpec{ - {Function: FuncCOUNT, Column: "*", Alias: "COUNT(*)"}, - } - - // Test the strategy determination - strategy := optimizer.DetermineStrategy(aggregations) - assert.True(t, strategy.CanUseFastPath, "Strategy should allow fast path for COUNT(*)") - assert.Equal(t, "all_aggregations_supported", strategy.Reason) - - // Test data source list building - builder := &ExecutionPlanBuilder{} - dataSources := &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/topic/partition-1": { - {RowCount: 500}, - }, - }, - ParquetRowCount: 500, - LiveLogRowCount: 0, - PartitionsCount: 1, - } - - dataSourcesList := builder.buildDataSourcesList(strategy, dataSources) - - // When fast path is used, should show "parquet_stats" not "parquet_files" - assert.Contains(t, dataSourcesList, "parquet_stats", - "Data sources should contain 'parquet_stats' when fast path is used") - assert.NotContains(t, dataSourcesList, "parquet_files", - "Data sources should NOT contain 'parquet_files' when fast path is used") - - // Test that the formatting works correctly - formattedSource := engine.SQLEngine.formatDataSource("parquet_stats") - assert.Equal(t, "Parquet Statistics (fast path)", formattedSource, - "parquet_stats should format to 'Parquet Statistics (fast path)'") - - formattedFullScan := engine.SQLEngine.formatDataSource("parquet_files") - assert.Equal(t, "Parquet Files (full scan)", formattedFullScan, - "parquet_files should format to 'Parquet Files (full scan)'") - }) - - t.Run("Slow path execution plan shows full scan data sources", func(t *testing.T) { - builder := &ExecutionPlanBuilder{} - - // Create strategy that cannot use fast path - strategy := AggregationStrategy{ - CanUseFastPath: false, - Reason: "unsupported_aggregation_functions", - } - - dataSourcesList := builder.buildDataSourcesList(strategy, dataSources) - - // When slow path is used, should show "parquet_files" and "live_logs" - assert.Contains(t, dataSourcesList, "parquet_files", - "Slow path should contain 'parquet_files'") - assert.Contains(t, dataSourcesList, "live_logs", - "Slow path should contain 'live_logs'") - assert.NotContains(t, dataSourcesList, "parquet_stats", - "Slow path should NOT contain 'parquet_stats'") - }) - - t.Run("Data source formatting works correctly", func(t *testing.T) { - // Test just the data source formatting which is the key fix - - // Test parquet_stats formatting (fast path) - fastPathFormatted := engine.SQLEngine.formatDataSource("parquet_stats") - assert.Equal(t, "Parquet Statistics (fast path)", fastPathFormatted, - "parquet_stats should format to show fast path usage") - - // Test parquet_files formatting (slow path) - slowPathFormatted := engine.SQLEngine.formatDataSource("parquet_files") - assert.Equal(t, "Parquet Files (full scan)", slowPathFormatted, - "parquet_files should format to show full scan") - - // Test that data sources list is built correctly for fast path - builder := &ExecutionPlanBuilder{} - fastStrategy := AggregationStrategy{CanUseFastPath: true} - - fastSources := builder.buildDataSourcesList(fastStrategy, dataSources) - assert.Contains(t, fastSources, "parquet_stats", - "Fast path should include parquet_stats") - assert.NotContains(t, fastSources, "parquet_files", - "Fast path should NOT include parquet_files") - - // Test that data sources list is built correctly for slow path - slowStrategy := AggregationStrategy{CanUseFastPath: false} - - slowSources := builder.buildDataSourcesList(slowStrategy, dataSources) - assert.Contains(t, slowSources, "parquet_files", - "Slow path should include parquet_files") - assert.NotContains(t, slowSources, "parquet_stats", - "Slow path should NOT include parquet_stats") - }) -} diff --git a/weed/query/engine/fast_path_fix_test.go b/weed/query/engine/fast_path_fix_test.go deleted file mode 100644 index 3769e9215..000000000 --- a/weed/query/engine/fast_path_fix_test.go +++ /dev/null @@ -1,193 +0,0 @@ -package engine - -import ( - "context" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" - "github.com/stretchr/testify/assert" -) - -// TestFastPathCountFixRealistic tests the specific scenario mentioned in the bug report: -// Fast path returning 0 for COUNT(*) when slow path returns 1803 -func TestFastPathCountFixRealistic(t *testing.T) { - engine := NewMockSQLEngine() - - // Set up debug mode to see our new logging - ctx := context.WithValue(context.Background(), "debug", true) - - // Create realistic data sources that mimic a scenario with 1803 rows - dataSources := &TopicDataSources{ - ParquetFiles: map[string][]*ParquetFileStats{ - "/topics/test/large-topic/0000-1023": { - { - RowCount: 800, - ColumnStats: map[string]*ParquetColumnStats{ - "id": { - ColumnName: "id", - MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1}}, - MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 800}}, - NullCount: 0, - RowCount: 800, - }, - }, - }, - { - RowCount: 500, - ColumnStats: map[string]*ParquetColumnStats{ - "id": { - ColumnName: "id", - MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 801}}, - MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1300}}, - NullCount: 0, - RowCount: 500, - }, - }, - }, - }, - "/topics/test/large-topic/1024-2047": { - { - RowCount: 300, - ColumnStats: map[string]*ParquetColumnStats{ - "id": { - ColumnName: "id", - MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1301}}, - MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1600}}, - NullCount: 0, - RowCount: 300, - }, - }, - }, - }, - }, - ParquetRowCount: 1600, // 800 + 500 + 300 - LiveLogRowCount: 203, // Additional live log data - PartitionsCount: 2, - LiveLogFilesCount: 15, - } - - partitions := []string{ - "/topics/test/large-topic/0000-1023", - "/topics/test/large-topic/1024-2047", - } - - t.Run("COUNT(*) should return correct total (1803)", func(t *testing.T) { - computer := NewAggregationComputer(engine.SQLEngine) - - aggregations := []AggregationSpec{ - {Function: FuncCOUNT, Column: "*", Alias: "COUNT(*)"}, - } - - results, err := computer.ComputeFastPathAggregations(ctx, aggregations, dataSources, partitions) - - assert.NoError(t, err, "Fast path aggregation should not error") - assert.Len(t, results, 1, "Should return one result") - - // This is the key test - before our fix, this was returning 0 - expectedCount := int64(1803) // 1600 (parquet) + 203 (live log) - actualCount := results[0].Count - - assert.Equal(t, expectedCount, actualCount, - "COUNT(*) should return %d (1600 parquet + 203 live log), but got %d", - expectedCount, actualCount) - }) - - t.Run("MIN/MAX should work with multiple partitions", func(t *testing.T) { - computer := NewAggregationComputer(engine.SQLEngine) - - aggregations := []AggregationSpec{ - {Function: FuncMIN, Column: "id", Alias: "MIN(id)"}, - {Function: FuncMAX, Column: "id", Alias: "MAX(id)"}, - } - - results, err := computer.ComputeFastPathAggregations(ctx, aggregations, dataSources, partitions) - - assert.NoError(t, err, "Fast path aggregation should not error") - assert.Len(t, results, 2, "Should return two results") - - // MIN should be the lowest across all parquet files - assert.Equal(t, int64(1), results[0].Min, "MIN should be 1") - - // MAX should be the highest across all parquet files - assert.Equal(t, int64(1600), results[1].Max, "MAX should be 1600") - }) -} - -// TestFastPathDataSourceDiscoveryLogging tests that our debug logging works correctly -func TestFastPathDataSourceDiscoveryLogging(t *testing.T) { - // This test verifies that our enhanced data source collection structure is correct - - t.Run("DataSources structure validation", func(t *testing.T) { - // Test the TopicDataSources structure initialization - dataSources := &TopicDataSources{ - ParquetFiles: make(map[string][]*ParquetFileStats), - ParquetRowCount: 0, - LiveLogRowCount: 0, - LiveLogFilesCount: 0, - PartitionsCount: 0, - } - - assert.NotNil(t, dataSources, "Data sources should not be nil") - assert.NotNil(t, dataSources.ParquetFiles, "ParquetFiles map should be initialized") - assert.GreaterOrEqual(t, dataSources.PartitionsCount, 0, "PartitionsCount should be non-negative") - assert.GreaterOrEqual(t, dataSources.ParquetRowCount, int64(0), "ParquetRowCount should be non-negative") - assert.GreaterOrEqual(t, dataSources.LiveLogRowCount, int64(0), "LiveLogRowCount should be non-negative") - }) -} - -// TestFastPathValidationLogic tests the enhanced validation we added -func TestFastPathValidationLogic(t *testing.T) { - t.Run("Validation catches data source vs computation mismatch", func(t *testing.T) { - // Create a scenario where data sources and computation might be inconsistent - dataSources := &TopicDataSources{ - ParquetFiles: make(map[string][]*ParquetFileStats), - ParquetRowCount: 1000, // Data sources say 1000 rows - LiveLogRowCount: 0, - PartitionsCount: 1, - } - - // But aggregation result says different count (simulating the original bug) - aggResults := []AggregationResult{ - {Count: 0}, // Bug: returns 0 when data sources show 1000 - } - - // This simulates the validation logic from tryFastParquetAggregation - totalRows := dataSources.ParquetRowCount + dataSources.LiveLogRowCount - countResult := aggResults[0].Count - - // Our validation should catch this mismatch - assert.NotEqual(t, totalRows, countResult, - "This test simulates the bug: data sources show %d but COUNT returns %d", - totalRows, countResult) - - // In the real code, this would trigger a fallback to slow path - validationPassed := (countResult == totalRows) - assert.False(t, validationPassed, "Validation should fail for inconsistent data") - }) - - t.Run("Validation passes for consistent data", func(t *testing.T) { - // Create a scenario where everything is consistent - dataSources := &TopicDataSources{ - ParquetFiles: make(map[string][]*ParquetFileStats), - ParquetRowCount: 1000, - LiveLogRowCount: 803, - PartitionsCount: 1, - } - - // Aggregation result matches data sources - aggResults := []AggregationResult{ - {Count: 1803}, // Correct: matches 1000 + 803 - } - - totalRows := dataSources.ParquetRowCount + dataSources.LiveLogRowCount - countResult := aggResults[0].Count - - // Our validation should pass this - assert.Equal(t, totalRows, countResult, - "Validation should pass when data sources (%d) match COUNT result (%d)", - totalRows, countResult) - - validationPassed := (countResult == totalRows) - assert.True(t, validationPassed, "Validation should pass for consistent data") - }) -} diff --git a/weed/query/engine/parquet_scanner.go b/weed/query/engine/parquet_scanner.go index 9bcced904..4c33df76f 100644 --- a/weed/query/engine/parquet_scanner.go +++ b/weed/query/engine/parquet_scanner.go @@ -1,280 +1,14 @@ package engine import ( - "context" "fmt" "math/big" "time" - "github.com/seaweedfs/seaweedfs/weed/mq/schema" - "github.com/seaweedfs/seaweedfs/weed/mq/topic" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/pb/mq_pb" "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" "github.com/seaweedfs/seaweedfs/weed/query/sqltypes" - "github.com/seaweedfs/seaweedfs/weed/util/chunk_cache" ) -// ParquetScanner scans MQ topic Parquet files for SELECT queries -// Assumptions: -// 1. All MQ messages are stored in Parquet format in topic partitions -// 2. Each partition directory contains dated Parquet files -// 3. System columns (_ts_ns, _key) are added to user schema -// 4. Predicate pushdown is used for efficient scanning -type ParquetScanner struct { - filerClient filer_pb.FilerClient - chunkCache chunk_cache.ChunkCache - topic topic.Topic - recordSchema *schema_pb.RecordType - parquetLevels *schema.ParquetLevels -} - -// NewParquetScanner creates a scanner for a specific MQ topic -// Assumption: Topic exists and has Parquet files in partition directories -func NewParquetScanner(filerClient filer_pb.FilerClient, namespace, topicName string) (*ParquetScanner, error) { - // Check if filerClient is available - if filerClient == nil { - return nil, fmt.Errorf("filerClient is required but not available") - } - - // Create topic reference - t := topic.Topic{ - Namespace: namespace, - Name: topicName, - } - - // Read topic configuration to get schema - var topicConf *mq_pb.ConfigureTopicResponse - var err error - if err := filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error { - topicConf, err = t.ReadConfFile(client) - return err - }); err != nil { - return nil, fmt.Errorf("failed to read topic config: %v", err) - } - - // Build complete schema with system columns - prefer flat schema if available - var recordType *schema_pb.RecordType - - if topicConf.GetMessageRecordType() != nil { - // New flat schema format - use directly - recordType = topicConf.GetMessageRecordType() - } - - if recordType == nil || len(recordType.Fields) == 0 { - // For topics without schema, create a minimal schema with system fields and _value - recordType = schema.RecordTypeBegin(). - WithField(SW_COLUMN_NAME_TIMESTAMP, schema.TypeInt64). - WithField(SW_COLUMN_NAME_KEY, schema.TypeBytes). - WithField(SW_COLUMN_NAME_VALUE, schema.TypeBytes). // Raw message value - RecordTypeEnd() - } else { - // Add system columns that MQ adds to all records - recordType = schema.NewRecordTypeBuilder(recordType). - WithField(SW_COLUMN_NAME_TIMESTAMP, schema.TypeInt64). - WithField(SW_COLUMN_NAME_KEY, schema.TypeBytes). - RecordTypeEnd() - } - - // Convert to Parquet levels for efficient reading - parquetLevels, err := schema.ToParquetLevels(recordType) - if err != nil { - return nil, fmt.Errorf("failed to create Parquet levels: %v", err) - } - - return &ParquetScanner{ - filerClient: filerClient, - chunkCache: chunk_cache.NewChunkCacheInMemory(256), // Same as MQ logstore - topic: t, - recordSchema: recordType, - parquetLevels: parquetLevels, - }, nil -} - -// ScanOptions configure how the scanner reads data -type ScanOptions struct { - // Time range filtering (Unix nanoseconds) - StartTimeNs int64 - StopTimeNs int64 - - // Column projection - if empty, select all columns - Columns []string - - // Row limit - 0 means no limit - Limit int - - // Predicate for WHERE clause filtering - Predicate func(*schema_pb.RecordValue) bool -} - -// ScanResult represents a single scanned record -type ScanResult struct { - Values map[string]*schema_pb.Value // Column name -> value - Timestamp int64 // Message timestamp (_ts_ns) - Key []byte // Message key (_key) -} - -// Scan reads records from the topic's Parquet files -// Assumptions: -// 1. Scans all partitions of the topic -// 2. Applies time filtering at Parquet level for efficiency -// 3. Applies predicates and projections after reading -func (ps *ParquetScanner) Scan(ctx context.Context, options ScanOptions) ([]ScanResult, error) { - var results []ScanResult - - // Get all partitions for this topic - // TODO: Implement proper partition discovery - // For now, assume partition 0 exists - partitions := []topic.Partition{{RangeStart: 0, RangeStop: 1000}} - - for _, partition := range partitions { - partitionResults, err := ps.scanPartition(ctx, partition, options) - if err != nil { - return nil, fmt.Errorf("failed to scan partition %v: %v", partition, err) - } - - results = append(results, partitionResults...) - - // Apply global limit across all partitions - if options.Limit > 0 && len(results) >= options.Limit { - results = results[:options.Limit] - break - } - } - - return results, nil -} - -// scanPartition scans a specific topic partition -func (ps *ParquetScanner) scanPartition(ctx context.Context, partition topic.Partition, options ScanOptions) ([]ScanResult, error) { - // partitionDir := topic.PartitionDir(ps.topic, partition) // TODO: Use for actual file listing - - var results []ScanResult - - // List Parquet files in partition directory - // TODO: Implement proper file listing with date range filtering - // For now, this is a placeholder that would list actual Parquet files - - // Simulate file processing - in real implementation, this would: - // 1. List files in partitionDir via filerClient - // 2. Filter files by date range if time filtering is enabled - // 3. Process each Parquet file in chronological order - - // Placeholder: Create sample data for testing - if len(results) == 0 { - // Generate sample data for demonstration - sampleData := ps.generateSampleData(options) - results = append(results, sampleData...) - } - - return results, nil -} - -// generateSampleData creates sample data for testing when no real Parquet files exist -func (ps *ParquetScanner) generateSampleData(options ScanOptions) []ScanResult { - now := time.Now().UnixNano() - - sampleData := []ScanResult{ - { - Values: map[string]*schema_pb.Value{ - "user_id": {Kind: &schema_pb.Value_Int32Value{Int32Value: 1001}}, - "event_type": {Kind: &schema_pb.Value_StringValue{StringValue: "login"}}, - "data": {Kind: &schema_pb.Value_StringValue{StringValue: `{"ip": "192.168.1.1"}`}}, - }, - Timestamp: now - 3600000000000, // 1 hour ago - Key: []byte("user-1001"), - }, - { - Values: map[string]*schema_pb.Value{ - "user_id": {Kind: &schema_pb.Value_Int32Value{Int32Value: 1002}}, - "event_type": {Kind: &schema_pb.Value_StringValue{StringValue: "page_view"}}, - "data": {Kind: &schema_pb.Value_StringValue{StringValue: `{"page": "/dashboard"}`}}, - }, - Timestamp: now - 1800000000000, // 30 minutes ago - Key: []byte("user-1002"), - }, - { - Values: map[string]*schema_pb.Value{ - "user_id": {Kind: &schema_pb.Value_Int32Value{Int32Value: 1001}}, - "event_type": {Kind: &schema_pb.Value_StringValue{StringValue: "logout"}}, - "data": {Kind: &schema_pb.Value_StringValue{StringValue: `{"session_duration": 3600}`}}, - }, - Timestamp: now - 900000000000, // 15 minutes ago - Key: []byte("user-1001"), - }, - } - - // Apply predicate filtering if specified - if options.Predicate != nil { - var filtered []ScanResult - for _, result := range sampleData { - // Convert to RecordValue for predicate testing - recordValue := &schema_pb.RecordValue{Fields: make(map[string]*schema_pb.Value)} - for k, v := range result.Values { - recordValue.Fields[k] = v - } - recordValue.Fields[SW_COLUMN_NAME_TIMESTAMP] = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: result.Timestamp}} - recordValue.Fields[SW_COLUMN_NAME_KEY] = &schema_pb.Value{Kind: &schema_pb.Value_BytesValue{BytesValue: result.Key}} - - if options.Predicate(recordValue) { - filtered = append(filtered, result) - } - } - sampleData = filtered - } - - // Apply limit - if options.Limit > 0 && len(sampleData) > options.Limit { - sampleData = sampleData[:options.Limit] - } - - return sampleData -} - -// ConvertToSQLResult converts ScanResults to SQL query results -func (ps *ParquetScanner) ConvertToSQLResult(results []ScanResult, columns []string) *QueryResult { - if len(results) == 0 { - return &QueryResult{ - Columns: columns, - Rows: [][]sqltypes.Value{}, - } - } - - // Determine columns if not specified - if len(columns) == 0 { - columnSet := make(map[string]bool) - for _, result := range results { - for columnName := range result.Values { - columnSet[columnName] = true - } - } - - columns = make([]string, 0, len(columnSet)) - for columnName := range columnSet { - columns = append(columns, columnName) - } - } - - // Convert to SQL rows - rows := make([][]sqltypes.Value, len(results)) - for i, result := range results { - row := make([]sqltypes.Value, len(columns)) - for j, columnName := range columns { - if value, exists := result.Values[columnName]; exists { - row[j] = convertSchemaValueToSQL(value) - } else { - row[j] = sqltypes.NULL - } - } - rows[i] = row - } - - return &QueryResult{ - Columns: columns, - Rows: rows, - } -} - // convertSchemaValueToSQL converts schema_pb.Value to sqltypes.Value func convertSchemaValueToSQL(value *schema_pb.Value) sqltypes.Value { if value == nil { diff --git a/weed/query/engine/partition_path_fix_test.go b/weed/query/engine/partition_path_fix_test.go deleted file mode 100644 index 8d92136e6..000000000 --- a/weed/query/engine/partition_path_fix_test.go +++ /dev/null @@ -1,117 +0,0 @@ -package engine - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -// TestPartitionPathHandling tests that partition paths are handled correctly -// whether discoverTopicPartitions returns relative or absolute paths -func TestPartitionPathHandling(t *testing.T) { - engine := NewMockSQLEngine() - - t.Run("Mock discoverTopicPartitions returns correct paths", func(t *testing.T) { - // Test that our mock engine handles absolute paths correctly - engine.mockPartitions["test.user_events"] = []string{ - "/topics/test/user_events/v2025-09-03-15-36-29/0000-2520", - "/topics/test/user_events/v2025-09-03-15-36-29/2521-5040", - } - - partitions, err := engine.discoverTopicPartitions("test", "user_events") - assert.NoError(t, err, "Should discover partitions without error") - assert.Equal(t, 2, len(partitions), "Should return 2 partitions") - assert.Contains(t, partitions[0], "/topics/test/user_events/", "Should contain absolute path") - }) - - t.Run("Mock discoverTopicPartitions handles relative paths", func(t *testing.T) { - // Test relative paths scenario - engine.mockPartitions["test.user_events"] = []string{ - "v2025-09-03-15-36-29/0000-2520", - "v2025-09-03-15-36-29/2521-5040", - } - - partitions, err := engine.discoverTopicPartitions("test", "user_events") - assert.NoError(t, err, "Should discover partitions without error") - assert.Equal(t, 2, len(partitions), "Should return 2 partitions") - assert.True(t, !strings.HasPrefix(partitions[0], "/topics/"), "Should be relative path") - }) - - t.Run("Partition path building logic works correctly", func(t *testing.T) { - topicBasePath := "/topics/test/user_events" - - testCases := []struct { - name string - relativePartition string - expectedPath string - }{ - { - name: "Absolute path - use as-is", - relativePartition: "/topics/test/user_events/v2025-09-03-15-36-29/0000-2520", - expectedPath: "/topics/test/user_events/v2025-09-03-15-36-29/0000-2520", - }, - { - name: "Relative path - build full path", - relativePartition: "v2025-09-03-15-36-29/0000-2520", - expectedPath: "/topics/test/user_events/v2025-09-03-15-36-29/0000-2520", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var partitionPath string - - // This is the same logic from our fixed code - if strings.HasPrefix(tc.relativePartition, "/topics/") { - // Already a full path - use as-is - partitionPath = tc.relativePartition - } else { - // Relative path - build full path - partitionPath = topicBasePath + "/" + tc.relativePartition - } - - assert.Equal(t, tc.expectedPath, partitionPath, - "Partition path should be built correctly") - - // Ensure no double slashes - assert.NotContains(t, partitionPath, "//", - "Partition path should not contain double slashes") - }) - } - }) -} - -// TestPartitionPathLogic tests the core logic for handling partition paths -func TestPartitionPathLogic(t *testing.T) { - t.Run("Building partition paths from discovered partitions", func(t *testing.T) { - // Test the specific partition path building that was causing issues - - topicBasePath := "/topics/ecommerce/user_events" - - // This simulates the discoverTopicPartitions returning absolute paths (realistic scenario) - relativePartitions := []string{ - "/topics/ecommerce/user_events/v2025-09-03-15-36-29/0000-2520", - } - - // This is the code from our fix - test it directly - partitions := make([]string, len(relativePartitions)) - for i, relPartition := range relativePartitions { - // Handle both relative and absolute partition paths from discoverTopicPartitions - if strings.HasPrefix(relPartition, "/topics/") { - // Already a full path - use as-is - partitions[i] = relPartition - } else { - // Relative path - build full path - partitions[i] = topicBasePath + "/" + relPartition - } - } - - // Verify the path was handled correctly - expectedPath := "/topics/ecommerce/user_events/v2025-09-03-15-36-29/0000-2520" - assert.Equal(t, expectedPath, partitions[0], "Absolute path should be used as-is") - - // Ensure no double slashes (this was the original bug) - assert.NotContains(t, partitions[0], "//", "Path should not contain double slashes") - }) -} diff --git a/weed/query/sqltypes/type.go b/weed/query/sqltypes/type.go index f4f3dd471..2a0f40386 100644 --- a/weed/query/sqltypes/type.go +++ b/weed/query/sqltypes/type.go @@ -56,11 +56,6 @@ func IsBinary(t Type) bool { return int(t)&flagIsBinary == flagIsBinary } -// isNumber returns true if the type is any type of number. -func isNumber(t Type) bool { - return IsIntegral(t) || IsFloat(t) || t == Decimal -} - // IsTemporal returns true if Value is time type. func IsTemporal(t Type) bool { switch t { diff --git a/weed/query/sqltypes/value.go b/weed/query/sqltypes/value.go index 012de2b45..7c2599652 100644 --- a/weed/query/sqltypes/value.go +++ b/weed/query/sqltypes/value.go @@ -1,9 +1,7 @@ package sqltypes import ( - "fmt" "strconv" - "time" ) var ( @@ -19,32 +17,6 @@ type Value struct { val []byte } -// NewValue builds a Value using typ and val. If the value and typ -// don't match, it returns an error. -func NewValue(typ Type, val []byte) (v Value, err error) { - switch { - case IsSigned(typ): - if _, err := strconv.ParseInt(string(val), 0, 64); err != nil { - return NULL, err - } - return MakeTrusted(typ, val), nil - case IsUnsigned(typ): - if _, err := strconv.ParseUint(string(val), 0, 64); err != nil { - return NULL, err - } - return MakeTrusted(typ, val), nil - case IsFloat(typ) || typ == Decimal: - if _, err := strconv.ParseFloat(string(val), 64); err != nil { - return NULL, err - } - return MakeTrusted(typ, val), nil - case IsQuoted(typ) || typ == Bit || typ == Null: - return MakeTrusted(typ, val), nil - } - // All other types are unsafe or invalid. - return NULL, fmt.Errorf("invalid type specified for MakeValue: %v", typ) -} - // MakeTrusted makes a new Value based on the type. // This function should only be used if you know the value // and type conform to the rules. Every place this function is @@ -71,11 +43,6 @@ func NewInt32(v int32) Value { return MakeTrusted(Int32, strconv.AppendInt(nil, int64(v), 10)) } -// NewUint64 builds an Uint64 Value. -func NewUint64(v uint64) Value { - return MakeTrusted(Uint64, strconv.AppendUint(nil, v, 10)) -} - // NewFloat32 builds an Float64 Value. func NewFloat32(v float32) Value { return MakeTrusted(Float32, strconv.AppendFloat(nil, float64(v), 'f', -1, 64)) @@ -97,136 +64,11 @@ func NewVarBinary(v string) Value { return MakeTrusted(VarBinary, []byte(v)) } -// NewIntegral builds an integral type from a string representation. -// The type will be Int64 or Uint64. Int64 will be preferred where possible. -func NewIntegral(val string) (n Value, err error) { - signed, err := strconv.ParseInt(val, 0, 64) - if err == nil { - return MakeTrusted(Int64, strconv.AppendInt(nil, signed, 10)), nil - } - unsigned, err := strconv.ParseUint(val, 0, 64) - if err != nil { - return Value{}, err - } - return MakeTrusted(Uint64, strconv.AppendUint(nil, unsigned, 10)), nil -} - // MakeString makes a VarBinary Value. func MakeString(val []byte) Value { return MakeTrusted(VarBinary, val) } -// BuildValue builds a value from any go type. sqltype.Value is -// also allowed. -func BuildValue(goval interface{}) (v Value, err error) { - // Look for the most common types first. - switch goval := goval.(type) { - case nil: - // no op - case []byte: - v = MakeTrusted(VarBinary, goval) - case int64: - v = MakeTrusted(Int64, strconv.AppendInt(nil, int64(goval), 10)) - case uint64: - v = MakeTrusted(Uint64, strconv.AppendUint(nil, uint64(goval), 10)) - case float64: - v = MakeTrusted(Float64, strconv.AppendFloat(nil, goval, 'f', -1, 64)) - case int: - v = MakeTrusted(Int64, strconv.AppendInt(nil, int64(goval), 10)) - case int8: - v = MakeTrusted(Int8, strconv.AppendInt(nil, int64(goval), 10)) - case int16: - v = MakeTrusted(Int16, strconv.AppendInt(nil, int64(goval), 10)) - case int32: - v = MakeTrusted(Int32, strconv.AppendInt(nil, int64(goval), 10)) - case uint: - v = MakeTrusted(Uint64, strconv.AppendUint(nil, uint64(goval), 10)) - case uint8: - v = MakeTrusted(Uint8, strconv.AppendUint(nil, uint64(goval), 10)) - case uint16: - v = MakeTrusted(Uint16, strconv.AppendUint(nil, uint64(goval), 10)) - case uint32: - v = MakeTrusted(Uint32, strconv.AppendUint(nil, uint64(goval), 10)) - case float32: - v = MakeTrusted(Float32, strconv.AppendFloat(nil, float64(goval), 'f', -1, 64)) - case string: - v = MakeTrusted(VarBinary, []byte(goval)) - case time.Time: - v = MakeTrusted(Datetime, []byte(goval.Format("2006-01-02 15:04:05"))) - case Value: - v = goval - case *BindVariable: - return ValueFromBytes(goval.Type, goval.Value) - default: - return v, fmt.Errorf("unexpected type %T: %v", goval, goval) - } - return v, nil -} - -// BuildConverted is like BuildValue except that it tries to -// convert a string or []byte to an integral if the target type -// is an integral. We don't perform other implicit conversions -// because they're unsafe. -func BuildConverted(typ Type, goval interface{}) (v Value, err error) { - if IsIntegral(typ) { - switch goval := goval.(type) { - case []byte: - return ValueFromBytes(typ, goval) - case string: - return ValueFromBytes(typ, []byte(goval)) - case Value: - if goval.IsQuoted() { - return ValueFromBytes(typ, goval.Raw()) - } - } - } - return BuildValue(goval) -} - -// ValueFromBytes builds a Value using typ and val. It ensures that val -// matches the requested type. If type is an integral it's converted to -// a canonical form. Otherwise, the original representation is preserved. -func ValueFromBytes(typ Type, val []byte) (v Value, err error) { - switch { - case IsSigned(typ): - signed, err := strconv.ParseInt(string(val), 0, 64) - if err != nil { - return NULL, err - } - v = MakeTrusted(typ, strconv.AppendInt(nil, signed, 10)) - case IsUnsigned(typ): - unsigned, err := strconv.ParseUint(string(val), 0, 64) - if err != nil { - return NULL, err - } - v = MakeTrusted(typ, strconv.AppendUint(nil, unsigned, 10)) - case IsFloat(typ) || typ == Decimal: - _, err := strconv.ParseFloat(string(val), 64) - if err != nil { - return NULL, err - } - // After verification, we preserve the original representation. - fallthrough - default: - v = MakeTrusted(typ, val) - } - return v, nil -} - -// BuildIntegral builds an integral type from a string representation. -// The type will be Int64 or Uint64. Int64 will be preferred where possible. -func BuildIntegral(val string) (n Value, err error) { - signed, err := strconv.ParseInt(val, 0, 64) - if err == nil { - return MakeTrusted(Int64, strconv.AppendInt(nil, signed, 10)), nil - } - unsigned, err := strconv.ParseUint(val, 0, 64) - if err != nil { - return Value{}, err - } - return MakeTrusted(Uint64, strconv.AppendUint(nil, unsigned, 10)), nil -} - // Type returns the type of Value. func (v Value) Type() Type { return v.typ @@ -247,15 +89,6 @@ func (v Value) Len() int { // Values represents the array of Value. type Values []Value -// Len implements the interface. -func (vs Values) Len() int { - len := 0 - for _, v := range vs { - len += v.Len() - } - return len -} - // String returns the raw value as a string. func (v Value) String() string { return BytesToString(v.val) diff --git a/weed/remote_storage/remote_storage.go b/weed/remote_storage/remote_storage.go index e23fd81df..0a6a63e1d 100644 --- a/weed/remote_storage/remote_storage.go +++ b/weed/remote_storage/remote_storage.go @@ -120,17 +120,6 @@ func GetAllRemoteStorageNames() string { return strings.Join(storageNames, "|") } -func GetRemoteStorageNamesHasBucket() string { - var storageNames []string - for k, m := range RemoteStorageClientMakers { - if m.HasBucket() { - storageNames = append(storageNames, k) - } - } - sort.Strings(storageNames) - return strings.Join(storageNames, "|") -} - func ParseRemoteLocation(remoteConfType string, remote string) (remoteStorageLocation *remote_pb.RemoteStorageLocation, err error) { maker, found := RemoteStorageClientMakers[remoteConfType] if !found { diff --git a/weed/s3api/auth_credentials.go b/weed/s3api/auth_credentials.go index a09f6c6d4..ec950cbab 100644 --- a/weed/s3api/auth_credentials.go +++ b/weed/s3api/auth_credentials.go @@ -144,6 +144,10 @@ func (c *Credential) isCredentialExpired() bool { } // NewIdentityAccessManagement creates a new IAM manager +func NewIdentityAccessManagement(option *S3ApiServerOption, filerClient *wdclient.FilerClient) *IdentityAccessManagement { + return NewIdentityAccessManagementWithStore(option, filerClient, "") +} + // SetFilerClient updates the filer client and its associated credential store func (iam *IdentityAccessManagement) SetFilerClient(filerClient *wdclient.FilerClient) { iam.m.Lock() @@ -196,10 +200,6 @@ func parseExternalUrlToHost(externalUrl string) (string, error) { return net.JoinHostPort(host, port), nil } -func NewIdentityAccessManagement(option *S3ApiServerOption, filerClient *wdclient.FilerClient) *IdentityAccessManagement { - return NewIdentityAccessManagementWithStore(option, filerClient, "") -} - func NewIdentityAccessManagementWithStore(option *S3ApiServerOption, filerClient *wdclient.FilerClient, explicitStore string) *IdentityAccessManagement { var externalHost string if option.ExternalUrl != "" { diff --git a/weed/s3api/auth_credentials_test.go b/weed/s3api/auth_credentials_test.go deleted file mode 100644 index 1e84b93db..000000000 --- a/weed/s3api/auth_credentials_test.go +++ /dev/null @@ -1,1393 +0,0 @@ -package s3api - -import ( - "context" - "crypto/tls" - "fmt" - "net/http" - "os" - "reflect" - "sync" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/credential" - "github.com/seaweedfs/seaweedfs/weed/credential/memory" - "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/policy_engine" - . "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" - "github.com/seaweedfs/seaweedfs/weed/util/wildcard" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - jsonpb "google.golang.org/protobuf/encoding/protojson" - - _ "github.com/seaweedfs/seaweedfs/weed/credential/filer_etc" -) - -type loadConfigurationDropsPoliciesStore struct { - *memory.MemoryStore - loadManagedPoliciesCalled bool -} - -func (store *loadConfigurationDropsPoliciesStore) LoadConfiguration(ctx context.Context) (*iam_pb.S3ApiConfiguration, error) { - config, err := store.MemoryStore.LoadConfiguration(ctx) - if err != nil { - return nil, err - } - stripped := *config - stripped.Policies = nil - return &stripped, nil -} - -func (store *loadConfigurationDropsPoliciesStore) LoadManagedPolicies(ctx context.Context) ([]*iam_pb.Policy, error) { - store.loadManagedPoliciesCalled = true - - config, err := store.MemoryStore.LoadConfiguration(ctx) - if err != nil { - return nil, err - } - - policies := make([]*iam_pb.Policy, 0, len(config.Policies)) - for _, policy := range config.Policies { - policies = append(policies, &iam_pb.Policy{ - Name: policy.Name, - Content: policy.Content, - }) - } - - return policies, nil -} - -type inlinePolicyRuntimeStore struct { - *memory.MemoryStore - inlinePolicies map[string]map[string]policy_engine.PolicyDocument -} - -func (store *inlinePolicyRuntimeStore) LoadInlinePolicies(ctx context.Context) (map[string]map[string]policy_engine.PolicyDocument, error) { - _ = ctx - return store.inlinePolicies, nil -} - -func newPolicyAuthRequest(t *testing.T, method string) *http.Request { - t.Helper() - req, err := http.NewRequest(method, "http://s3.amazonaws.com/test-bucket/test-object", nil) - require.NoError(t, err) - return req -} - -func TestIdentityListFileFormat(t *testing.T) { - - s3ApiConfiguration := &iam_pb.S3ApiConfiguration{} - - identity1 := &iam_pb.Identity{ - Name: "some_name", - Credentials: []*iam_pb.Credential{ - { - AccessKey: "some_access_key1", - SecretKey: "some_secret_key2", - }, - }, - Actions: []string{ - ACTION_ADMIN, - ACTION_READ, - ACTION_WRITE, - }, - } - identity2 := &iam_pb.Identity{ - Name: "some_read_only_user", - Credentials: []*iam_pb.Credential{ - { - AccessKey: "some_access_key1", - SecretKey: "some_secret_key1", - }, - }, - Actions: []string{ - ACTION_READ, - }, - } - identity3 := &iam_pb.Identity{ - Name: "some_normal_user", - Credentials: []*iam_pb.Credential{ - { - AccessKey: "some_access_key2", - SecretKey: "some_secret_key2", - }, - }, - Actions: []string{ - ACTION_READ, - ACTION_WRITE, - }, - } - - s3ApiConfiguration.Identities = append(s3ApiConfiguration.Identities, identity1) - s3ApiConfiguration.Identities = append(s3ApiConfiguration.Identities, identity2) - s3ApiConfiguration.Identities = append(s3ApiConfiguration.Identities, identity3) - - m := jsonpb.MarshalOptions{ - EmitUnpopulated: true, - Indent: " ", - } - - text, _ := m.Marshal(s3ApiConfiguration) - - println(string(text)) - -} - -func TestCanDo(t *testing.T) { - ident1 := &Identity{ - Name: "anything", - Actions: []Action{ - "Write:bucket1/a/b/c/*", - "Write:bucket1/a/b/other", - }, - } - // object specific - assert.Equal(t, true, ident1.CanDo(ACTION_WRITE, "bucket1", "/a/b/c/d.txt")) - assert.Equal(t, true, ident1.CanDo(ACTION_WRITE, "bucket1", "/a/b/c/d/e.txt")) - assert.Equal(t, false, ident1.CanDo(ACTION_DELETE_BUCKET, "bucket1", "")) - assert.Equal(t, false, ident1.CanDo(ACTION_WRITE, "bucket1", "/a/b/other/some"), "action without *") - assert.Equal(t, false, ident1.CanDo(ACTION_WRITE, "bucket1", "/a/b/*"), "action on parent directory") - - // bucket specific - ident2 := &Identity{ - Name: "anything", - Actions: []Action{ - "Read:bucket1", - "Write:bucket1/*", - "WriteAcp:bucket1", - }, - } - assert.Equal(t, true, ident2.CanDo(ACTION_READ, "bucket1", "/a/b/c/d.txt")) - assert.Equal(t, true, ident2.CanDo(ACTION_WRITE, "bucket1", "/a/b/c/d.txt")) - assert.Equal(t, true, ident2.CanDo(ACTION_WRITE_ACP, "bucket1", "")) - assert.Equal(t, false, ident2.CanDo(ACTION_READ_ACP, "bucket1", "")) - assert.Equal(t, false, ident2.CanDo(ACTION_LIST, "bucket1", "/a/b/c/d.txt")) - - // across buckets - ident3 := &Identity{ - Name: "anything", - Actions: []Action{ - "Read", - "Write", - }, - } - assert.Equal(t, true, ident3.CanDo(ACTION_READ, "bucket1", "/a/b/c/d.txt")) - assert.Equal(t, true, ident3.CanDo(ACTION_WRITE, "bucket1", "/a/b/c/d.txt")) - assert.Equal(t, false, ident3.CanDo(ACTION_LIST, "bucket1", "/a/b/other/some")) - assert.Equal(t, false, ident3.CanDo(ACTION_WRITE_ACP, "bucket1", "")) - - // partial buckets - ident4 := &Identity{ - Name: "anything", - Actions: []Action{ - "Read:special_*", - "ReadAcp:special_*", - }, - } - assert.Equal(t, true, ident4.CanDo(ACTION_READ, "special_bucket", "/a/b/c/d.txt")) - assert.Equal(t, true, ident4.CanDo(ACTION_READ_ACP, "special_bucket", "")) - assert.Equal(t, false, ident4.CanDo(ACTION_READ, "bucket1", "/a/b/c/d.txt")) - - // admin buckets - ident5 := &Identity{ - Name: "anything", - Actions: []Action{ - "Admin:special_*", - }, - } - assert.Equal(t, true, ident5.CanDo(ACTION_READ, "special_bucket", "/a/b/c/d.txt")) - assert.Equal(t, true, ident5.CanDo(ACTION_READ_ACP, "special_bucket", "")) - assert.Equal(t, true, ident5.CanDo(ACTION_WRITE, "special_bucket", "/a/b/c/d.txt")) - assert.Equal(t, true, ident5.CanDo(ACTION_WRITE_ACP, "special_bucket", "")) - - // anonymous buckets - ident6 := &Identity{ - Name: "anonymous", - Actions: []Action{ - "Read", - }, - } - assert.Equal(t, true, ident6.CanDo(ACTION_READ, "anything_bucket", "/a/b/c/d.txt")) - - //test deleteBucket operation - ident7 := &Identity{ - Name: "anything", - Actions: []Action{ - "DeleteBucket:bucket1", - }, - } - assert.Equal(t, true, ident7.CanDo(ACTION_DELETE_BUCKET, "bucket1", "")) -} - -func TestMatchWildcardPattern(t *testing.T) { - tests := []struct { - pattern string - target string - match bool - }{ - // Basic * wildcard tests - {"Bucket/*", "Bucket/a/b", true}, - {"Bucket/*", "x/Bucket/a", false}, - {"Bucket/*/admin", "Bucket/x/admin", true}, - {"Bucket/*/admin", "Bucket/x/y/admin", true}, - {"Bucket/*/admin", "Bucket////x////uwu////y////admin", true}, - {"abc*def", "abcXYZdef", true}, - {"abc*def", "abcXYZdefZZ", false}, - {"syr/*", "syr/a/b", true}, - - // ? wildcard tests (matches exactly one character) - {"ab?d", "abcd", true}, - {"ab?d", "abXd", true}, - {"ab?d", "abd", false}, // ? must match exactly one character - {"ab?d", "abcXd", false}, // ? matches only one character - {"a?c", "abc", true}, - {"a?c", "aXc", true}, - {"a?c", "ac", false}, - {"???", "abc", true}, - {"???", "ab", false}, - {"???", "abcd", false}, - - // Combined * and ? wildcards - {"a*?", "ab", true}, // * matches empty, ? matches 'b' - {"a*?", "abc", true}, // * matches 'b', ? matches 'c' - {"a*?", "a", false}, // ? must match something - {"a?*", "ab", true}, // ? matches 'b', * matches empty - {"a?*", "abc", true}, // ? matches 'b', * matches 'c' - {"a?*b", "aXb", true}, // ? matches 'X', * matches empty - {"a?*b", "aXYZb", true}, - {"*?*", "a", true}, - {"*?*", "", false}, // ? requires at least one character - - // Edge cases: * matches empty string - {"a*b", "ab", true}, // * matches empty string - {"a**b", "ab", true}, // multiple stars match empty - {"a**b", "axb", true}, // multiple stars match 'x' - {"a**b", "axyb", true}, - {"*", "", true}, - {"*", "anything", true}, - {"**", "", true}, - {"**", "anything", true}, - - // Edge cases: empty strings - {"", "", true}, - {"a", "", false}, - {"", "a", false}, - - // Trailing * matches empty - {"a*", "a", true}, - {"a*", "abc", true}, - {"abc*", "abc", true}, - {"abc*", "abcdef", true}, - - // Leading * matches empty - {"*a", "a", true}, - {"*a", "XXXa", true}, - {"*abc", "abc", true}, - {"*abc", "XXXabc", true}, - - // Multiple wildcards - {"*a*", "a", true}, - {"*a*", "Xa", true}, - {"*a*", "aX", true}, - {"*a*", "XaX", true}, - {"*a*b*", "ab", true}, - {"*a*b*", "XaYbZ", true}, - - // Exact match (no wildcards) - {"exact", "exact", true}, - {"exact", "notexact", false}, - {"exact", "exactnot", false}, - - // S3-style action patterns - {"Read:bucket*", "Read:bucket-test", true}, - {"Read:bucket*", "Read:bucket", true}, - {"Write:bucket/path/*", "Write:bucket/path/file.txt", true}, - {"Admin:*", "Admin:anything", true}, - } - - for _, tt := range tests { - t.Run(tt.pattern+"_"+tt.target, func(t *testing.T) { - result := wildcard.MatchesWildcard(tt.pattern, tt.target) - if result != tt.match { - t.Errorf("wildcard.MatchesWildcard(%q, %q) = %v, want %v", tt.pattern, tt.target, result, tt.match) - } - }) - } -} - -func TestVerifyActionPermissionPolicyFallback(t *testing.T) { - buildRequest := func(t *testing.T, method string) *http.Request { - t.Helper() - req, err := http.NewRequest(method, "http://s3.amazonaws.com/test-bucket/test-object", nil) - assert.NoError(t, err) - return req - } - - t.Run("policy allow grants access", func(t *testing.T) { - iam := &IdentityAccessManagement{} - err := iam.PutPolicy("allowGet", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`) - assert.NoError(t, err) - - identity := &Identity{ - Name: "policy-user", - Account: &AccountAdmin, - PolicyNames: []string{"allowGet"}, - } - - errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrNone, errCode) - }) - - t.Run("explicit deny overrides allow", func(t *testing.T) { - iam := &IdentityAccessManagement{} - err := iam.PutPolicy("allowAllGet", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`) - assert.NoError(t, err) - err = iam.PutPolicy("denySecret", `{"Version":"2012-10-17","Statement":[{"Effect":"Deny","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/secret.txt"}]}`) - assert.NoError(t, err) - - identity := &Identity{ - Name: "policy-user", - Account: &AccountAdmin, - PolicyNames: []string{"allowAllGet", "denySecret"}, - } - - errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "secret.txt") - assert.Equal(t, s3err.ErrAccessDenied, errCode) - }) - - t.Run("implicit deny when no statement matches", func(t *testing.T) { - iam := &IdentityAccessManagement{} - err := iam.PutPolicy("allowOtherBucket", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::other-bucket/*"}]}`) - assert.NoError(t, err) - - identity := &Identity{ - Name: "policy-user", - Account: &AccountAdmin, - PolicyNames: []string{"allowOtherBucket"}, - } - - errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrAccessDenied, errCode) - }) - - t.Run("invalid policy document does not allow", func(t *testing.T) { - iam := &IdentityAccessManagement{} - err := iam.PutPolicy("invalidPolicy", "{not-json") - assert.NoError(t, err) - - identity := &Identity{ - Name: "policy-user", - Account: &AccountAdmin, - PolicyNames: []string{"invalidPolicy"}, - } - - errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrAccessDenied, errCode) - }) - - t.Run("notresource excludes denied object", func(t *testing.T) { - iam := &IdentityAccessManagement{} - err := iam.PutPolicy("denyNotResource", `{"Version":"2012-10-17","Statement":[{"Effect":"Deny","Action":"s3:GetObject","NotResource":"arn:aws:s3:::test-bucket/public/*"}]}`) - assert.NoError(t, err) - err = iam.PutPolicy("allowAllGet", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`) - assert.NoError(t, err) - - identity := &Identity{ - Name: "policy-user", - Account: &AccountAdmin, - PolicyNames: []string{"allowAllGet", "denyNotResource"}, - } - - errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "private/secret.txt") - assert.Equal(t, s3err.ErrAccessDenied, errCode) - - errCode = iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "public/readme.txt") - assert.Equal(t, s3err.ErrNone, errCode) - }) - - t.Run("condition securetransport enforced", func(t *testing.T) { - iam := &IdentityAccessManagement{} - err := iam.PutPolicy("allowTLSOnly", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*","Condition":{"Bool":{"aws:SecureTransport":"true"}}}]}`) - assert.NoError(t, err) - - identity := &Identity{ - Name: "policy-user", - Account: &AccountAdmin, - PolicyNames: []string{"allowTLSOnly"}, - } - - httpReq := buildRequest(t, http.MethodGet) - errCode := iam.VerifyActionPermission(httpReq, identity, Action(ACTION_READ), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrAccessDenied, errCode) - - httpsReq := buildRequest(t, http.MethodGet) - httpsReq.TLS = &tls.ConnectionState{} - errCode = iam.VerifyActionPermission(httpsReq, identity, Action(ACTION_READ), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrNone, errCode) - }) - - t.Run("attached policies override coarse legacy actions", func(t *testing.T) { - iam := &IdentityAccessManagement{} - err := iam.PutPolicy("putOnly", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:PutObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`) - assert.NoError(t, err) - - identity := &Identity{ - Name: "policy-user", - Account: &AccountAdmin, - Actions: []Action{"Write:test-bucket"}, - PolicyNames: []string{"putOnly"}, - } - - putErrCode := iam.VerifyActionPermission(buildRequest(t, http.MethodPut), identity, Action(ACTION_WRITE), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrNone, putErrCode) - - deleteErrCode := iam.VerifyActionPermission(buildRequest(t, http.MethodDelete), identity, Action(ACTION_WRITE), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrAccessDenied, deleteErrCode) - }) - - t.Run("valid policy updated to invalid denies access", func(t *testing.T) { - iam := &IdentityAccessManagement{} - err := iam.PutPolicy("myPolicy", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`) - assert.NoError(t, err) - - identity := &Identity{ - Name: "policy-user", - Account: &AccountAdmin, - PolicyNames: []string{"myPolicy"}, - } - - errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrNone, errCode) - - // Update to invalid JSON — should revoke access. - err = iam.PutPolicy("myPolicy", "{broken") - assert.NoError(t, err) - - errCode = iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrAccessDenied, errCode) - }) - - t.Run("actions based path still works", func(t *testing.T) { - iam := &IdentityAccessManagement{} - identity := &Identity{ - Name: "legacy-user", - Account: &AccountAdmin, - Actions: []Action{"Read:test-bucket"}, - } - - errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "any-object") - assert.Equal(t, s3err.ErrNone, errCode) - }) -} - -func TestLoadS3ApiConfigurationFromCredentialManagerHydratesManagedPolicies(t *testing.T) { - baseStore := &memory.MemoryStore{} - assert.NoError(t, baseStore.Initialize(nil, "")) - - store := &loadConfigurationDropsPoliciesStore{MemoryStore: baseStore} - cm := &credential.CredentialManager{Store: store} - - config := &iam_pb.S3ApiConfiguration{ - Identities: []*iam_pb.Identity{ - { - Name: "managed-user", - PolicyNames: []string{"managedGet"}, - Credentials: []*iam_pb.Credential{ - {AccessKey: "AKIAMANAGED000001", SecretKey: "managed-secret"}, - }, - }, - }, - Policies: []*iam_pb.Policy{ - { - Name: "managedGet", - Content: `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`, - }, - }, - } - assert.NoError(t, cm.SaveConfiguration(context.Background(), config)) - - iam := &IdentityAccessManagement{credentialManager: cm} - assert.NoError(t, iam.LoadS3ApiConfigurationFromCredentialManager()) - assert.True(t, store.loadManagedPoliciesCalled) - - identity := iam.lookupByIdentityName("managed-user") - if !assert.NotNil(t, identity) { - return - } - - errCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrNone, errCode) -} - -func TestLoadS3ApiConfigurationFromCredentialManagerHydratesManagedPoliciesThroughPropagatingStore(t *testing.T) { - baseStore := &memory.MemoryStore{} - assert.NoError(t, baseStore.Initialize(nil, "")) - - upstream := &loadConfigurationDropsPoliciesStore{MemoryStore: baseStore} - wrappedStore := credential.NewPropagatingCredentialStore(upstream, nil, nil) - cm := &credential.CredentialManager{Store: wrappedStore} - - config := &iam_pb.S3ApiConfiguration{ - Identities: []*iam_pb.Identity{ - { - Name: "managed-user", - PolicyNames: []string{"managedGet"}, - Credentials: []*iam_pb.Credential{ - {AccessKey: "AKIAMANAGED000010", SecretKey: "managed-secret"}, - }, - }, - }, - Policies: []*iam_pb.Policy{ - { - Name: "managedGet", - Content: `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`, - }, - }, - } - assert.NoError(t, cm.SaveConfiguration(context.Background(), config)) - - iam := &IdentityAccessManagement{credentialManager: cm} - assert.NoError(t, iam.LoadS3ApiConfigurationFromCredentialManager()) - assert.True(t, upstream.loadManagedPoliciesCalled) - - identity := iam.lookupByIdentityName("managed-user") - if !assert.NotNil(t, identity) { - return - } - - errCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrNone, errCode) -} - -func TestLoadS3ApiConfigurationFromCredentialManagerSyncsPoliciesToIAMManager(t *testing.T) { - ctx := context.Background() - baseStore := &memory.MemoryStore{} - assert.NoError(t, baseStore.Initialize(nil, "")) - - cm := &credential.CredentialManager{Store: baseStore} - config := &iam_pb.S3ApiConfiguration{ - Identities: []*iam_pb.Identity{ - { - Name: "managed-user", - PolicyNames: []string{"managedPut"}, - Credentials: []*iam_pb.Credential{ - {AccessKey: "AKIAMANAGED000002", SecretKey: "managed-secret"}, - }, - }, - }, - Policies: []*iam_pb.Policy{ - { - Name: "managedPut", - Content: `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:PutObject","s3:ListBucket"],"Resource":["arn:aws:s3:::cli-allowed-bucket","arn:aws:s3:::cli-allowed-bucket/*"]}]}`, - }, - }, - } - assert.NoError(t, cm.SaveConfiguration(ctx, config)) - - iamManager, err := loadIAMManagerFromConfig("", func() string { return "localhost:8888" }, func() string { - return "fallback-key-for-zero-config" - }) - assert.NoError(t, err) - iamManager.SetUserStore(cm) - - iam := &IdentityAccessManagement{credentialManager: cm} - iam.SetIAMIntegration(NewS3IAMIntegration(iamManager, "")) - - assert.NoError(t, iam.LoadS3ApiConfigurationFromCredentialManager()) - - identity := iam.lookupByIdentityName("managed-user") - if !assert.NotNil(t, identity) { - return - } - - allowedErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodPut), identity, Action(ACTION_WRITE), "cli-allowed-bucket", "test-object") - assert.Equal(t, s3err.ErrNone, allowedErrCode) - - forbiddenErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodPut), identity, Action(ACTION_WRITE), "cli-forbidden-bucket", "test-object") - assert.Equal(t, s3err.ErrAccessDenied, forbiddenErrCode) -} - -func TestLoadS3ApiConfigurationFromCredentialManagerHydratesInlinePolicies(t *testing.T) { - baseStore := &memory.MemoryStore{} - assert.NoError(t, baseStore.Initialize(nil, "")) - - inlinePolicy := policy_engine.PolicyDocument{ - Version: policy_engine.PolicyVersion2012_10_17, - Statement: []policy_engine.PolicyStatement{ - { - Effect: policy_engine.PolicyEffectAllow, - Action: policy_engine.NewStringOrStringSlice("s3:PutObject"), - Resource: policy_engine.NewStringOrStringSlicePtr("arn:aws:s3:::test-bucket/*"), - }, - }, - } - - store := &inlinePolicyRuntimeStore{ - MemoryStore: baseStore, - inlinePolicies: map[string]map[string]policy_engine.PolicyDocument{ - "inline-user": { - "PutOnly": inlinePolicy, - }, - }, - } - cm := &credential.CredentialManager{Store: store} - - config := &iam_pb.S3ApiConfiguration{ - Identities: []*iam_pb.Identity{ - { - Name: "inline-user", - Actions: []string{"Write:test-bucket"}, - Credentials: []*iam_pb.Credential{ - {AccessKey: "AKIAINLINE0000001", SecretKey: "inline-secret"}, - }, - }, - }, - } - assert.NoError(t, cm.SaveConfiguration(context.Background(), config)) - - iam := &IdentityAccessManagement{credentialManager: cm} - assert.NoError(t, iam.LoadS3ApiConfigurationFromCredentialManager()) - - identity := iam.lookupByIdentityName("inline-user") - if !assert.NotNil(t, identity) { - return - } - assert.Contains(t, identity.PolicyNames, inlinePolicyRuntimeName("inline-user", "PutOnly")) - - putErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodPut), identity, Action(ACTION_WRITE), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrNone, putErrCode) - - deleteErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodDelete), identity, Action(ACTION_WRITE), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrAccessDenied, deleteErrCode) -} - -func TestLoadS3ApiConfigurationFromCredentialManagerHydratesInlinePoliciesThroughPropagatingStore(t *testing.T) { - baseStore := &memory.MemoryStore{} - assert.NoError(t, baseStore.Initialize(nil, "")) - - inlinePolicy := policy_engine.PolicyDocument{ - Version: policy_engine.PolicyVersion2012_10_17, - Statement: []policy_engine.PolicyStatement{ - { - Effect: policy_engine.PolicyEffectAllow, - Action: policy_engine.NewStringOrStringSlice("s3:PutObject"), - Resource: policy_engine.NewStringOrStringSlicePtr("arn:aws:s3:::test-bucket/*"), - }, - }, - } - - upstream := &inlinePolicyRuntimeStore{ - MemoryStore: baseStore, - inlinePolicies: map[string]map[string]policy_engine.PolicyDocument{ - "inline-user": { - "PutOnly": inlinePolicy, - }, - }, - } - wrappedStore := credential.NewPropagatingCredentialStore(upstream, nil, nil) - cm := &credential.CredentialManager{Store: wrappedStore} - - config := &iam_pb.S3ApiConfiguration{ - Identities: []*iam_pb.Identity{ - { - Name: "inline-user", - Actions: []string{"Write:test-bucket"}, - Credentials: []*iam_pb.Credential{ - {AccessKey: "AKIAINLINE0000010", SecretKey: "inline-secret"}, - }, - }, - }, - } - assert.NoError(t, cm.SaveConfiguration(context.Background(), config)) - - iam := &IdentityAccessManagement{credentialManager: cm} - assert.NoError(t, iam.LoadS3ApiConfigurationFromCredentialManager()) - - identity := iam.lookupByIdentityName("inline-user") - if !assert.NotNil(t, identity) { - return - } - assert.Contains(t, identity.PolicyNames, inlinePolicyRuntimeName("inline-user", "PutOnly")) - - putErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodPut), identity, Action(ACTION_WRITE), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrNone, putErrCode) - - deleteErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodDelete), identity, Action(ACTION_WRITE), "test-bucket", "test-object") - assert.Equal(t, s3err.ErrAccessDenied, deleteErrCode) -} - -func TestLoadConfigurationDropsPoliciesStoreDoesNotMutateSourceConfig(t *testing.T) { - baseStore := &memory.MemoryStore{} - require.NoError(t, baseStore.Initialize(nil, "")) - - config := &iam_pb.S3ApiConfiguration{ - Policies: []*iam_pb.Policy{ - {Name: "managedGet", Content: `{"Version":"2012-10-17","Statement":[]}`}, - }, - } - require.NoError(t, baseStore.SaveConfiguration(context.Background(), config)) - - store := &loadConfigurationDropsPoliciesStore{MemoryStore: baseStore} - - stripped, err := store.LoadConfiguration(context.Background()) - require.NoError(t, err) - assert.Nil(t, stripped.Policies) - - source, err := baseStore.LoadConfiguration(context.Background()) - require.NoError(t, err) - require.Len(t, source.Policies, 1) - assert.Equal(t, "managedGet", source.Policies[0].Name) -} - -func TestMergePoliciesIntoConfigurationSkipsNilPolicies(t *testing.T) { - config := &iam_pb.S3ApiConfiguration{ - Policies: []*iam_pb.Policy{ - nil, - {Name: "existing", Content: "old"}, - }, - } - - mergePoliciesIntoConfiguration(config, []*iam_pb.Policy{ - nil, - {Name: "", Content: "ignored"}, - {Name: "existing", Content: "updated"}, - {Name: "new", Content: "created"}, - }) - - require.Len(t, config.Policies, 3) - assert.Nil(t, config.Policies[0]) - assert.Equal(t, "existing", config.Policies[1].Name) - assert.Equal(t, "updated", config.Policies[1].Content) - assert.Equal(t, "new", config.Policies[2].Name) - assert.Equal(t, "created", config.Policies[2].Content) -} - -type LoadS3ApiConfigurationTestCase struct { - pbAccount *iam_pb.Account - pbIdent *iam_pb.Identity - expectIdent *Identity -} - -func TestLoadS3ApiConfiguration(t *testing.T) { - specifiedAccount := Account{ - Id: "specifiedAccountID", - DisplayName: "specifiedAccountName", - EmailAddress: "specifiedAccounEmail@example.com", - } - pbSpecifiedAccount := iam_pb.Account{ - Id: "specifiedAccountID", - DisplayName: "specifiedAccountName", - EmailAddress: "specifiedAccounEmail@example.com", - } - testCases := map[string]*LoadS3ApiConfigurationTestCase{ - "notSpecifyAccountId": { - pbIdent: &iam_pb.Identity{ - Name: "notSpecifyAccountId", - Actions: []string{ - "Read", - "Write", - }, - Credentials: []*iam_pb.Credential{ - { - AccessKey: "some_access_key1", - SecretKey: "some_secret_key2", - }, - }, - }, - expectIdent: &Identity{ - Name: "notSpecifyAccountId", - Account: &AccountAdmin, - PrincipalArn: fmt.Sprintf("arn:aws:iam::%s:user/notSpecifyAccountId", defaultAccountID), - Actions: []Action{ - "Read", - "Write", - }, - Credentials: []*Credential{ - { - AccessKey: "some_access_key1", - SecretKey: "some_secret_key2", - }, - }, - }, - }, - "specifiedAccountID": { - pbAccount: &pbSpecifiedAccount, - pbIdent: &iam_pb.Identity{ - Name: "specifiedAccountID", - Account: &pbSpecifiedAccount, - Actions: []string{ - "Read", - "Write", - }, - }, - expectIdent: &Identity{ - Name: "specifiedAccountID", - Account: &specifiedAccount, - PrincipalArn: fmt.Sprintf("arn:aws:iam::%s:user/specifiedAccountID", defaultAccountID), - Actions: []Action{ - "Read", - "Write", - }, - }, - }, - "anonymous": { - pbIdent: &iam_pb.Identity{ - Name: "anonymous", - Actions: []string{ - "Read", - "Write", - }, - }, - expectIdent: &Identity{ - Name: "anonymous", - Account: &AccountAnonymous, - PrincipalArn: "*", - Actions: []Action{ - "Read", - "Write", - }, - }, - }, - } - - config := &iam_pb.S3ApiConfiguration{ - Identities: make([]*iam_pb.Identity, 0), - } - for _, v := range testCases { - config.Identities = append(config.Identities, v.pbIdent) - if v.pbAccount != nil { - config.Accounts = append(config.Accounts, v.pbAccount) - } - } - - iam := IdentityAccessManagement{} - err := iam.loadS3ApiConfiguration(config) - if err != nil { - return - } - - for _, ident := range iam.identities { - tc := testCases[ident.Name] - if !reflect.DeepEqual(ident, tc.expectIdent) { - t.Errorf("not expect for ident name %s", ident.Name) - } - } -} - -func TestNewIdentityAccessManagementWithStoreEnvVars(t *testing.T) { - // Save original environment - originalAccessKeyId := os.Getenv("AWS_ACCESS_KEY_ID") - originalSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY") - - // Clean up after test - defer func() { - if originalAccessKeyId != "" { - os.Setenv("AWS_ACCESS_KEY_ID", originalAccessKeyId) - } else { - os.Unsetenv("AWS_ACCESS_KEY_ID") - } - if originalSecretAccessKey != "" { - os.Setenv("AWS_SECRET_ACCESS_KEY", originalSecretAccessKey) - } else { - os.Unsetenv("AWS_SECRET_ACCESS_KEY") - } - }() - - tests := []struct { - name string - accessKeyId string - secretAccessKey string - expectEnvIdentity bool - expectedName string - description string - }{ - { - name: "Environment variables used as fallback", - accessKeyId: "AKIA1234567890ABCDEF", - secretAccessKey: "secret123456789012345678901234567890abcdef12", - expectEnvIdentity: true, - expectedName: "admin-AKIA1234", - description: "When no config file and no filer config, environment variables should be used", - }, - { - name: "Short access key fallback", - accessKeyId: "SHORT", - secretAccessKey: "secret123456789012345678901234567890abcdef12", - expectEnvIdentity: true, - expectedName: "admin-SHORT", - description: "Short access keys should work correctly as fallback", - }, - { - name: "No env vars means no identities", - accessKeyId: "", - secretAccessKey: "", - expectEnvIdentity: false, - expectedName: "", - description: "When no env vars and no config, should have no identities", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Reset the memory store to avoid test pollution - if store := credential.Stores[0]; store.GetName() == credential.StoreTypeMemory { - if memStore, ok := store.(interface{ Reset() }); ok { - memStore.Reset() - } - } - - // Set up environment variables - if tt.accessKeyId != "" { - os.Setenv("AWS_ACCESS_KEY_ID", tt.accessKeyId) - } else { - os.Unsetenv("AWS_ACCESS_KEY_ID") - } - if tt.secretAccessKey != "" { - os.Setenv("AWS_SECRET_ACCESS_KEY", tt.secretAccessKey) - } else { - os.Unsetenv("AWS_SECRET_ACCESS_KEY") - } - - // Create IAM instance with memory store for testing (no config file) - option := &S3ApiServerOption{ - Config: "", // No config file - this should trigger environment variable fallback - } - iam := NewIdentityAccessManagementWithStore(option, nil, string(credential.StoreTypeMemory)) - - if tt.expectEnvIdentity { - // Should have exactly one identity from environment variables - assert.Len(t, iam.identities, 1, "Should have exactly one identity from environment variables") - - identity := iam.identities[0] - assert.Equal(t, tt.expectedName, identity.Name, "Identity name should match expected") - assert.Len(t, identity.Credentials, 1, "Should have one credential") - assert.Equal(t, tt.accessKeyId, identity.Credentials[0].AccessKey, "Access key should match environment variable") - assert.Equal(t, tt.secretAccessKey, identity.Credentials[0].SecretKey, "Secret key should match environment variable") - assert.Contains(t, identity.Actions, Action(ACTION_ADMIN), "Should have admin action") - } else { - // When no env vars, should have no identities (since no config file) - assert.Len(t, iam.identities, 0, "Should have no identities when no env vars and no config file") - } - }) - } -} - -// TestConfigFileWithNoIdentitiesAllowsEnvVars tests that when a config file exists -// but contains no identities (e.g., only KMS settings), environment variables should still work. -// This test validates the fix for issue #7311. -func TestConfigFileWithNoIdentitiesAllowsEnvVars(t *testing.T) { - // Reset the memory store to avoid test pollution - if store := credential.Stores[0]; store.GetName() == credential.StoreTypeMemory { - if memStore, ok := store.(interface{ Reset() }); ok { - memStore.Reset() - } - } - - // Set environment variables - testAccessKey := "AKIATEST1234567890AB" - testSecretKey := "testSecret1234567890123456789012345678901234" - t.Setenv("AWS_ACCESS_KEY_ID", testAccessKey) - t.Setenv("AWS_SECRET_ACCESS_KEY", testSecretKey) - - // Create a temporary config file with only KMS settings (no identities) - configContent := `{ - "kms": { - "default": { - "provider": "local", - "config": { - "keyPath": "/tmp/test-key" - } - } - } -}` - tmpFile, err := os.CreateTemp("", "s3-config-*.json") - assert.NoError(t, err, "Should create temp config file") - defer os.Remove(tmpFile.Name()) - - _, err = tmpFile.Write([]byte(configContent)) - assert.NoError(t, err, "Should write config content") - tmpFile.Close() - - // Create IAM instance with config file that has no identities - option := &S3ApiServerOption{ - Config: tmpFile.Name(), - } - iam := NewIdentityAccessManagementWithStore(option, nil, string(credential.StoreTypeMemory)) - - // Should have exactly one identity from environment variables - assert.Len(t, iam.identities, 1, "Should have exactly one identity from environment variables even when config file exists with no identities") - - identity := iam.identities[0] - assert.Equal(t, "admin-AKIATEST", identity.Name, "Identity name should be based on access key") - assert.Len(t, identity.Credentials, 1, "Should have one credential") - assert.Equal(t, testAccessKey, identity.Credentials[0].AccessKey, "Access key should match environment variable") - assert.Equal(t, testSecretKey, identity.Credentials[0].SecretKey, "Secret key should match environment variable") - assert.Contains(t, identity.Actions, Action(ACTION_ADMIN), "Should have admin action") -} - -// TestBucketLevelListPermissions tests that bucket-level List permissions work correctly -// This test validates the fix for issue #7066 -func TestBucketLevelListPermissions(t *testing.T) { - // Test the functionality that was broken in issue #7066 - - t.Run("Bucket Wildcard Permissions", func(t *testing.T) { - // Create identity with bucket-level List permission using wildcards - identity := &Identity{ - Name: "bucket-user", - Actions: []Action{ - "List:mybucket*", - "Read:mybucket*", - "ReadAcp:mybucket*", - "Write:mybucket*", - "WriteAcp:mybucket*", - "Tagging:mybucket*", - }, - } - - // Test cases for bucket-level wildcard permissions - testCases := []struct { - name string - action Action - bucket string - object string - shouldAllow bool - description string - }{ - { - name: "exact bucket match", - action: "List", - bucket: "mybucket", - object: "", - shouldAllow: true, - description: "Should allow access to exact bucket name", - }, - { - name: "bucket with suffix", - action: "List", - bucket: "mybucket-prod", - object: "", - shouldAllow: true, - description: "Should allow access to bucket with matching prefix", - }, - { - name: "bucket with numbers", - action: "List", - bucket: "mybucket123", - object: "", - shouldAllow: true, - description: "Should allow access to bucket with numbers", - }, - { - name: "different bucket", - action: "List", - bucket: "otherbucket", - object: "", - shouldAllow: false, - description: "Should deny access to bucket with different prefix", - }, - { - name: "partial match", - action: "List", - bucket: "notmybucket", - object: "", - shouldAllow: false, - description: "Should deny access to bucket that contains but doesn't start with the prefix", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := identity.CanDo(tc.action, tc.bucket, tc.object) - assert.Equal(t, tc.shouldAllow, result, tc.description) - }) - } - }) - - t.Run("Global List Permission", func(t *testing.T) { - // Create identity with global List permission - identity := &Identity{ - Name: "global-user", - Actions: []Action{ - "List", - }, - } - - // Should allow access to any bucket - testCases := []string{"anybucket", "mybucket", "test-bucket", "prod-data"} - - for _, bucket := range testCases { - result := identity.CanDo("List", bucket, "") - assert.True(t, result, "Global List permission should allow access to bucket %s", bucket) - } - }) - - t.Run("No Wildcard Exact Match", func(t *testing.T) { - // Create identity with exact bucket permission (no wildcard) - identity := &Identity{ - Name: "exact-user", - Actions: []Action{ - "List:specificbucket", - }, - } - - // Should only allow access to the exact bucket - assert.True(t, identity.CanDo("List", "specificbucket", ""), "Should allow access to exact bucket") - assert.False(t, identity.CanDo("List", "specificbucket-test", ""), "Should deny access to bucket with suffix") - assert.False(t, identity.CanDo("List", "otherbucket", ""), "Should deny access to different bucket") - }) - - t.Log("This test validates the fix for issue #7066") - t.Log("Bucket-level List permissions like 'List:bucket*' work correctly") - t.Log("ListBucketsHandler now uses consistent authentication flow") -} - -// TestListBucketsAuthRequest tests that authRequest works correctly for ListBuckets operations -// This test validates that the fix for the regression identified in PR #7067 works correctly -func TestListBucketsAuthRequest(t *testing.T) { - t.Run("ListBuckets special case handling", func(t *testing.T) { - // Create identity with bucket-specific permissions (no global List permission) - identity := &Identity{ - Name: "bucket-user", - Account: &AccountAdmin, - Actions: []Action{ - Action("List:mybucket*"), - Action("Read:mybucket*"), - }, - } - - // Test 1: ListBuckets operation should succeed (bucket = "") - // This would have failed before the fix because CanDo("List", "", "") would return false - // After the fix, it bypasses the CanDo check for ListBuckets operations - - // Simulate what happens in authRequest for ListBuckets: - // action = ACTION_LIST, bucket = "", object = "" - - // Before fix: identity.CanDo(ACTION_LIST, "", "") would fail - // After fix: the CanDo check should be bypassed - - // Test the individual CanDo method to show it would fail without the special case - result := identity.CanDo(Action(ACTION_LIST), "", "") - assert.False(t, result, "CanDo should return false for empty bucket with bucket-specific permissions") - - // Test with a specific bucket that matches the permission - result2 := identity.CanDo(Action(ACTION_LIST), "mybucket", "") - assert.True(t, result2, "CanDo should return true for matching bucket") - - // Test with a specific bucket that doesn't match - result3 := identity.CanDo(Action(ACTION_LIST), "otherbucket", "") - assert.False(t, result3, "CanDo should return false for non-matching bucket") - }) - - t.Run("Object listing maintains permission enforcement", func(t *testing.T) { - // Create identity with bucket-specific permissions - identity := &Identity{ - Name: "bucket-user", - Account: &AccountAdmin, - Actions: []Action{ - Action("List:mybucket*"), - }, - } - - // For object listing operations, the normal permission checks should still apply - // These operations have a specific bucket in the URL - - // Should succeed for allowed bucket - result1 := identity.CanDo(Action(ACTION_LIST), "mybucket", "prefix/") - assert.True(t, result1, "Should allow listing objects in permitted bucket") - - result2 := identity.CanDo(Action(ACTION_LIST), "mybucket-prod", "") - assert.True(t, result2, "Should allow listing objects in wildcard-matched bucket") - - // Should fail for disallowed bucket - result3 := identity.CanDo(Action(ACTION_LIST), "otherbucket", "") - assert.False(t, result3, "Should deny listing objects in non-permitted bucket") - }) - - t.Log("This test validates the fix for the regression identified in PR #7067") - t.Log("ListBuckets operation bypasses global permission check when bucket is empty") - t.Log("Object listing still properly enforces bucket-level permissions") -} - -// TestSignatureVerificationDoesNotCheckPermissions tests that signature verification -// only validates the signature and identity, not permissions. Permissions should be -// checked later in authRequest based on the actual operation. -// This test validates the fix for issue #7334 -func TestSignatureVerificationDoesNotCheckPermissions(t *testing.T) { - t.Run("List-only user can authenticate via signature", func(t *testing.T) { - // Create IAM with a user that only has List permissions on specific buckets - iam := &IdentityAccessManagement{ - hashes: make(map[string]*sync.Pool), - hashCounters: make(map[string]*int32), - } - - err := iam.loadS3ApiConfiguration(&iam_pb.S3ApiConfiguration{ - Identities: []*iam_pb.Identity{ - { - Name: "list-only-user", - Credentials: []*iam_pb.Credential{ - { - AccessKey: "list_access_key", - SecretKey: "list_secret_key", - }, - }, - Actions: []string{ - "List:bucket-123", - "Read:bucket-123", - }, - }, - }, - }) - assert.NoError(t, err) - - // Before the fix, signature verification would fail because it checked for Write permission - // After the fix, signature verification should succeed (only checking signature validity) - // The actual permission check happens later in authRequest with the correct action - - // The user should be able to authenticate (signature verification passes) - // But authorization for specific actions is checked separately - identity, cred, found := iam.lookupByAccessKey("list_access_key") - assert.True(t, found, "Should find the user by access key") - assert.Equal(t, "list-only-user", identity.Name) - assert.Equal(t, "list_secret_key", cred.SecretKey) - - // User should have the correct permissions - assert.True(t, identity.CanDo(Action(ACTION_LIST), "bucket-123", "")) - assert.True(t, identity.CanDo(Action(ACTION_READ), "bucket-123", "")) - - // User should NOT have write permissions - assert.False(t, identity.CanDo(Action(ACTION_WRITE), "bucket-123", "")) - }) - - t.Log("This test validates the fix for issue #7334") - t.Log("Signature verification no longer checks for Write permission") - t.Log("This allows list-only and read-only users to authenticate via AWS Signature V4") -} - -func TestStaticIdentityProtection(t *testing.T) { - iam := NewIdentityAccessManagement(&S3ApiServerOption{}, nil) - - // Add a static identity - staticIdent := &Identity{ - Name: "static-user", - IsStatic: true, - } - iam.m.Lock() - if iam.nameToIdentity == nil { - iam.nameToIdentity = make(map[string]*Identity) - } - iam.identities = append(iam.identities, staticIdent) - iam.nameToIdentity[staticIdent.Name] = staticIdent - iam.m.Unlock() - - // Add a dynamic identity - dynamicIdent := &Identity{ - Name: "dynamic-user", - IsStatic: false, - } - iam.m.Lock() - iam.identities = append(iam.identities, dynamicIdent) - iam.nameToIdentity[dynamicIdent.Name] = dynamicIdent - iam.m.Unlock() - - // Try to remove static identity - iam.RemoveIdentity("static-user") - - // Verify static identity still exists - iam.m.RLock() - _, ok := iam.nameToIdentity["static-user"] - iam.m.RUnlock() - assert.True(t, ok, "Static identity should not be removed") - - // Try to remove dynamic identity - iam.RemoveIdentity("dynamic-user") - - // Verify dynamic identity is removed - iam.m.RLock() - _, ok = iam.nameToIdentity["dynamic-user"] - iam.m.RUnlock() - assert.False(t, ok, "Dynamic identity should have been removed") -} - -func TestParseExternalUrlToHost(t *testing.T) { - tests := []struct { - name string - input string - expected string - expectErr bool - }{ - { - name: "empty string", - input: "", - expected: "", - }, - { - name: "HTTPS with default port stripped", - input: "https://api.example.com:443", - expected: "api.example.com", - }, - { - name: "HTTP with default port stripped", - input: "http://api.example.com:80", - expected: "api.example.com", - }, - { - name: "HTTPS with non-standard port preserved", - input: "https://api.example.com:9000", - expected: "api.example.com:9000", - }, - { - name: "HTTP with non-standard port preserved", - input: "http://api.example.com:8080", - expected: "api.example.com:8080", - }, - { - name: "HTTPS without port", - input: "https://api.example.com", - expected: "api.example.com", - }, - { - name: "HTTP without port", - input: "http://api.example.com", - expected: "api.example.com", - }, - { - name: "IPv6 with non-standard port", - input: "https://[::1]:9000", - expected: "[::1]:9000", - }, - { - name: "IPv6 with default HTTPS port stripped", - input: "https://[::1]:443", - expected: "::1", - }, - { - name: "IPv6 without port", - input: "https://[::1]", - expected: "::1", - }, - { - name: "invalid URL", - input: "://not-a-url", - expectErr: true, - }, - { - name: "missing host", - input: "https://", - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := parseExternalUrlToHost(tt.input) - if tt.expectErr { - assert.Error(t, err) - return - } - assert.NoError(t, err) - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/weed/s3api/bucket_metadata.go b/weed/s3api/bucket_metadata.go index f13fbb949..ce0162be8 100644 --- a/weed/s3api/bucket_metadata.go +++ b/weed/s3api/bucket_metadata.go @@ -224,12 +224,6 @@ func (r *BucketRegistry) removeMetadataCache(bucket string) { delete(r.metadataCache, bucket) } -func (r *BucketRegistry) markNotFound(bucket string) { - r.notFoundLock.Lock() - defer r.notFoundLock.Unlock() - r.notFound[bucket] = struct{}{} -} - func (r *BucketRegistry) unMarkNotFound(bucket string) { r.notFoundLock.Lock() defer r.notFoundLock.Unlock() diff --git a/weed/s3api/filer_multipart_test.go b/weed/s3api/filer_multipart_test.go deleted file mode 100644 index 92ecbeba9..000000000 --- a/weed/s3api/filer_multipart_test.go +++ /dev/null @@ -1,267 +0,0 @@ -package s3api - -import ( - "encoding/hex" - "net/http" - "testing" - "time" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" - "github.com/stretchr/testify/assert" -) - -func TestInitiateMultipartUploadResult(t *testing.T) { - - expected := ` -example-bucketexample-objectVXBsb2FkIElEIGZvciA2aWWpbmcncyBteS1tb3ZpZS5tMnRzIHVwbG9hZA` - response := &InitiateMultipartUploadResult{ - CreateMultipartUploadOutput: s3.CreateMultipartUploadOutput{ - Bucket: aws.String("example-bucket"), - Key: aws.String("example-object"), - UploadId: aws.String("VXBsb2FkIElEIGZvciA2aWWpbmcncyBteS1tb3ZpZS5tMnRzIHVwbG9hZA"), - }, - } - - encoded := string(s3err.EncodeXMLResponse(response)) - if encoded != expected { - t.Errorf("unexpected output: %s\nexpecting:%s", encoded, expected) - } - -} - -func TestListPartsResult(t *testing.T) { - - expected := ` -"12345678"1970-01-01T00:00:00Z1123` - response := &ListPartsResult{ - Part: []*s3.Part{ - { - PartNumber: aws.Int64(int64(1)), - LastModified: aws.Time(time.Unix(0, 0).UTC()), - Size: aws.Int64(int64(123)), - ETag: aws.String("\"12345678\""), - }, - }, - } - - encoded := string(s3err.EncodeXMLResponse(response)) - if encoded != expected { - t.Errorf("unexpected output: %s\nexpecting:%s", encoded, expected) - } - -} - -func TestCompleteMultipartResultIncludesVersionId(t *testing.T) { - r := &http.Request{Host: "localhost", Header: make(http.Header)} - input := &s3.CompleteMultipartUploadInput{ - Bucket: aws.String("example-bucket"), - Key: aws.String("example-object"), - } - - entry := &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte("version-123"), - }, - } - - result := completeMultipartResult(r, input, "\"etag-value\"", entry) - if assert.NotNil(t, result.VersionId) { - assert.Equal(t, "version-123", *result.VersionId) - } -} - -func TestCompleteMultipartResultOmitsNullVersionId(t *testing.T) { - r := &http.Request{Host: "localhost", Header: make(http.Header)} - input := &s3.CompleteMultipartUploadInput{ - Bucket: aws.String("example-bucket"), - Key: aws.String("example-object"), - } - - entry := &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte("null"), - }, - } - - result := completeMultipartResult(r, input, "\"etag-value\"", entry) - assert.Nil(t, result.VersionId) -} - -func Test_parsePartNumber(t *testing.T) { - tests := []struct { - name string - fileName string - partNum int - }{ - { - "first", - "0001_uuid.part", - 1, - }, - { - "second", - "0002.part", - 2, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - partNumber, _ := parsePartNumber(tt.fileName) - assert.Equalf(t, tt.partNum, partNumber, "parsePartNumber(%v)", tt.fileName) - }) - } -} - -func TestGetEntryNameAndDir(t *testing.T) { - s3a := &S3ApiServer{ - option: &S3ApiServerOption{ - BucketsPath: "/buckets", - }, - } - - tests := []struct { - name string - bucket string - key string - expectedName string - expectedDirEnd string // We check the suffix since dir includes BucketsPath - }{ - { - name: "simple file at root", - bucket: "test-bucket", - key: "/file.txt", - expectedName: "file.txt", - expectedDirEnd: "/buckets/test-bucket", - }, - { - name: "file in subdirectory", - bucket: "test-bucket", - key: "/folder/file.txt", - expectedName: "file.txt", - expectedDirEnd: "/buckets/test-bucket/folder", - }, - { - name: "file in nested subdirectory", - bucket: "test-bucket", - key: "/folder/subfolder/file.txt", - expectedName: "file.txt", - expectedDirEnd: "/buckets/test-bucket/folder/subfolder", - }, - { - name: "key without leading slash", - bucket: "test-bucket", - key: "folder/file.txt", - expectedName: "file.txt", - expectedDirEnd: "/buckets/test-bucket/folder", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - input := &s3.CompleteMultipartUploadInput{ - Bucket: aws.String(tt.bucket), - Key: aws.String(tt.key), - } - entryName, dirName := s3a.getEntryNameAndDir(input) - assert.Equal(t, tt.expectedName, entryName, "entry name mismatch") - assert.Equal(t, tt.expectedDirEnd, dirName, "directory mismatch") - }) - } -} - -func TestValidateCompletePartETag(t *testing.T) { - t.Run("matches_composite_etag_from_extended", func(t *testing.T) { - entry := &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte("ea58527f14c6ae0dd53089966e44941b-2"), - }, - Attributes: &filer_pb.FuseAttributes{}, - } - match, invalid, part, stored := validateCompletePartETag(`"ea58527f14c6ae0dd53089966e44941b-2"`, entry) - assert.True(t, match) - assert.False(t, invalid) - assert.Equal(t, "ea58527f14c6ae0dd53089966e44941b-2", part) - assert.Equal(t, "ea58527f14c6ae0dd53089966e44941b-2", stored) - }) - - t.Run("matches_md5_from_attributes", func(t *testing.T) { - md5Bytes, err := hex.DecodeString("324b2665939fde5b8678d3a8b5c46970") - assert.NoError(t, err) - entry := &filer_pb.Entry{ - Attributes: &filer_pb.FuseAttributes{ - Md5: md5Bytes, - }, - } - match, invalid, part, stored := validateCompletePartETag("324b2665939fde5b8678d3a8b5c46970", entry) - assert.True(t, match) - assert.False(t, invalid) - assert.Equal(t, "324b2665939fde5b8678d3a8b5c46970", part) - assert.Equal(t, "324b2665939fde5b8678d3a8b5c46970", stored) - }) - - t.Run("detects_mismatch", func(t *testing.T) { - entry := &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte("67fdd2e302502ff9f9b606bc036e6892-2"), - }, - Attributes: &filer_pb.FuseAttributes{}, - } - match, invalid, _, _ := validateCompletePartETag("686f7d71bacdcd539dd4e17a0d7f1e5f-2", entry) - assert.False(t, match) - assert.False(t, invalid) - }) - - t.Run("flags_empty_client_etag_as_invalid", func(t *testing.T) { - entry := &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte("67fdd2e302502ff9f9b606bc036e6892-2"), - }, - Attributes: &filer_pb.FuseAttributes{}, - } - match, invalid, _, _ := validateCompletePartETag(`""`, entry) - assert.False(t, match) - assert.True(t, invalid) - }) -} - -func TestCompleteMultipartUploadRejectsOutOfOrderParts(t *testing.T) { - s3a := NewS3ApiServerForTest() - input := &s3.CompleteMultipartUploadInput{ - Bucket: aws.String("bucket"), - Key: aws.String("object"), - UploadId: aws.String("upload"), - } - parts := &CompleteMultipartUpload{ - Parts: []CompletedPart{ - {PartNumber: 2, ETag: "\"etag-2\""}, - {PartNumber: 1, ETag: "\"etag-1\""}, - }, - } - - result, errCode := s3a.completeMultipartUpload(&http.Request{Header: make(http.Header)}, input, parts) - assert.Nil(t, result) - assert.Equal(t, s3err.ErrInvalidPartOrder, errCode) -} - -func TestCompleteMultipartUploadAllowsDuplicatePartNumbers(t *testing.T) { - s3a := NewS3ApiServerForTest() - input := &s3.CompleteMultipartUploadInput{ - Bucket: aws.String("bucket"), - Key: aws.String("object"), - UploadId: aws.String("upload"), - } - parts := &CompleteMultipartUpload{ - Parts: []CompletedPart{ - {PartNumber: 1, ETag: "\"etag-older\""}, - {PartNumber: 1, ETag: "\"etag-newer\""}, - }, - } - - result, errCode := s3a.completeMultipartUpload(&http.Request{Header: make(http.Header)}, input, parts) - assert.Nil(t, result) - assert.Equal(t, s3err.ErrNoSuchUpload, errCode) -} diff --git a/weed/s3api/iam_optional_test.go b/weed/s3api/iam_optional_test.go index 583b14791..4d35e4df9 100644 --- a/weed/s3api/iam_optional_test.go +++ b/weed/s3api/iam_optional_test.go @@ -3,9 +3,22 @@ package s3api import ( "testing" + "github.com/seaweedfs/seaweedfs/weed/credential" "github.com/stretchr/testify/assert" ) +// resetMemoryStore resets the shared in-memory credential store so that tests +// that rely on an empty store are not polluted by earlier tests. +func resetMemoryStore() { + for _, store := range credential.Stores { + if store.GetName() == credential.StoreTypeMemory { + if resettable, ok := store.(interface{ Reset() }); ok { + resettable.Reset() + } + } + } +} + func TestLoadIAMManagerWithNoConfig(t *testing.T) { // Verify that IAM can be initialized without any config option := &S3ApiServerOption{ @@ -17,6 +30,9 @@ func TestLoadIAMManagerWithNoConfig(t *testing.T) { } func TestLoadIAMManagerFromConfig_EmptyConfigWithFallbackKey(t *testing.T) { + // Reset the shared memory store to avoid state leaking from other tests. + resetMemoryStore() + // Initialize IAM with empty config — no anonymous identity is configured, // so LookupAnonymous should return not-found. option := &S3ApiServerOption{ diff --git a/weed/s3api/iceberg/commit_helpers.go b/weed/s3api/iceberg/commit_helpers.go index 958cc753f..742f24cae 100644 --- a/weed/s3api/iceberg/commit_helpers.go +++ b/weed/s3api/iceberg/commit_helpers.go @@ -6,8 +6,6 @@ import ( "errors" "fmt" "net/http" - "path" - "strconv" "strings" "github.com/apache/iceberg-go/table" @@ -25,10 +23,6 @@ type icebergRequestError struct { message string } -func (e *icebergRequestError) Error() string { - return e.message -} - type createOnCommitInput struct { bucketARN string markerBucket string @@ -88,19 +82,6 @@ func isS3TablesAlreadyExists(err error) bool { (tableErr.Type == s3tables.ErrCodeTableAlreadyExists || tableErr.Type == s3tables.ErrCodeNamespaceAlreadyExists || strings.Contains(strings.ToLower(tableErr.Message), "already exists")) } -func parseMetadataVersionFromLocation(metadataLocation string) int { - base := path.Base(metadataLocation) - if !strings.HasPrefix(base, "v") || !strings.HasSuffix(base, ".metadata.json") { - return 0 - } - rawVersion := strings.TrimPrefix(strings.TrimSuffix(base, ".metadata.json"), "v") - version, err := strconv.Atoi(rawVersion) - if err != nil || version <= 0 { - return 0 - } - return version -} - func (s *Server) finalizeCreateOnCommit(ctx context.Context, input createOnCommitInput) (*CommitTableResponse, *icebergRequestError) { builder, err := table.MetadataBuilderFromBase(input.baseMetadata, input.baseMetadataLoc) if err != nil { diff --git a/weed/s3api/iceberg/iceberg_stage_create_helpers_test.go b/weed/s3api/iceberg/iceberg_stage_create_helpers_test.go deleted file mode 100644 index 81fd7aba9..000000000 --- a/weed/s3api/iceberg/iceberg_stage_create_helpers_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package iceberg - -import ( - "strings" - "testing" - - "github.com/apache/iceberg-go/table" - "github.com/google/uuid" -) - -func TestHasAssertCreateRequirement(t *testing.T) { - requirements := table.Requirements{table.AssertCreate()} - if !hasAssertCreateRequirement(requirements) { - t.Fatalf("hasAssertCreateRequirement() = false, want true") - } - - requirements = table.Requirements{table.AssertDefaultSortOrderID(0)} - if hasAssertCreateRequirement(requirements) { - t.Fatalf("hasAssertCreateRequirement() = true, want false") - } -} - -func TestParseMetadataVersionFromLocation(t *testing.T) { - testCases := []struct { - location string - version int - }{ - {location: "s3://b/ns/t/metadata/v1.metadata.json", version: 1}, - {location: "s3://b/ns/t/metadata/v25.metadata.json", version: 25}, - {location: "v1.metadata.json", version: 1}, - {location: "s3://b/ns/t/metadata/v0.metadata.json", version: 0}, - {location: "s3://b/ns/t/metadata/v-1.metadata.json", version: 0}, - {location: "s3://b/ns/t/metadata/vABC.metadata.json", version: 0}, - {location: "s3://b/ns/t/metadata/current.json", version: 0}, - {location: "", version: 0}, - } - - for _, tc := range testCases { - t.Run(tc.location, func(t *testing.T) { - if got := parseMetadataVersionFromLocation(tc.location); got != tc.version { - t.Errorf("parseMetadataVersionFromLocation(%q) = %d, want %d", tc.location, got, tc.version) - } - }) - } -} - -func TestStageCreateMarkerNamespaceKey(t *testing.T) { - key := stageCreateMarkerNamespaceKey([]string{"a", "b"}) - if key == "a\x1fb" { - t.Fatalf("stageCreateMarkerNamespaceKey() returned unescaped namespace key %q", key) - } - if !strings.Contains(key, "%1F") { - t.Fatalf("stageCreateMarkerNamespaceKey() = %q, want escaped unit separator", key) - } -} - -func TestStageCreateMarkerDir(t *testing.T) { - dir := stageCreateMarkerDir("warehouse", []string{"ns"}, "orders") - if !strings.Contains(dir, stageCreateMarkerDirName) { - t.Fatalf("stageCreateMarkerDir() = %q, want marker dir segment %q", dir, stageCreateMarkerDirName) - } - if !strings.HasSuffix(dir, "/orders") { - t.Fatalf("stageCreateMarkerDir() = %q, want suffix /orders", dir) - } -} - -func TestStageCreateStagedTablePath(t *testing.T) { - tableUUID := uuid.MustParse("11111111-2222-3333-4444-555555555555") - stagedPath := stageCreateStagedTablePath([]string{"ns"}, "orders", tableUUID) - if !strings.Contains(stagedPath, stageCreateMarkerDirName) { - t.Fatalf("stageCreateStagedTablePath() = %q, want marker dir segment %q", stagedPath, stageCreateMarkerDirName) - } - if !strings.HasSuffix(stagedPath, "/"+tableUUID.String()) { - t.Fatalf("stageCreateStagedTablePath() = %q, want UUID suffix %q", stagedPath, tableUUID.String()) - } -} diff --git a/weed/s3api/object_lock_utils.go b/weed/s3api/object_lock_utils.go index 9455cb12c..d58bc7b8e 100644 --- a/weed/s3api/object_lock_utils.go +++ b/weed/s3api/object_lock_utils.go @@ -2,8 +2,6 @@ package s3api import ( "context" - "encoding/xml" - "fmt" "strconv" "time" @@ -35,21 +33,6 @@ func StoreVersioningInExtended(entry *filer_pb.Entry, enabled bool) error { return nil } -// LoadVersioningFromExtended loads versioning configuration from entry extended attributes -func LoadVersioningFromExtended(entry *filer_pb.Entry) (bool, bool) { - if entry == nil || entry.Extended == nil { - return false, false // not found, default to suspended - } - - // Check for S3 API compatible key - if versioningBytes, exists := entry.Extended[s3_constants.ExtVersioningKey]; exists { - enabled := string(versioningBytes) == s3_constants.VersioningEnabled - return enabled, true - } - - return false, false // not found -} - // GetVersioningStatus returns the versioning status as a string: "", "Enabled", or "Suspended" // Empty string means versioning was never enabled func GetVersioningStatus(entry *filer_pb.Entry) string { @@ -90,15 +73,6 @@ func CreateObjectLockConfiguration(enabled bool, mode string, days int, years in return config } -// ObjectLockConfigurationToXML converts ObjectLockConfiguration to XML bytes -func ObjectLockConfigurationToXML(config *ObjectLockConfiguration) ([]byte, error) { - if config == nil { - return nil, fmt.Errorf("object lock configuration is nil") - } - - return xml.Marshal(config) -} - // StoreObjectLockConfigurationInExtended stores Object Lock configuration in entry extended attributes func StoreObjectLockConfigurationInExtended(entry *filer_pb.Entry, config *ObjectLockConfiguration) error { if entry.Extended == nil { @@ -379,18 +353,6 @@ func validateDefaultRetention(retention *DefaultRetention) error { return nil } -// ==================================================================== -// SHARED OBJECT LOCK CHECKING FUNCTIONS -// ==================================================================== -// These functions delegate to s3_objectlock package to avoid code duplication. -// They are kept here for backward compatibility with existing callers. - -// EntryHasActiveLock checks if an entry has an active retention or legal hold -// Delegates to s3_objectlock.EntryHasActiveLock -func EntryHasActiveLock(entry *filer_pb.Entry, currentTime time.Time) bool { - return s3_objectlock.EntryHasActiveLock(entry, currentTime) -} - // HasObjectsWithActiveLocks checks if any objects in the bucket have active retention or legal hold // Delegates to s3_objectlock.HasObjectsWithActiveLocks func HasObjectsWithActiveLocks(ctx context.Context, client filer_pb.SeaweedFilerClient, bucketPath string) (bool, error) { diff --git a/weed/s3api/policy/post-policy.go b/weed/s3api/policy/post-policy.go deleted file mode 100644 index 3250cdf49..000000000 --- a/weed/s3api/policy/post-policy.go +++ /dev/null @@ -1,321 +0,0 @@ -package policy - -/* - * MinIO Go Library for Amazon S3 Compatible Cloud Storage - * Copyright 2015-2017 MinIO, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import ( - "encoding/base64" - "fmt" - "net/http" - "strings" - "time" - - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" -) - -// expirationDateFormat date format for expiration key in json policy. -const expirationDateFormat = "2006-01-02T15:04:05.999Z" - -// policyCondition explanation: -// http://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-HTTPPOSTConstructPolicy.html -// -// Example: -// -// policyCondition { -// matchType: "$eq", -// key: "$Content-Type", -// value: "image/png", -// } -type policyCondition struct { - matchType string - condition string - value string -} - -// PostPolicy - Provides strict static type conversion and validation -// for Amazon S3's POST policy JSON string. -type PostPolicy struct { - // Expiration date and time of the POST policy. - expiration time.Time - // Collection of different policy conditions. - conditions []policyCondition - // ContentLengthRange minimum and maximum allowable size for the - // uploaded content. - contentLengthRange struct { - min int64 - max int64 - } - - // Post form data. - formData map[string]string -} - -// NewPostPolicy - Instantiate new post policy. -func NewPostPolicy() *PostPolicy { - p := &PostPolicy{} - p.conditions = make([]policyCondition, 0) - p.formData = make(map[string]string) - return p -} - -// SetExpires - Sets expiration time for the new policy. -func (p *PostPolicy) SetExpires(t time.Time) error { - if t.IsZero() { - return errInvalidArgument("No expiry time set.") - } - p.expiration = t - return nil -} - -// SetKey - Sets an object name for the policy based upload. -func (p *PostPolicy) SetKey(key string) error { - if strings.TrimSpace(key) == "" || key == "" { - return errInvalidArgument("Object name is empty.") - } - policyCond := policyCondition{ - matchType: "eq", - condition: "$key", - value: key, - } - if err := p.addNewPolicy(policyCond); err != nil { - return err - } - p.formData["key"] = key - return nil -} - -// SetKeyStartsWith - Sets an object name that an policy based upload -// can start with. -func (p *PostPolicy) SetKeyStartsWith(keyStartsWith string) error { - if strings.TrimSpace(keyStartsWith) == "" || keyStartsWith == "" { - return errInvalidArgument("Object prefix is empty.") - } - policyCond := policyCondition{ - matchType: "starts-with", - condition: "$key", - value: keyStartsWith, - } - if err := p.addNewPolicy(policyCond); err != nil { - return err - } - p.formData["key"] = keyStartsWith - return nil -} - -// SetBucket - Sets bucket at which objects will be uploaded to. -func (p *PostPolicy) SetBucket(bucketName string) error { - if strings.TrimSpace(bucketName) == "" || bucketName == "" { - return errInvalidArgument("Bucket name is empty.") - } - policyCond := policyCondition{ - matchType: "eq", - condition: "$bucket", - value: bucketName, - } - if err := p.addNewPolicy(policyCond); err != nil { - return err - } - p.formData["bucket"] = bucketName - return nil -} - -// SetCondition - Sets condition for credentials, date and algorithm -func (p *PostPolicy) SetCondition(matchType, condition, value string) error { - if strings.TrimSpace(value) == "" || value == "" { - return errInvalidArgument("No value specified for condition") - } - - policyCond := policyCondition{ - matchType: matchType, - condition: "$" + condition, - value: value, - } - if condition == "X-Amz-Credential" || condition == "X-Amz-Date" || condition == "X-Amz-Algorithm" { - if err := p.addNewPolicy(policyCond); err != nil { - return err - } - p.formData[condition] = value - return nil - } - return errInvalidArgument("Invalid condition in policy") -} - -// SetContentType - Sets content-type of the object for this policy -// based upload. -func (p *PostPolicy) SetContentType(contentType string) error { - if strings.TrimSpace(contentType) == "" || contentType == "" { - return errInvalidArgument("No content type specified.") - } - policyCond := policyCondition{ - matchType: "eq", - condition: "$Content-Type", - value: contentType, - } - if err := p.addNewPolicy(policyCond); err != nil { - return err - } - p.formData["Content-Type"] = contentType - return nil -} - -// SetContentLengthRange - Set new min and max content length -// condition for all incoming uploads. -func (p *PostPolicy) SetContentLengthRange(min, max int64) error { - if min > max { - return errInvalidArgument("Minimum limit is larger than maximum limit.") - } - if min < 0 { - return errInvalidArgument("Minimum limit cannot be negative.") - } - if max < 0 { - return errInvalidArgument("Maximum limit cannot be negative.") - } - p.contentLengthRange.min = min - p.contentLengthRange.max = max - return nil -} - -// SetSuccessActionRedirect - Sets the redirect success url of the object for this policy -// based upload. -func (p *PostPolicy) SetSuccessActionRedirect(redirect string) error { - if strings.TrimSpace(redirect) == "" || redirect == "" { - return errInvalidArgument("Redirect is empty") - } - policyCond := policyCondition{ - matchType: "eq", - condition: "$success_action_redirect", - value: redirect, - } - if err := p.addNewPolicy(policyCond); err != nil { - return err - } - p.formData["success_action_redirect"] = redirect - return nil -} - -// SetSuccessStatusAction - Sets the status success code of the object for this policy -// based upload. -func (p *PostPolicy) SetSuccessStatusAction(status string) error { - if strings.TrimSpace(status) == "" || status == "" { - return errInvalidArgument("Status is empty") - } - policyCond := policyCondition{ - matchType: "eq", - condition: "$success_action_status", - value: status, - } - if err := p.addNewPolicy(policyCond); err != nil { - return err - } - p.formData["success_action_status"] = status - return nil -} - -// SetUserMetadata - Set user metadata as a key/value couple. -// Can be retrieved through a HEAD request or an event. -func (p *PostPolicy) SetUserMetadata(key string, value string) error { - if strings.TrimSpace(key) == "" || key == "" { - return errInvalidArgument("Key is empty") - } - if strings.TrimSpace(value) == "" || value == "" { - return errInvalidArgument("Value is empty") - } - headerName := fmt.Sprintf("x-amz-meta-%s", key) - policyCond := policyCondition{ - matchType: "eq", - condition: fmt.Sprintf("$%s", headerName), - value: value, - } - if err := p.addNewPolicy(policyCond); err != nil { - return err - } - p.formData[headerName] = value - return nil -} - -// SetUserData - Set user data as a key/value couple. -// Can be retrieved through a HEAD request or an event. -func (p *PostPolicy) SetUserData(key string, value string) error { - if key == "" { - return errInvalidArgument("Key is empty") - } - if value == "" { - return errInvalidArgument("Value is empty") - } - headerName := fmt.Sprintf("x-amz-%s", key) - policyCond := policyCondition{ - matchType: "eq", - condition: fmt.Sprintf("$%s", headerName), - value: value, - } - if err := p.addNewPolicy(policyCond); err != nil { - return err - } - p.formData[headerName] = value - return nil -} - -// addNewPolicy - internal helper to validate adding new policies. -func (p *PostPolicy) addNewPolicy(policyCond policyCondition) error { - if policyCond.matchType == "" || policyCond.condition == "" || policyCond.value == "" { - return errInvalidArgument("Policy fields are empty.") - } - p.conditions = append(p.conditions, policyCond) - return nil -} - -// String function for printing policy in json formatted string. -func (p PostPolicy) String() string { - return string(p.marshalJSON()) -} - -// marshalJSON - Provides Marshaled JSON in bytes. -func (p PostPolicy) marshalJSON() []byte { - expirationStr := `"expiration":"` + p.expiration.Format(expirationDateFormat) + `"` - var conditionsStr string - conditions := []string{} - for _, po := range p.conditions { - conditions = append(conditions, fmt.Sprintf("[\"%s\",\"%s\",\"%s\"]", po.matchType, po.condition, po.value)) - } - if p.contentLengthRange.min != 0 || p.contentLengthRange.max != 0 { - conditions = append(conditions, fmt.Sprintf("[\"content-length-range\", %d, %d]", - p.contentLengthRange.min, p.contentLengthRange.max)) - } - if len(conditions) > 0 { - conditionsStr = `"conditions":[` + strings.Join(conditions, ",") + "]" - } - retStr := "{" - retStr = retStr + expirationStr + "," - retStr = retStr + conditionsStr - retStr = retStr + "}" - return []byte(retStr) -} - -// base64 - Produces base64 of PostPolicy's Marshaled json. -func (p PostPolicy) base64() string { - return base64.StdEncoding.EncodeToString(p.marshalJSON()) -} - -// errInvalidArgument - Invalid argument response. -func errInvalidArgument(message string) error { - return s3err.RESTErrorResponse{ - StatusCode: http.StatusBadRequest, - Code: "InvalidArgument", - Message: message, - RequestID: "client", - } -} diff --git a/weed/s3api/policy/postpolicyform_test.go b/weed/s3api/policy/postpolicyform_test.go deleted file mode 100644 index 1a9d78b0e..000000000 --- a/weed/s3api/policy/postpolicyform_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package policy - -/* - * MinIO Cloud Storage, (C) 2016 MinIO, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import ( - "encoding/base64" - "fmt" - "net/http" - "testing" - "time" -) - -// Test Post Policy parsing and checking conditions -func TestPostPolicyForm(t *testing.T) { - pp := NewPostPolicy() - pp.SetBucket("testbucket") - pp.SetContentType("image/jpeg") - pp.SetUserMetadata("uuid", "14365123651274") - pp.SetKeyStartsWith("user/user1/filename") - pp.SetContentLengthRange(1048579, 10485760) - pp.SetSuccessStatusAction("201") - - type testCase struct { - Bucket string - Key string - XAmzDate string - XAmzAlgorithm string - XAmzCredential string - XAmzMetaUUID string - ContentType string - SuccessActionStatus string - Policy string - Expired bool - expectedErr error - } - - testCases := []testCase{ - // Everything is fine with this test - {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "14365123651274", SuccessActionStatus: "201", XAmzCredential: "KVGKMDUQ23TCZXTLTHLP/20160727/us-east-1/s3/aws4_request", XAmzDate: "20160727T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: nil}, - // Expired policy document - {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "14365123651274", SuccessActionStatus: "201", XAmzCredential: "KVGKMDUQ23TCZXTLTHLP/20160727/us-east-1/s3/aws4_request", XAmzDate: "20160727T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", Expired: true, expectedErr: fmt.Errorf("Invalid according to Policy: Policy expired")}, - // Different AMZ date - {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "14365123651274", XAmzDate: "2017T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")}, - // Key which doesn't start with user/user1/filename - {Bucket: "testbucket", Key: "myfile.txt", XAmzDate: "20160727T000000Z", XAmzMetaUUID: "14365123651274", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")}, - // Incorrect bucket name. - {Bucket: "incorrect", Key: "user/user1/filename/myfile.txt", XAmzMetaUUID: "14365123651274", XAmzDate: "20160727T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")}, - // Incorrect key name - {Bucket: "testbucket", Key: "incorrect", XAmzDate: "20160727T000000Z", XAmzMetaUUID: "14365123651274", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")}, - // Incorrect date - {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "14365123651274", XAmzDate: "incorrect", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")}, - // Incorrect ContentType - {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "14365123651274", XAmzDate: "20160727T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "incorrect", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")}, - // Incorrect Metadata - {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "151274", SuccessActionStatus: "201", XAmzCredential: "KVGKMDUQ23TCZXTLTHLP/20160727/us-east-1/s3/aws4_request", XAmzDate: "20160727T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed: [eq, $x-amz-meta-uuid, 14365123651274]")}, - } - // Validate all the test cases. - for i, tt := range testCases { - formValues := make(http.Header) - formValues.Set("Bucket", tt.Bucket) - formValues.Set("Key", tt.Key) - formValues.Set("Content-Type", tt.ContentType) - formValues.Set("X-Amz-Date", tt.XAmzDate) - formValues.Set("X-Amz-Meta-Uuid", tt.XAmzMetaUUID) - formValues.Set("X-Amz-Algorithm", tt.XAmzAlgorithm) - formValues.Set("X-Amz-Credential", tt.XAmzCredential) - if tt.Expired { - // Expired already. - pp.SetExpires(time.Now().UTC().AddDate(0, 0, -10)) - } else { - // Expires in 10 days. - pp.SetExpires(time.Now().UTC().AddDate(0, 0, 10)) - } - - formValues.Set("Policy", base64.StdEncoding.EncodeToString([]byte(pp.String()))) - formValues.Set("Success_action_status", tt.SuccessActionStatus) - policyBytes, err := base64.StdEncoding.DecodeString(base64.StdEncoding.EncodeToString([]byte(pp.String()))) - if err != nil { - t.Fatal(err) - } - - postPolicyForm, err := ParsePostPolicyForm(string(policyBytes)) - if err != nil { - t.Fatal(err) - } - - err = CheckPostPolicy(formValues, postPolicyForm) - if err != nil && tt.expectedErr != nil && err.Error() != tt.expectedErr.Error() { - t.Fatalf("Test %d:, Expected %s, got %s", i+1, tt.expectedErr.Error(), err.Error()) - } - } -} diff --git a/weed/s3api/policy_engine/conditions.go b/weed/s3api/policy_engine/conditions.go index b32f11594..af55b06c2 100644 --- a/weed/s3api/policy_engine/conditions.go +++ b/weed/s3api/policy_engine/conditions.go @@ -125,22 +125,6 @@ func (c *NormalizedValueCache) evictLeastRecentlyUsed() { delete(c.cache, tail.key) } -// Clear clears all cached values -func (c *NormalizedValueCache) Clear() { - c.mu.Lock() - defer c.mu.Unlock() - c.cache = make(map[string]*LRUNode) - c.head.next = c.tail - c.tail.prev = c.head -} - -// GetStats returns cache statistics -func (c *NormalizedValueCache) GetStats() (size int, maxSize int) { - c.mu.RLock() - defer c.mu.RUnlock() - return len(c.cache), c.maxSize -} - // Global cache instance with size limit var normalizedValueCache = NewNormalizedValueCache(1000) @@ -769,34 +753,3 @@ func EvaluateConditions(conditions PolicyConditions, contextValues map[string][] return true } - -// EvaluateConditionsLegacy evaluates conditions using the old interface{} format for backward compatibility -// objectEntry is the object's metadata from entry.Extended (can be nil) -func EvaluateConditionsLegacy(conditions map[string]interface{}, contextValues map[string][]string, objectEntry map[string][]byte) bool { - if len(conditions) == 0 { - return true // No conditions means always true - } - - for operator, conditionMap := range conditions { - conditionEvaluator, err := GetConditionEvaluator(operator) - if err != nil { - glog.Warningf("Unsupported condition operator: %s", operator) - continue - } - - conditionMapTyped, ok := conditionMap.(map[string]interface{}) - if !ok { - glog.Warningf("Invalid condition format for operator: %s", operator) - continue - } - - for key, value := range conditionMapTyped { - contextVals := getConditionContextValue(key, contextValues, objectEntry) - if !conditionEvaluator.Evaluate(value, contextVals) { - return false // If any condition fails, the whole condition block fails - } - } - } - - return true -} diff --git a/weed/s3api/policy_engine/engine.go b/weed/s3api/policy_engine/engine.go index bf66ebfd2..d39b4b2ce 100644 --- a/weed/s3api/policy_engine/engine.go +++ b/weed/s3api/policy_engine/engine.go @@ -610,92 +610,6 @@ func BuildActionName(action string) string { return fmt.Sprintf("s3:%s", action) } -// IsReadAction checks if an action is a read action -func IsReadAction(action string) bool { - readActions := []string{ - "s3:GetObject", - "s3:GetObjectVersion", - "s3:GetObjectAcl", - "s3:GetObjectVersionAcl", - "s3:GetObjectTagging", - "s3:GetObjectVersionTagging", - "s3:ListBucket", - "s3:ListBucketVersions", - "s3:GetBucketLocation", - "s3:GetBucketVersioning", - "s3:GetBucketAcl", - "s3:GetBucketCors", - "s3:GetBucketPolicy", - "s3:GetBucketTagging", - "s3:GetBucketNotification", - "s3:GetBucketObjectLockConfiguration", - "s3:GetObjectRetention", - "s3:GetObjectLegalHold", - } - - for _, readAction := range readActions { - if action == readAction { - return true - } - } - return false -} - -// IsWriteAction checks if an action is a write action -func IsWriteAction(action string) bool { - writeActions := []string{ - "s3:PutObject", - "s3:PutObjectAcl", - "s3:PutObjectTagging", - "s3:DeleteObject", - "s3:DeleteObjectVersion", - "s3:DeleteObjectTagging", - "s3:AbortMultipartUpload", - "s3:ListMultipartUploads", - "s3:ListParts", - "s3:PutBucketAcl", - "s3:PutBucketCors", - "s3:PutBucketPolicy", - "s3:PutBucketTagging", - "s3:PutBucketNotification", - "s3:PutBucketVersioning", - "s3:DeleteBucketPolicy", - "s3:DeleteBucketTagging", - "s3:DeleteBucketCors", - "s3:PutBucketObjectLockConfiguration", - "s3:PutObjectRetention", - "s3:PutObjectLegalHold", - "s3:BypassGovernanceRetention", - } - - for _, writeAction := range writeActions { - if action == writeAction { - return true - } - } - return false -} - -// GetBucketNameFromArn extracts bucket name from ARN -func GetBucketNameFromArn(arn string) string { - if strings.HasPrefix(arn, "arn:aws:s3:::") { - parts := strings.SplitN(arn[13:], "/", 2) - return parts[0] - } - return "" -} - -// GetObjectNameFromArn extracts object name from ARN -func GetObjectNameFromArn(arn string) string { - if strings.HasPrefix(arn, "arn:aws:s3:::") { - parts := strings.SplitN(arn[13:], "/", 2) - if len(parts) > 1 { - return parts[1] - } - } - return "" -} - // GetPolicyStatements returns all policy statements for a bucket func (engine *PolicyEngine) GetPolicyStatements(bucketName string) []PolicyStatement { engine.mutex.RLock() diff --git a/weed/s3api/policy_engine/engine_test.go b/weed/s3api/policy_engine/engine_test.go index 1ad8c434a..452c01775 100644 --- a/weed/s3api/policy_engine/engine_test.go +++ b/weed/s3api/policy_engine/engine_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" "github.com/seaweedfs/seaweedfs/weed/util/wildcard" ) @@ -226,47 +225,6 @@ func TestConditionEvaluators(t *testing.T) { } } -func TestConvertIdentityToPolicy(t *testing.T) { - identityActions := []string{ - "Read:bucket1/*", - "Write:bucket1/*", - "Admin:bucket2", - } - - policy, err := ConvertIdentityToPolicy(identityActions) - if err != nil { - t.Fatalf("Failed to convert identity to policy: %v", err) - } - - if policy.Version != "2012-10-17" { - t.Errorf("Expected version 2012-10-17, got %s", policy.Version) - } - - if len(policy.Statement) != 3 { - t.Errorf("Expected 3 statements, got %d", len(policy.Statement)) - } - - // Check first statement (Read) - stmt := policy.Statement[0] - if stmt.Effect != PolicyEffectAllow { - t.Errorf("Expected Allow effect, got %s", stmt.Effect) - } - - actions := normalizeToStringSlice(stmt.Action) - // Read action now includes: GetObject, GetObjectVersion, ListBucket, ListBucketVersions, - // GetObjectAcl, GetObjectVersionAcl, GetObjectTagging, GetObjectVersionTagging, - // GetBucketLocation, GetBucketVersioning, GetBucketAcl, GetBucketCors, GetBucketTagging, GetBucketNotification - if len(actions) != 14 { - t.Errorf("Expected 14 read actions, got %d: %v", len(actions), actions) - } - - resources := normalizeToStringSlice(stmt.Resource) - // Read action now includes both bucket ARN (for ListBucket*) and object ARN (for GetObject*) - if len(resources) != 2 { - t.Errorf("Expected 2 resources (bucket and bucket/*), got %d: %v", len(resources), resources) - } -} - func TestPolicyValidation(t *testing.T) { tests := []struct { name string @@ -794,41 +752,6 @@ func TestCompilePolicy(t *testing.T) { } } -// TestNewPolicyBackedIAMWithLegacy tests the constructor overload -func TestNewPolicyBackedIAMWithLegacy(t *testing.T) { - // Mock legacy IAM - mockLegacyIAM := &MockLegacyIAM{} - - // Test the new constructor - policyBackedIAM := NewPolicyBackedIAMWithLegacy(mockLegacyIAM) - - // Verify that the legacy IAM is set - if policyBackedIAM.legacyIAM != mockLegacyIAM { - t.Errorf("Expected legacy IAM to be set, but it wasn't") - } - - // Verify that the policy engine is initialized - if policyBackedIAM.policyEngine == nil { - t.Errorf("Expected policy engine to be initialized, but it wasn't") - } - - // Compare with the traditional approach - traditionalIAM := NewPolicyBackedIAM() - traditionalIAM.SetLegacyIAM(mockLegacyIAM) - - // Both should behave the same - if policyBackedIAM.legacyIAM != traditionalIAM.legacyIAM { - t.Errorf("Expected both approaches to result in the same legacy IAM") - } -} - -// MockLegacyIAM implements the LegacyIAM interface for testing -type MockLegacyIAM struct{} - -func (m *MockLegacyIAM) authRequest(r *http.Request, action Action) (Identity, s3err.ErrorCode) { - return nil, s3err.ErrNone -} - // TestExistingObjectTagCondition tests s3:ExistingObjectTag/ condition support func TestExistingObjectTagCondition(t *testing.T) { engine := NewPolicyEngine() diff --git a/weed/s3api/policy_engine/integration.go b/weed/s3api/policy_engine/integration.go deleted file mode 100644 index d1d36d02a..000000000 --- a/weed/s3api/policy_engine/integration.go +++ /dev/null @@ -1,642 +0,0 @@ -package policy_engine - -import ( - "fmt" - "net/http" - "strings" - - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" -) - -// Action represents an S3 action - this should match the type in auth_credentials.go -type Action string - -// Identity represents a user identity - this should match the type in auth_credentials.go -type Identity interface { - CanDo(action Action, bucket string, objectKey string) bool -} - -// PolicyBackedIAM provides policy-based access control with fallback to legacy IAM -type PolicyBackedIAM struct { - policyEngine *PolicyEngine - legacyIAM LegacyIAM // Interface to delegate to existing IAM system -} - -// LegacyIAM interface for delegating to existing IAM implementation -type LegacyIAM interface { - authRequest(r *http.Request, action Action) (Identity, s3err.ErrorCode) -} - -// NewPolicyBackedIAM creates a new policy-backed IAM system -func NewPolicyBackedIAM() *PolicyBackedIAM { - return &PolicyBackedIAM{ - policyEngine: NewPolicyEngine(), - legacyIAM: nil, // Will be set when integrated with existing IAM - } -} - -// NewPolicyBackedIAMWithLegacy creates a new policy-backed IAM system with legacy IAM set -func NewPolicyBackedIAMWithLegacy(legacyIAM LegacyIAM) *PolicyBackedIAM { - return &PolicyBackedIAM{ - policyEngine: NewPolicyEngine(), - legacyIAM: legacyIAM, - } -} - -// SetLegacyIAM sets the legacy IAM system for fallback -func (p *PolicyBackedIAM) SetLegacyIAM(legacyIAM LegacyIAM) { - p.legacyIAM = legacyIAM -} - -// SetBucketPolicy sets the policy for a bucket -func (p *PolicyBackedIAM) SetBucketPolicy(bucketName string, policyJSON string) error { - return p.policyEngine.SetBucketPolicy(bucketName, policyJSON) -} - -// GetBucketPolicy gets the policy for a bucket -func (p *PolicyBackedIAM) GetBucketPolicy(bucketName string) (*PolicyDocument, error) { - return p.policyEngine.GetBucketPolicy(bucketName) -} - -// DeleteBucketPolicy deletes the policy for a bucket -func (p *PolicyBackedIAM) DeleteBucketPolicy(bucketName string) error { - return p.policyEngine.DeleteBucketPolicy(bucketName) -} - -// CanDo checks if a principal can perform an action on a resource -func (p *PolicyBackedIAM) CanDo(action, bucketName, objectName, principal string, r *http.Request) bool { - // If there's a bucket policy, evaluate it - if p.policyEngine.HasPolicyForBucket(bucketName) { - result := p.policyEngine.EvaluatePolicyForRequest(bucketName, objectName, action, principal, r) - switch result { - case PolicyResultAllow: - return true - case PolicyResultDeny: - return false - case PolicyResultIndeterminate: - // Fall through to legacy system - } - } - - // No bucket policy or indeterminate result, use legacy conversion - return p.evaluateLegacyAction(action, bucketName, objectName, principal) -} - -// evaluateLegacyAction evaluates actions using legacy identity-based rules -func (p *PolicyBackedIAM) evaluateLegacyAction(action, bucketName, objectName, principal string) bool { - // If we have a legacy IAM system to delegate to, use it - if p.legacyIAM != nil { - // Create a dummy request for legacy evaluation - // In real implementation, this would use the actual request - r := &http.Request{ - Header: make(http.Header), - } - - // Convert the action string to Action type - legacyAction := Action(action) - - // Use legacy IAM to check permission - identity, errCode := p.legacyIAM.authRequest(r, legacyAction) - if errCode != s3err.ErrNone { - return false - } - - // If we have an identity, check if it can perform the action - if identity != nil { - return identity.CanDo(legacyAction, bucketName, objectName) - } - } - - // No legacy IAM available, convert to policy and evaluate - return p.evaluateUsingPolicyConversion(action, bucketName, objectName, principal) -} - -// evaluateUsingPolicyConversion converts legacy action to policy and evaluates -func (p *PolicyBackedIAM) evaluateUsingPolicyConversion(action, bucketName, objectName, principal string) bool { - // For now, use a conservative approach for legacy actions - // In a real implementation, this would integrate with the existing identity system - glog.V(2).Infof("Legacy action evaluation for %s on %s/%s by %s", action, bucketName, objectName, principal) - - // Return false to maintain security until proper legacy integration is implemented - // This ensures no unintended access is granted - return false -} - -// extractBucketAndPrefix extracts bucket name and prefix from a resource pattern. -// Examples: -// -// "bucket" -> bucket="bucket", prefix="" -// "bucket/*" -> bucket="bucket", prefix="" -// "bucket/prefix/*" -> bucket="bucket", prefix="prefix" -// "bucket/a/b/c/*" -> bucket="bucket", prefix="a/b/c" -func extractBucketAndPrefix(pattern string) (string, string) { - // Validate input - pattern = strings.TrimSpace(pattern) - if pattern == "" || pattern == "/" { - return "", "" - } - - // Remove trailing /* if present - pattern = strings.TrimSuffix(pattern, "/*") - - // Remove a single trailing slash to avoid empty path segments - if strings.HasSuffix(pattern, "/") { - pattern = pattern[:len(pattern)-1] - } - if pattern == "" { - return "", "" - } - - // Split on the first / - parts := strings.SplitN(pattern, "/", 2) - bucket := strings.TrimSpace(parts[0]) - if bucket == "" { - return "", "" - } - - if len(parts) == 1 { - // No slash, entire pattern is bucket - return bucket, "" - } - // Has slash, first part is bucket, rest is prefix - prefix := strings.Trim(parts[1], "/") - return bucket, prefix -} - -// buildObjectResourceArn generates ARNs for object-level access. -// It properly handles both bucket-level (all objects) and prefix-level access. -// Returns empty slice if bucket is invalid to prevent generating malformed ARNs. -func buildObjectResourceArn(resourcePattern string) []string { - bucket, prefix := extractBucketAndPrefix(resourcePattern) - // If bucket is empty, the pattern is invalid; avoid generating malformed ARNs - if bucket == "" { - return []string{} - } - if prefix != "" { - // Prefix-based access: restrict to objects under this prefix - return []string{fmt.Sprintf("arn:aws:s3:::%s/%s/*", bucket, prefix)} - } - // Bucket-level access: all objects in bucket - return []string{fmt.Sprintf("arn:aws:s3:::%s/*", bucket)} -} - -// ConvertIdentityToPolicy converts a legacy identity action to an AWS policy -func ConvertIdentityToPolicy(identityActions []string) (*PolicyDocument, error) { - statements := make([]PolicyStatement, 0) - - for _, action := range identityActions { - stmt, err := convertSingleAction(action) - if err != nil { - glog.Warningf("Failed to convert action %s: %v", action, err) - continue - } - if stmt != nil { - statements = append(statements, *stmt) - } - } - - if len(statements) == 0 { - return nil, fmt.Errorf("no valid statements generated") - } - - return &PolicyDocument{ - Version: PolicyVersion2012_10_17, - Statement: statements, - }, nil -} - -// convertSingleAction converts a single legacy action to a policy statement. -// action format: "ActionType:ResourcePattern" (e.g., "Write:bucket/prefix/*") -func convertSingleAction(action string) (*PolicyStatement, error) { - parts := strings.Split(action, ":") - if len(parts) != 2 { - return nil, fmt.Errorf("invalid action format: %s", action) - } - - actionType := parts[0] - resourcePattern := parts[1] - - var s3Actions []string - var resources []string - - switch actionType { - case "Read": - // Read includes both object-level (GetObject, GetObjectAcl, GetObjectTagging, GetObjectVersions) - // and bucket-level operations (ListBucket, GetBucketLocation, GetBucketVersioning, GetBucketCors, etc.) - s3Actions = []string{ - "s3:GetObject", - "s3:GetObjectVersion", - "s3:GetObjectAcl", - "s3:GetObjectVersionAcl", - "s3:GetObjectTagging", - "s3:GetObjectVersionTagging", - "s3:ListBucket", - "s3:ListBucketVersions", - "s3:GetBucketLocation", - "s3:GetBucketVersioning", - "s3:GetBucketAcl", - "s3:GetBucketCors", - "s3:GetBucketTagging", - "s3:GetBucketNotification", - } - bucket, _ := extractBucketAndPrefix(resourcePattern) - objectResources := buildObjectResourceArn(resourcePattern) - // Include both bucket ARN (for ListBucket* and Get*Bucket operations) and object ARNs (for GetObject* operations) - if bucket != "" { - resources = append([]string{fmt.Sprintf("arn:aws:s3:::%s", bucket)}, objectResources...) - } else { - resources = objectResources - } - - case "Write": - // Write includes object-level writes (PutObject, DeleteObject, PutObjectAcl, DeleteObjectVersion, DeleteObjectTagging, PutObjectTagging) - // and bucket-level writes (PutBucketVersioning, PutBucketCors, DeleteBucketCors, PutBucketAcl, PutBucketTagging, DeleteBucketTagging, PutBucketNotification) - // and multipart upload operations (AbortMultipartUpload, ListMultipartUploads, ListParts). - // ListMultipartUploads and ListParts are included because they are part of the multipart upload workflow - // and require Write permissions to be meaningful (no point listing uploads if you can't abort/complete them). - s3Actions = []string{ - "s3:PutObject", - "s3:PutObjectAcl", - "s3:PutObjectTagging", - "s3:DeleteObject", - "s3:DeleteObjectVersion", - "s3:DeleteObjectTagging", - "s3:AbortMultipartUpload", - "s3:ListMultipartUploads", - "s3:ListParts", - "s3:PutBucketAcl", - "s3:PutBucketCors", - "s3:PutBucketTagging", - "s3:PutBucketNotification", - "s3:PutBucketVersioning", - "s3:DeleteBucketTagging", - "s3:DeleteBucketCors", - } - bucket, _ := extractBucketAndPrefix(resourcePattern) - objectResources := buildObjectResourceArn(resourcePattern) - // Include bucket ARN so bucket-level write operations (e.g., PutBucketVersioning, PutBucketCors) - // have the correct resource, while still allowing object-level writes. - if bucket != "" { - resources = append([]string{fmt.Sprintf("arn:aws:s3:::%s", bucket)}, objectResources...) - } else { - resources = objectResources - } - - case "Admin": - s3Actions = []string{"s3:*"} - bucket, prefix := extractBucketAndPrefix(resourcePattern) - if bucket == "" { - // Invalid pattern, return error - return nil, fmt.Errorf("Admin action requires a valid bucket name") - } - if prefix != "" { - // Subpath admin access: restrict to objects under this prefix - resources = []string{ - fmt.Sprintf("arn:aws:s3:::%s", bucket), - fmt.Sprintf("arn:aws:s3:::%s/%s/*", bucket, prefix), - } - } else { - // Bucket-level admin access: full bucket permissions - resources = []string{ - fmt.Sprintf("arn:aws:s3:::%s", bucket), - fmt.Sprintf("arn:aws:s3:::%s/*", bucket), - } - } - - case "List": - // List includes bucket listing operations and also ListAllMyBuckets - s3Actions = []string{"s3:ListBucket", "s3:ListBucketVersions", "s3:ListAllMyBuckets"} - // ListBucket actions only require bucket ARN, not object-level ARNs - bucket, _ := extractBucketAndPrefix(resourcePattern) - if bucket != "" { - resources = []string{fmt.Sprintf("arn:aws:s3:::%s", bucket)} - } else { - // Invalid pattern, return empty resources to fail validation - resources = []string{} - } - - case "Tagging": - // Tagging includes both object-level and bucket-level tagging operations - s3Actions = []string{ - "s3:GetObjectTagging", - "s3:PutObjectTagging", - "s3:DeleteObjectTagging", - "s3:GetBucketTagging", - "s3:PutBucketTagging", - "s3:DeleteBucketTagging", - } - bucket, _ := extractBucketAndPrefix(resourcePattern) - objectResources := buildObjectResourceArn(resourcePattern) - // Include bucket ARN so bucket-level tagging operations have the correct resource - if bucket != "" { - resources = append([]string{fmt.Sprintf("arn:aws:s3:::%s", bucket)}, objectResources...) - } else { - resources = objectResources - } - - case "BypassGovernanceRetention": - s3Actions = []string{"s3:BypassGovernanceRetention"} - resources = buildObjectResourceArn(resourcePattern) - - case "GetObjectRetention": - s3Actions = []string{"s3:GetObjectRetention"} - resources = buildObjectResourceArn(resourcePattern) - - case "PutObjectRetention": - s3Actions = []string{"s3:PutObjectRetention"} - resources = buildObjectResourceArn(resourcePattern) - - case "GetObjectLegalHold": - s3Actions = []string{"s3:GetObjectLegalHold"} - resources = buildObjectResourceArn(resourcePattern) - - case "PutObjectLegalHold": - s3Actions = []string{"s3:PutObjectLegalHold"} - resources = buildObjectResourceArn(resourcePattern) - - case "GetBucketObjectLockConfiguration": - s3Actions = []string{"s3:GetBucketObjectLockConfiguration"} - bucket, _ := extractBucketAndPrefix(resourcePattern) - if bucket != "" { - resources = []string{fmt.Sprintf("arn:aws:s3:::%s", bucket)} - } else { - // Invalid pattern, return empty resources to fail validation - resources = []string{} - } - - case "PutBucketObjectLockConfiguration": - s3Actions = []string{"s3:PutBucketObjectLockConfiguration"} - bucket, _ := extractBucketAndPrefix(resourcePattern) - if bucket != "" { - resources = []string{fmt.Sprintf("arn:aws:s3:::%s", bucket)} - } else { - // Invalid pattern, return empty resources to fail validation - resources = []string{} - } - - default: - return nil, fmt.Errorf("unknown action type: %s", actionType) - } - - return &PolicyStatement{ - Effect: PolicyEffectAllow, - Action: NewStringOrStringSlice(s3Actions...), - Resource: NewStringOrStringSlicePtr(resources...), - }, nil -} - -// GetActionMappings returns the mapping of legacy actions to S3 actions -func GetActionMappings() map[string][]string { - return map[string][]string{ - "Read": { - "s3:GetObject", - "s3:GetObjectVersion", - "s3:GetObjectAcl", - "s3:GetObjectVersionAcl", - "s3:GetObjectTagging", - "s3:GetObjectVersionTagging", - "s3:ListBucket", - "s3:ListBucketVersions", - "s3:GetBucketLocation", - "s3:GetBucketVersioning", - "s3:GetBucketAcl", - "s3:GetBucketCors", - "s3:GetBucketTagging", - "s3:GetBucketNotification", - }, - "Write": { - "s3:PutObject", - "s3:PutObjectAcl", - "s3:PutObjectTagging", - "s3:DeleteObject", - "s3:DeleteObjectVersion", - "s3:DeleteObjectTagging", - "s3:AbortMultipartUpload", - "s3:ListMultipartUploads", - "s3:ListParts", - "s3:PutBucketAcl", - "s3:PutBucketCors", - "s3:PutBucketTagging", - "s3:PutBucketNotification", - "s3:PutBucketVersioning", - "s3:DeleteBucketTagging", - "s3:DeleteBucketCors", - }, - "Admin": { - "s3:*", - }, - "List": { - "s3:ListBucket", - "s3:ListBucketVersions", - "s3:ListAllMyBuckets", - }, - "Tagging": { - "s3:GetObjectTagging", - "s3:PutObjectTagging", - "s3:DeleteObjectTagging", - "s3:GetBucketTagging", - "s3:PutBucketTagging", - "s3:DeleteBucketTagging", - }, - "BypassGovernanceRetention": { - "s3:BypassGovernanceRetention", - }, - "GetObjectRetention": { - "s3:GetObjectRetention", - }, - "PutObjectRetention": { - "s3:PutObjectRetention", - }, - "GetObjectLegalHold": { - "s3:GetObjectLegalHold", - }, - "PutObjectLegalHold": { - "s3:PutObjectLegalHold", - }, - "GetBucketObjectLockConfiguration": { - "s3:GetBucketObjectLockConfiguration", - }, - "PutBucketObjectLockConfiguration": { - "s3:PutBucketObjectLockConfiguration", - }, - } -} - -// ValidateActionMapping validates that a legacy action can be mapped to S3 actions -func ValidateActionMapping(action string) error { - mappings := GetActionMappings() - - parts := strings.Split(action, ":") - if len(parts) != 2 { - return fmt.Errorf("invalid action format: %s, expected format: 'ActionType:Resource'", action) - } - - actionType := parts[0] - resource := parts[1] - - if _, exists := mappings[actionType]; !exists { - return fmt.Errorf("unknown action type: %s", actionType) - } - - if resource == "" { - return fmt.Errorf("resource cannot be empty") - } - - return nil -} - -// ConvertLegacyActions converts an array of legacy actions to S3 actions -func ConvertLegacyActions(legacyActions []string) ([]string, error) { - mappings := GetActionMappings() - s3Actions := make([]string, 0) - - for _, legacyAction := range legacyActions { - if err := ValidateActionMapping(legacyAction); err != nil { - return nil, err - } - - parts := strings.Split(legacyAction, ":") - actionType := parts[0] - - if actionType == "Admin" { - // Admin gives all permissions, so we can just return s3:* - return []string{"s3:*"}, nil - } - - if mapped, exists := mappings[actionType]; exists { - s3Actions = append(s3Actions, mapped...) - } - } - - // Remove duplicates - uniqueActions := make([]string, 0) - seen := make(map[string]bool) - for _, action := range s3Actions { - if !seen[action] { - uniqueActions = append(uniqueActions, action) - seen[action] = true - } - } - - return uniqueActions, nil -} - -// GetResourcesFromLegacyAction extracts resources from a legacy action. -// It delegates to convertSingleAction to ensure consistent resource ARN generation -// across the codebase and avoid duplicating action-type-specific logic. -func GetResourcesFromLegacyAction(legacyAction string) ([]string, error) { - stmt, err := convertSingleAction(legacyAction) - if err != nil { - return nil, err - } - return stmt.Resource.Strings(), nil -} - -// CreatePolicyFromLegacyIdentity creates a policy document from legacy identity actions -func CreatePolicyFromLegacyIdentity(identityName string, actions []string) (*PolicyDocument, error) { - statements := make([]PolicyStatement, 0) - - // Group actions by resource pattern - resourceActions := make(map[string][]string) - - for _, action := range actions { - // Validate action format before processing - if err := ValidateActionMapping(action); err != nil { - glog.Warningf("Skipping invalid action %q for identity %q: %v", action, identityName, err) - continue - } - - parts := strings.Split(action, ":") - if len(parts) != 2 { - continue - } - - resourcePattern := parts[1] - actionType := parts[0] - - if _, exists := resourceActions[resourcePattern]; !exists { - resourceActions[resourcePattern] = make([]string, 0) - } - resourceActions[resourcePattern] = append(resourceActions[resourcePattern], actionType) - } - - // Create statements for each resource pattern - for resourcePattern, actionTypes := range resourceActions { - s3Actions := make([]string, 0) - resourceSet := make(map[string]struct{}) - - // Collect S3 actions and aggregate resource ARNs from all action types. - // Different action types have different resource ARN requirements: - // - List: bucket-level ARNs only - // - Read/Write/Tagging: object-level ARNs - // - Admin: full bucket access - // We must merge all required ARNs for the combined policy statement. - for _, actionType := range actionTypes { - if actionType == "Admin" { - s3Actions = []string{"s3:*"} - // Admin action determines the resources, so we can break after processing it. - res, err := GetResourcesFromLegacyAction(fmt.Sprintf("Admin:%s", resourcePattern)) - if err != nil { - glog.Warningf("Failed to get resources for Admin action on %s: %v", resourcePattern, err) - resourceSet = nil // Invalidate to skip this statement - break - } - for _, r := range res { - resourceSet[r] = struct{}{} - } - break - } - - if mapped, exists := GetActionMappings()[actionType]; exists { - s3Actions = append(s3Actions, mapped...) - res, err := GetResourcesFromLegacyAction(fmt.Sprintf("%s:%s", actionType, resourcePattern)) - if err != nil { - glog.Warningf("Failed to get resources for %s action on %s: %v", actionType, resourcePattern, err) - resourceSet = nil // Invalidate to skip this statement - break - } - for _, r := range res { - resourceSet[r] = struct{}{} - } - } - } - - if resourceSet == nil || len(s3Actions) == 0 { - continue - } - - resources := make([]string, 0, len(resourceSet)) - for r := range resourceSet { - resources = append(resources, r) - } - - statement := PolicyStatement{ - Sid: fmt.Sprintf("%s-%s", identityName, strings.ReplaceAll(resourcePattern, "/", "-")), - Effect: PolicyEffectAllow, - Action: NewStringOrStringSlice(s3Actions...), - Resource: NewStringOrStringSlicePtr(resources...), - } - - statements = append(statements, statement) - } - - if len(statements) == 0 { - return nil, fmt.Errorf("no valid statements generated for identity %s", identityName) - } - - return &PolicyDocument{ - Version: PolicyVersion2012_10_17, - Statement: statements, - }, nil -} - -// HasPolicyForBucket checks if a bucket has a policy -func (p *PolicyBackedIAM) HasPolicyForBucket(bucketName string) bool { - return p.policyEngine.HasPolicyForBucket(bucketName) -} - -// GetPolicyEngine returns the underlying policy engine -func (p *PolicyBackedIAM) GetPolicyEngine() *PolicyEngine { - return p.policyEngine -} diff --git a/weed/s3api/policy_engine/integration_test.go b/weed/s3api/policy_engine/integration_test.go deleted file mode 100644 index 6e74e51cb..000000000 --- a/weed/s3api/policy_engine/integration_test.go +++ /dev/null @@ -1,373 +0,0 @@ -package policy_engine - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -// TestConvertSingleActionDeleteObject tests support for s3:DeleteObject action (Issue #7864) -func TestConvertSingleActionDeleteObject(t *testing.T) { - // Test that Write action includes DeleteObject S3 action - stmt, err := convertSingleAction("Write:bucket") - assert.NoError(t, err) - assert.NotNil(t, stmt) - - // Check that s3:DeleteObject is included in the actions - actions := stmt.Action.Strings() - assert.Contains(t, actions, "s3:DeleteObject", "Write action should include s3:DeleteObject") - assert.Contains(t, actions, "s3:PutObject", "Write action should include s3:PutObject") -} - -// TestConvertSingleActionSubpath tests subpath handling for legacy actions (Issue #7864) -func TestConvertSingleActionSubpath(t *testing.T) { - testCases := []struct { - name string - action string - expectedActions []string - expectedResources []string - description string - }{ - { - name: "Write_on_bucket", - action: "Write:mybucket", - expectedActions: []string{"s3:PutObject", "s3:DeleteObject", "s3:PutObjectAcl", "s3:DeleteObjectVersion", "s3:PutObjectTagging", "s3:DeleteObjectTagging", "s3:AbortMultipartUpload", "s3:ListMultipartUploads", "s3:ListParts", "s3:PutBucketAcl", "s3:PutBucketCors", "s3:PutBucketTagging", "s3:PutBucketNotification", "s3:PutBucketVersioning", "s3:DeleteBucketTagging", "s3:DeleteBucketCors"}, - expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/*"}, - description: "Write permission on bucket should include bucket and object ARNs", - }, - { - name: "Write_on_bucket_with_wildcard", - action: "Write:mybucket/*", - expectedActions: []string{"s3:PutObject", "s3:DeleteObject", "s3:PutObjectAcl", "s3:DeleteObjectVersion", "s3:PutObjectTagging", "s3:DeleteObjectTagging", "s3:AbortMultipartUpload", "s3:ListMultipartUploads", "s3:ListParts", "s3:PutBucketAcl", "s3:PutBucketCors", "s3:PutBucketTagging", "s3:PutBucketNotification", "s3:PutBucketVersioning", "s3:DeleteBucketTagging", "s3:DeleteBucketCors"}, - expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/*"}, - description: "Write permission with /* should include bucket and object ARNs", - }, - { - name: "Write_on_subpath", - action: "Write:mybucket/sub_path/*", - expectedActions: []string{"s3:PutObject", "s3:DeleteObject", "s3:PutObjectAcl", "s3:DeleteObjectVersion", "s3:PutObjectTagging", "s3:DeleteObjectTagging", "s3:AbortMultipartUpload", "s3:ListMultipartUploads", "s3:ListParts", "s3:PutBucketAcl", "s3:PutBucketCors", "s3:PutBucketTagging", "s3:PutBucketNotification", "s3:PutBucketVersioning", "s3:DeleteBucketTagging", "s3:DeleteBucketCors"}, - expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/sub_path/*"}, - description: "Write permission on subpath should include bucket and subpath objects ARNs", - }, - { - name: "Read_on_subpath", - action: "Read:mybucket/documents/*", - expectedActions: []string{"s3:GetObject", "s3:GetObjectVersion", "s3:ListBucket", "s3:ListBucketVersions", "s3:GetObjectAcl", "s3:GetObjectVersionAcl", "s3:GetObjectTagging", "s3:GetObjectVersionTagging", "s3:GetBucketLocation", "s3:GetBucketVersioning", "s3:GetBucketAcl", "s3:GetBucketCors", "s3:GetBucketTagging", "s3:GetBucketNotification"}, - expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/documents/*"}, - description: "Read permission on subpath should include bucket ARN and subpath objects", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - stmt, err := convertSingleAction(tc.action) - assert.NoError(t, err, tc.description) - assert.NotNil(t, stmt) - - // Check actions - actions := stmt.Action.Strings() - for _, expectedAction := range tc.expectedActions { - assert.Contains(t, actions, expectedAction, - "Action %s should be included for %s", expectedAction, tc.action) - } - - // Check resources - verify all expected resources are present - resources := stmt.Resource.Strings() - assert.ElementsMatch(t, resources, tc.expectedResources, - "Resources should match exactly for %s. Got %v, expected %v", tc.action, resources, tc.expectedResources) - }) - } -} - -// TestConvertSingleActionSubpathDeleteAllowed tests that DeleteObject works on subpaths -func TestConvertSingleActionSubpathDeleteAllowed(t *testing.T) { - // This test specifically addresses Issue #7864 part 1: - // "when a user is granted permission to a subpath, eg s3.configure -user someuser - // -actions Write -buckets some_bucket/sub_path/* -apply - // the user will only be able to put, but not delete object under somebucket/sub_path" - - stmt, err := convertSingleAction("Write:some_bucket/sub_path/*") - assert.NoError(t, err) - - // The fix: s3:DeleteObject should be in the allowed actions - actions := stmt.Action.Strings() - assert.Contains(t, actions, "s3:DeleteObject", - "Write permission on subpath should allow deletion of objects in that path") - - // The resource should be restricted to the subpath - resources := stmt.Resource.Strings() - assert.Contains(t, resources, "arn:aws:s3:::some_bucket/sub_path/*", - "Delete permission should apply to objects under the subpath") -} - -// TestConvertSingleActionNestedPaths tests deeply nested paths -func TestConvertSingleActionNestedPaths(t *testing.T) { - testCases := []struct { - action string - expectedResources []string - }{ - { - action: "Write:bucket/a/b/c/*", - expectedResources: []string{"arn:aws:s3:::bucket", "arn:aws:s3:::bucket/a/b/c/*"}, - }, - { - action: "Read:bucket/data/documents/2024/*", - expectedResources: []string{"arn:aws:s3:::bucket", "arn:aws:s3:::bucket/data/documents/2024/*"}, - }, - } - - for _, tc := range testCases { - stmt, err := convertSingleAction(tc.action) - assert.NoError(t, err) - - resources := stmt.Resource.Strings() - assert.ElementsMatch(t, resources, tc.expectedResources) - } -} - -// TestGetResourcesFromLegacyAction tests that GetResourcesFromLegacyAction generates -// action-appropriate resources consistent with convertSingleAction -func TestGetResourcesFromLegacyAction(t *testing.T) { - testCases := []struct { - name string - action string - expectedResources []string - description string - }{ - // List actions - bucket-only (no object ARNs) - { - name: "List_on_bucket", - action: "List:mybucket", - expectedResources: []string{"arn:aws:s3:::mybucket"}, - description: "List action should only have bucket ARN", - }, - { - name: "List_on_bucket_with_wildcard", - action: "List:mybucket/*", - expectedResources: []string{"arn:aws:s3:::mybucket"}, - description: "List action should only have bucket ARN regardless of wildcard", - }, - // Read actions - bucket and object-level ARNs (includes List* and Get* operations) - { - name: "Read_on_bucket", - action: "Read:mybucket", - expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/*"}, - description: "Read action should have both bucket and object ARNs", - }, - { - name: "Read_on_subpath", - action: "Read:mybucket/documents/*", - expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/documents/*"}, - description: "Read action on subpath should have bucket ARN and object ARN for subpath", - }, - // Write actions - bucket and object ARNs (includes bucket-level operations) - { - name: "Write_on_subpath", - action: "Write:mybucket/sub_path/*", - expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/sub_path/*"}, - description: "Write action should have bucket and object ARNs", - }, - // Admin actions - both bucket and object ARNs - { - name: "Admin_on_bucket", - action: "Admin:mybucket", - expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/*"}, - description: "Admin action should have both bucket and object ARNs", - }, - { - name: "Admin_on_subpath", - action: "Admin:mybucket/admin/section/*", - expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/admin/section/*"}, - description: "Admin action on subpath should restrict to subpath, preventing privilege escalation", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - resources, err := GetResourcesFromLegacyAction(tc.action) - assert.NoError(t, err, tc.description) - assert.ElementsMatch(t, resources, tc.expectedResources, - "Resources should match expected. Got %v, expected %v", resources, tc.expectedResources) - - // Also verify consistency with convertSingleAction where applicable - stmt, err := convertSingleAction(tc.action) - assert.NoError(t, err) - - stmtResources := stmt.Resource.Strings() - assert.ElementsMatch(t, resources, stmtResources, - "GetResourcesFromLegacyAction should match convertSingleAction resources for %s", tc.action) - }) - } -} - -// TestExtractBucketAndPrefixEdgeCases validates edge case handling in extractBucketAndPrefix -func TestExtractBucketAndPrefixEdgeCases(t *testing.T) { - testCases := []struct { - name string - pattern string - expectedBucket string - expectedPrefix string - description string - }{ - { - name: "Empty string", - pattern: "", - expectedBucket: "", - expectedPrefix: "", - description: "Empty pattern should return empty strings", - }, - { - name: "Whitespace only", - pattern: " ", - expectedBucket: "", - expectedPrefix: "", - description: "Whitespace-only pattern should return empty strings", - }, - { - name: "Slash only", - pattern: "/", - expectedBucket: "", - expectedPrefix: "", - description: "Slash-only pattern should return empty strings", - }, - { - name: "Double slash prefix", - pattern: "bucket//prefix/*", - expectedBucket: "bucket", - expectedPrefix: "prefix", - description: "Double slash should be normalized (trailing slashes removed)", - }, - { - name: "Normal bucket", - pattern: "mybucket", - expectedBucket: "mybucket", - expectedPrefix: "", - description: "Bucket-only pattern should work correctly", - }, - { - name: "Bucket with prefix", - pattern: "mybucket/myprefix/*", - expectedBucket: "mybucket", - expectedPrefix: "myprefix", - description: "Bucket with prefix should be parsed correctly", - }, - { - name: "Nested prefix", - pattern: "mybucket/a/b/c/*", - expectedBucket: "mybucket", - expectedPrefix: "a/b/c", - description: "Nested prefix should be preserved", - }, - { - name: "Bucket with trailing slash", - pattern: "mybucket/", - expectedBucket: "mybucket", - expectedPrefix: "", - description: "Trailing slash on bucket should be normalized", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - bucket, prefix := extractBucketAndPrefix(tc.pattern) - assert.Equal(t, tc.expectedBucket, bucket, tc.description) - assert.Equal(t, tc.expectedPrefix, prefix, tc.description) - }) - } -} - -// TestCreatePolicyFromLegacyIdentityMultipleActions validates correct resource ARN aggregation -// when multiple action types target the same resource pattern -func TestCreatePolicyFromLegacyIdentityMultipleActions(t *testing.T) { - testCases := []struct { - name string - identityName string - actions []string - expectedStatements int - expectedActionsInStmt1 []string - expectedResourcesInStmt1 []string - description string - }{ - { - name: "List_and_Write_on_subpath", - identityName: "data-manager", - actions: []string{"List:mybucket/data/*", "Write:mybucket/data/*"}, - expectedStatements: 1, - expectedActionsInStmt1: []string{ - "s3:ListBucket", "s3:ListBucketVersions", "s3:ListAllMyBuckets", - "s3:PutObject", "s3:DeleteObject", "s3:PutObjectAcl", "s3:DeleteObjectVersion", - "s3:PutObjectTagging", "s3:DeleteObjectTagging", "s3:AbortMultipartUpload", - "s3:ListMultipartUploads", "s3:ListParts", "s3:PutBucketAcl", "s3:PutBucketCors", - "s3:PutBucketTagging", "s3:PutBucketNotification", "s3:PutBucketVersioning", - "s3:DeleteBucketTagging", "s3:DeleteBucketCors", - }, - expectedResourcesInStmt1: []string{ - "arn:aws:s3:::mybucket", // From List and Write actions - "arn:aws:s3:::mybucket/data/*", // From Write action - }, - description: "List + Write on same subpath should aggregate all actions and both bucket and object ARNs", - }, - { - name: "Read_and_Tagging_on_bucket", - identityName: "tag-reader", - actions: []string{"Read:mybucket", "Tagging:mybucket"}, - expectedStatements: 1, - expectedActionsInStmt1: []string{ - "s3:GetObject", "s3:GetObjectVersion", - "s3:ListBucket", "s3:ListBucketVersions", - "s3:GetObjectAcl", "s3:GetObjectVersionAcl", - "s3:GetObjectTagging", "s3:GetObjectVersionTagging", - "s3:PutObjectTagging", "s3:DeleteObjectTagging", - "s3:GetBucketLocation", "s3:GetBucketVersioning", - "s3:GetBucketAcl", "s3:GetBucketCors", "s3:GetBucketTagging", - "s3:GetBucketNotification", "s3:PutBucketTagging", "s3:DeleteBucketTagging", - }, - expectedResourcesInStmt1: []string{ - "arn:aws:s3:::mybucket", - "arn:aws:s3:::mybucket/*", - }, - description: "Read + Tagging on same bucket should aggregate all bucket and object-level actions and ARNs", - }, - { - name: "Admin_with_other_actions", - identityName: "admin-user", - actions: []string{"Admin:mybucket/admin/*", "Write:mybucket/admin/*"}, - expectedStatements: 1, - expectedActionsInStmt1: []string{"s3:*"}, - expectedResourcesInStmt1: []string{ - "arn:aws:s3:::mybucket", - "arn:aws:s3:::mybucket/admin/*", - }, - description: "Admin action should dominate and set s3:*, other actions still processed for resources", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - policy, err := CreatePolicyFromLegacyIdentity(tc.identityName, tc.actions) - assert.NoError(t, err, tc.description) - assert.NotNil(t, policy) - - // Check statement count - assert.Equal(t, tc.expectedStatements, len(policy.Statement), - "Expected %d statement(s), got %d", tc.expectedStatements, len(policy.Statement)) - - if tc.expectedStatements > 0 { - stmt := policy.Statement[0] - - // Check actions - actualActions := stmt.Action.Strings() - for _, expectedAction := range tc.expectedActionsInStmt1 { - assert.Contains(t, actualActions, expectedAction, - "Action %s should be included in statement", expectedAction) - } - - // Check resources - all expected resources should be present - actualResources := stmt.Resource.Strings() - assert.ElementsMatch(t, tc.expectedResourcesInStmt1, actualResources, - "Statement should aggregate all required resource ARNs. Got %v, expected %v", - actualResources, tc.expectedResourcesInStmt1) - } - }) - } -} diff --git a/weed/s3api/policy_engine/types.go b/weed/s3api/policy_engine/types.go index 862023b34..f1623ff15 100644 --- a/weed/s3api/policy_engine/types.go +++ b/weed/s3api/policy_engine/types.go @@ -490,11 +490,6 @@ func GetBucketFromResource(resource string) string { return "" } -// IsObjectResource checks if resource refers to objects -func IsObjectResource(resource string) bool { - return strings.Contains(resource, "/") -} - // MatchesAction checks if an action matches any of the compiled action matchers. // It also implicitly grants multipart upload actions if s3:PutObject is allowed, // since multipart upload is an implementation detail of putting objects. diff --git a/weed/s3api/s3_bucket_encryption.go b/weed/s3api/s3_bucket_encryption.go index 5a9fb7499..10901f8ac 100644 --- a/weed/s3api/s3_bucket_encryption.go +++ b/weed/s3api/s3_bucket_encryption.go @@ -288,70 +288,3 @@ func (s3a *S3ApiServer) GetDefaultEncryptionHeaders(bucket string) map[string]st return headers } - -// IsDefaultEncryptionEnabled checks if default encryption is enabled for a configuration -func IsDefaultEncryptionEnabled(config *s3_pb.EncryptionConfiguration) bool { - return config != nil && config.SseAlgorithm != "" -} - -// GetDefaultEncryptionHeaders generates default encryption headers from configuration -func GetDefaultEncryptionHeaders(config *s3_pb.EncryptionConfiguration) map[string]string { - if config == nil || config.SseAlgorithm == "" { - return nil - } - - headers := make(map[string]string) - headers[s3_constants.AmzServerSideEncryption] = config.SseAlgorithm - - if config.SseAlgorithm == "aws:kms" && config.KmsKeyId != "" { - headers[s3_constants.AmzServerSideEncryptionAwsKmsKeyId] = config.KmsKeyId - } - - return headers -} - -// encryptionConfigFromXMLBytes parses XML bytes to encryption configuration -func encryptionConfigFromXMLBytes(xmlBytes []byte) (*s3_pb.EncryptionConfiguration, error) { - var xmlConfig ServerSideEncryptionConfiguration - if err := xml.Unmarshal(xmlBytes, &xmlConfig); err != nil { - return nil, err - } - - // Validate namespace - should be empty or the standard AWS namespace - if xmlConfig.XMLName.Space != "" && xmlConfig.XMLName.Space != "http://s3.amazonaws.com/doc/2006-03-01/" { - return nil, fmt.Errorf("invalid XML namespace: %s", xmlConfig.XMLName.Space) - } - - // Validate the configuration - if len(xmlConfig.Rules) == 0 { - return nil, fmt.Errorf("encryption configuration must have at least one rule") - } - - rule := xmlConfig.Rules[0] - if rule.ApplyServerSideEncryptionByDefault.SSEAlgorithm == "" { - return nil, fmt.Errorf("encryption algorithm is required") - } - - // Validate algorithm - validAlgorithms := map[string]bool{ - "AES256": true, - "aws:kms": true, - } - - if !validAlgorithms[rule.ApplyServerSideEncryptionByDefault.SSEAlgorithm] { - return nil, fmt.Errorf("unsupported encryption algorithm: %s", rule.ApplyServerSideEncryptionByDefault.SSEAlgorithm) - } - - config := encryptionConfigFromXML(&xmlConfig) - return config, nil -} - -// encryptionConfigToXMLBytes converts encryption configuration to XML bytes -func encryptionConfigToXMLBytes(config *s3_pb.EncryptionConfiguration) ([]byte, error) { - if config == nil { - return nil, fmt.Errorf("encryption configuration is nil") - } - - xmlConfig := encryptionConfigToXML(config) - return xml.Marshal(xmlConfig) -} diff --git a/weed/s3api/s3_iam_middleware.go b/weed/s3api/s3_iam_middleware.go index 7820f3803..af454dee8 100644 --- a/weed/s3api/s3_iam_middleware.go +++ b/weed/s3api/s3_iam_middleware.go @@ -13,7 +13,6 @@ import ( "github.com/seaweedfs/seaweedfs/weed/iam/integration" "github.com/seaweedfs/seaweedfs/weed/iam/providers" "github.com/seaweedfs/seaweedfs/weed/iam/sts" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" ) @@ -381,52 +380,6 @@ func buildS3ResourceArn(bucket string, objectKey string) string { return "arn:aws:s3:::" + bucket + "/" + objectKey } -// mapLegacyActionToIAM provides fallback mapping for legacy actions -// This ensures backward compatibility while the system transitions to granular actions -func mapLegacyActionToIAM(legacyAction Action) string { - switch legacyAction { - case s3_constants.ACTION_READ: - return "s3:GetObject" // Fallback for unmapped read operations - case s3_constants.ACTION_WRITE: - return "s3:PutObject" // Fallback for unmapped write operations - case s3_constants.ACTION_LIST: - return "s3:ListBucket" // Fallback for unmapped list operations - case s3_constants.ACTION_TAGGING: - return "s3:GetObjectTagging" // Fallback for unmapped tagging operations - case s3_constants.ACTION_READ_ACP: - return "s3:GetObjectAcl" // Fallback for unmapped ACL read operations - case s3_constants.ACTION_WRITE_ACP: - return "s3:PutObjectAcl" // Fallback for unmapped ACL write operations - case s3_constants.ACTION_DELETE_BUCKET: - return "s3:DeleteBucket" // Fallback for unmapped bucket delete operations - case s3_constants.ACTION_ADMIN: - return "s3:*" // Fallback for unmapped admin operations - - // Handle granular multipart actions (already correctly mapped) - case s3_constants.S3_ACTION_CREATE_MULTIPART: - return s3_constants.S3_ACTION_CREATE_MULTIPART - case s3_constants.S3_ACTION_UPLOAD_PART: - return s3_constants.S3_ACTION_UPLOAD_PART - case s3_constants.S3_ACTION_COMPLETE_MULTIPART: - return s3_constants.S3_ACTION_COMPLETE_MULTIPART - case s3_constants.S3_ACTION_ABORT_MULTIPART: - return s3_constants.S3_ACTION_ABORT_MULTIPART - case s3_constants.S3_ACTION_LIST_MULTIPART_UPLOADS: - return s3_constants.S3_ACTION_LIST_MULTIPART_UPLOADS - case s3_constants.S3_ACTION_LIST_PARTS: - return s3_constants.S3_ACTION_LIST_PARTS - - default: - // If it's already a properly formatted S3 action, return as-is - actionStr := string(legacyAction) - if strings.HasPrefix(actionStr, "s3:") { - return actionStr - } - // Fallback: convert to S3 action format - return "s3:" + actionStr - } -} - // extractRequestContext extracts request context for policy conditions func extractRequestContext(r *http.Request) map[string]interface{} { context := make(map[string]interface{}) @@ -553,79 +506,6 @@ type EnhancedS3ApiServer struct { iamIntegration IAMIntegration } -// NewEnhancedS3ApiServer creates an S3 API server with IAM integration -func NewEnhancedS3ApiServer(baseServer *S3ApiServer, iamManager *integration.IAMManager) *EnhancedS3ApiServer { - // Set the IAM integration on the base server - baseServer.SetIAMIntegration(iamManager) - - return &EnhancedS3ApiServer{ - S3ApiServer: baseServer, - iamIntegration: NewS3IAMIntegration(iamManager, "localhost:8888"), - } -} - -// AuthenticateJWTRequest handles JWT authentication for S3 requests -func (enhanced *EnhancedS3ApiServer) AuthenticateJWTRequest(r *http.Request) (*Identity, s3err.ErrorCode) { - ctx := r.Context() - - // Use our IAM integration for JWT authentication - iamIdentity, errCode := enhanced.iamIntegration.AuthenticateJWT(ctx, r) - if errCode != s3err.ErrNone { - return nil, errCode - } - - // Convert IAMIdentity to the existing Identity structure - identity := &Identity{ - Name: iamIdentity.Name, - Account: iamIdentity.Account, - // Note: Actions will be determined by policy evaluation - Actions: []Action{}, // Empty - authorization handled by policy engine - PolicyNames: iamIdentity.PolicyNames, - } - - // Store session token for later authorization - r.Header.Set("X-SeaweedFS-Session-Token", iamIdentity.SessionToken) - r.Header.Set("X-SeaweedFS-Principal", iamIdentity.Principal) - - return identity, s3err.ErrNone -} - -// AuthorizeRequest handles authorization for S3 requests using policy engine -func (enhanced *EnhancedS3ApiServer) AuthorizeRequest(r *http.Request, identity *Identity, action Action) s3err.ErrorCode { - ctx := r.Context() - - // Get session info from request headers (set during authentication) - sessionToken := r.Header.Get("X-SeaweedFS-Session-Token") - principal := r.Header.Get("X-SeaweedFS-Principal") - - if sessionToken == "" || principal == "" { - glog.V(3).Info("No session information available for authorization") - return s3err.ErrAccessDenied - } - - // Extract bucket and object from request - bucket, object := s3_constants.GetBucketAndObject(r) - prefix := s3_constants.GetPrefix(r) - - // For List operations, use prefix for permission checking if available - if action == s3_constants.ACTION_LIST && object == "" && prefix != "" { - object = prefix - } else if (object == "/" || object == "") && prefix != "" { - object = prefix - } - - // Create IAM identity for authorization - iamIdentity := &IAMIdentity{ - Name: identity.Name, - Principal: principal, - SessionToken: sessionToken, - Account: identity.Account, - } - - // Use our IAM integration for authorization - return enhanced.iamIntegration.AuthorizeAction(ctx, iamIdentity, action, bucket, object, r) -} - // OIDCIdentity represents an identity validated through OIDC type OIDCIdentity struct { UserID string diff --git a/weed/s3api/s3_iam_simple_test.go b/weed/s3api/s3_iam_simple_test.go deleted file mode 100644 index c2c68321f..000000000 --- a/weed/s3api/s3_iam_simple_test.go +++ /dev/null @@ -1,584 +0,0 @@ -package s3api - -import ( - "context" - "net/http" - "net/http/httptest" - "net/url" - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/iam/integration" - "github.com/seaweedfs/seaweedfs/weed/iam/policy" - "github.com/seaweedfs/seaweedfs/weed/iam/sts" - "github.com/seaweedfs/seaweedfs/weed/iam/utils" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func newTestS3IAMManagerWithDefaultEffect(t *testing.T, defaultEffect string) *integration.IAMManager { - t.Helper() - - iamManager := integration.NewIAMManager() - config := &integration.IAMConfig{ - STS: &sts.STSConfig{ - TokenDuration: sts.FlexibleDuration{Duration: time.Hour}, - MaxSessionLength: sts.FlexibleDuration{Duration: time.Hour * 12}, - Issuer: "test-sts", - SigningKey: []byte("test-signing-key-32-characters-long"), - }, - Policy: &policy.PolicyEngineConfig{ - DefaultEffect: defaultEffect, - StoreType: "memory", - }, - Roles: &integration.RoleStoreConfig{ - StoreType: "memory", - }, - } - - err := iamManager.Initialize(config, func() string { - return "localhost:8888" - }) - require.NoError(t, err) - - return iamManager -} - -func newTestS3IAMManager(t *testing.T) *integration.IAMManager { - t.Helper() - return newTestS3IAMManagerWithDefaultEffect(t, "Deny") -} - -// TestS3IAMMiddleware tests the basic S3 IAM middleware functionality -func TestS3IAMMiddleware(t *testing.T) { - iamManager := newTestS3IAMManager(t) - - // Create S3 IAM integration - s3IAMIntegration := NewS3IAMIntegration(iamManager, "localhost:8888") - - // Test that integration is created successfully - assert.NotNil(t, s3IAMIntegration) - assert.True(t, s3IAMIntegration.enabled) -} - -func TestS3IAMMiddlewareStaticV4ManagedPolicies(t *testing.T) { - ctx := context.Background() - iamManager := newTestS3IAMManager(t) - - allowPolicy := &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Effect: "Allow", - Action: policy.StringList{"s3:PutObject", "s3:ListBucket"}, - Resource: policy.StringList{"arn:aws:s3:::cli-allowed-bucket", "arn:aws:s3:::cli-allowed-bucket/*"}, - }, - }, - } - require.NoError(t, iamManager.CreatePolicy(ctx, "localhost:8888", "cli-bucket-access-policy", allowPolicy)) - - s3IAMIntegration := NewS3IAMIntegration(iamManager, "localhost:8888") - identity := &IAMIdentity{ - Name: "cli-test-user", - Principal: "arn:aws:iam::000000000000:user/cli-test-user", - PolicyNames: []string{"cli-bucket-access-policy"}, - } - - putReq := httptest.NewRequest(http.MethodPut, "http://example.com/cli-allowed-bucket/test-file.txt", http.NoBody) - putErrCode := s3IAMIntegration.AuthorizeAction(ctx, identity, s3_constants.ACTION_WRITE, "cli-allowed-bucket", "test-file.txt", putReq) - assert.Equal(t, s3err.ErrNone, putErrCode) - - listReq := httptest.NewRequest(http.MethodGet, "http://example.com/cli-allowed-bucket/", http.NoBody) - listErrCode := s3IAMIntegration.AuthorizeAction(ctx, identity, s3_constants.ACTION_LIST, "cli-allowed-bucket", "", listReq) - assert.Equal(t, s3err.ErrNone, listErrCode) -} - -func TestS3IAMMiddlewareAttachedPoliciesRestrictDefaultAllow(t *testing.T) { - ctx := context.Background() - iamManager := newTestS3IAMManagerWithDefaultEffect(t, "Allow") - - allowPolicy := &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Effect: "Allow", - Action: policy.StringList{"s3:PutObject", "s3:ListBucket"}, - Resource: policy.StringList{"arn:aws:s3:::cli-allowed-bucket", "arn:aws:s3:::cli-allowed-bucket/*"}, - }, - }, - } - require.NoError(t, iamManager.CreatePolicy(ctx, "localhost:8888", "cli-bucket-access-policy", allowPolicy)) - - s3IAMIntegration := NewS3IAMIntegration(iamManager, "localhost:8888") - identity := &IAMIdentity{ - Name: "cli-test-user", - Principal: "arn:aws:iam::000000000000:user/cli-test-user", - PolicyNames: []string{"cli-bucket-access-policy"}, - } - - allowedReq := httptest.NewRequest(http.MethodPut, "http://example.com/cli-allowed-bucket/test-file.txt", http.NoBody) - allowedErrCode := s3IAMIntegration.AuthorizeAction(ctx, identity, s3_constants.ACTION_WRITE, "cli-allowed-bucket", "test-file.txt", allowedReq) - assert.Equal(t, s3err.ErrNone, allowedErrCode) - - forbiddenReq := httptest.NewRequest(http.MethodPut, "http://example.com/cli-forbidden-bucket/forbidden-file.txt", http.NoBody) - forbiddenErrCode := s3IAMIntegration.AuthorizeAction(ctx, identity, s3_constants.ACTION_WRITE, "cli-forbidden-bucket", "forbidden-file.txt", forbiddenReq) - assert.Equal(t, s3err.ErrAccessDenied, forbiddenErrCode) - - forbiddenListReq := httptest.NewRequest(http.MethodGet, "http://example.com/cli-forbidden-bucket/", http.NoBody) - forbiddenListErrCode := s3IAMIntegration.AuthorizeAction(ctx, identity, s3_constants.ACTION_LIST, "cli-forbidden-bucket", "", forbiddenListReq) - assert.Equal(t, s3err.ErrAccessDenied, forbiddenListErrCode) -} - -// TestS3IAMMiddlewareJWTAuth tests JWT authentication -func TestS3IAMMiddlewareJWTAuth(t *testing.T) { - // Skip for now since it requires full setup - t.Skip("JWT authentication test requires full IAM setup") - - // Create IAM integration - s3iam := NewS3IAMIntegration(nil, "localhost:8888") // Disabled integration - - // Create test request with JWT token - req := httptest.NewRequest("GET", "/test-bucket/test-object", http.NoBody) - req.Header.Set("Authorization", "Bearer test-token") - - // Test authentication (should return not implemented when disabled) - ctx := context.Background() - identity, errCode := s3iam.AuthenticateJWT(ctx, req) - - assert.Nil(t, identity) - assert.NotEqual(t, errCode, 0) // Should return an error -} - -// TestBuildS3ResourceArn tests resource ARN building -func TestBuildS3ResourceArn(t *testing.T) { - tests := []struct { - name string - bucket string - object string - expected string - }{ - { - name: "empty bucket and object", - bucket: "", - object: "", - expected: "arn:aws:s3:::*", - }, - { - name: "bucket only", - bucket: "test-bucket", - object: "", - expected: "arn:aws:s3:::test-bucket", - }, - { - name: "bucket and object", - bucket: "test-bucket", - object: "test-object.txt", - expected: "arn:aws:s3:::test-bucket/test-object.txt", - }, - { - name: "bucket and object with leading slash", - bucket: "test-bucket", - object: "/test-object.txt", - expected: "arn:aws:s3:::test-bucket/test-object.txt", - }, - { - name: "bucket and nested object", - bucket: "test-bucket", - object: "folder/subfolder/test-object.txt", - expected: "arn:aws:s3:::test-bucket/folder/subfolder/test-object.txt", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := buildS3ResourceArn(tt.bucket, tt.object) - assert.Equal(t, tt.expected, result) - }) - } -} - -// TestDetermineGranularS3Action tests granular S3 action determination from HTTP requests -func TestDetermineGranularS3Action(t *testing.T) { - tests := []struct { - name string - method string - bucket string - objectKey string - queryParams map[string]string - fallbackAction Action - expected string - description string - }{ - // Object-level operations - { - name: "get_object", - method: "GET", - bucket: "test-bucket", - objectKey: "test-object.txt", - queryParams: map[string]string{}, - fallbackAction: s3_constants.ACTION_READ, - expected: "s3:GetObject", - description: "Basic object retrieval", - }, - { - name: "get_object_acl", - method: "GET", - bucket: "test-bucket", - objectKey: "test-object.txt", - queryParams: map[string]string{"acl": ""}, - fallbackAction: s3_constants.ACTION_READ_ACP, - expected: "s3:GetObjectAcl", - description: "Object ACL retrieval", - }, - { - name: "get_object_tagging", - method: "GET", - bucket: "test-bucket", - objectKey: "test-object.txt", - queryParams: map[string]string{"tagging": ""}, - fallbackAction: s3_constants.ACTION_TAGGING, - expected: "s3:GetObjectTagging", - description: "Object tagging retrieval", - }, - { - name: "put_object", - method: "PUT", - bucket: "test-bucket", - objectKey: "test-object.txt", - queryParams: map[string]string{}, - fallbackAction: s3_constants.ACTION_WRITE, - expected: "s3:PutObject", - description: "Basic object upload", - }, - { - name: "put_object_acl", - method: "PUT", - bucket: "test-bucket", - objectKey: "test-object.txt", - queryParams: map[string]string{"acl": ""}, - fallbackAction: s3_constants.ACTION_WRITE_ACP, - expected: "s3:PutObjectAcl", - description: "Object ACL modification", - }, - { - name: "delete_object", - method: "DELETE", - bucket: "test-bucket", - objectKey: "test-object.txt", - queryParams: map[string]string{}, - fallbackAction: s3_constants.ACTION_WRITE, // DELETE object uses WRITE fallback - expected: "s3:DeleteObject", - description: "Object deletion - correctly mapped to DeleteObject (not PutObject)", - }, - { - name: "delete_object_tagging", - method: "DELETE", - bucket: "test-bucket", - objectKey: "test-object.txt", - queryParams: map[string]string{"tagging": ""}, - fallbackAction: s3_constants.ACTION_TAGGING, - expected: "s3:DeleteObjectTagging", - description: "Object tag deletion", - }, - - // Multipart upload operations - { - name: "create_multipart_upload", - method: "POST", - bucket: "test-bucket", - objectKey: "large-file.txt", - queryParams: map[string]string{"uploads": ""}, - fallbackAction: s3_constants.ACTION_WRITE, - expected: "s3:CreateMultipartUpload", - description: "Multipart upload initiation", - }, - { - name: "upload_part", - method: "PUT", - bucket: "test-bucket", - objectKey: "large-file.txt", - queryParams: map[string]string{"uploadId": "12345", "partNumber": "1"}, - fallbackAction: s3_constants.ACTION_WRITE, - expected: "s3:UploadPart", - description: "Multipart part upload", - }, - { - name: "complete_multipart_upload", - method: "POST", - bucket: "test-bucket", - objectKey: "large-file.txt", - queryParams: map[string]string{"uploadId": "12345"}, - fallbackAction: s3_constants.ACTION_WRITE, - expected: "s3:CompleteMultipartUpload", - description: "Multipart upload completion", - }, - { - name: "abort_multipart_upload", - method: "DELETE", - bucket: "test-bucket", - objectKey: "large-file.txt", - queryParams: map[string]string{"uploadId": "12345"}, - fallbackAction: s3_constants.ACTION_WRITE, - expected: "s3:AbortMultipartUpload", - description: "Multipart upload abort", - }, - - // Bucket-level operations - { - name: "list_bucket", - method: "GET", - bucket: "test-bucket", - objectKey: "", - queryParams: map[string]string{}, - fallbackAction: s3_constants.ACTION_LIST, - expected: "s3:ListBucket", - description: "Bucket listing", - }, - { - name: "get_bucket_acl", - method: "GET", - bucket: "test-bucket", - objectKey: "", - queryParams: map[string]string{"acl": ""}, - fallbackAction: s3_constants.ACTION_READ_ACP, - expected: "s3:GetBucketAcl", - description: "Bucket ACL retrieval", - }, - { - name: "put_bucket_policy", - method: "PUT", - bucket: "test-bucket", - objectKey: "", - queryParams: map[string]string{"policy": ""}, - fallbackAction: s3_constants.ACTION_WRITE, - expected: "s3:PutBucketPolicy", - description: "Bucket policy modification", - }, - { - name: "delete_bucket", - method: "DELETE", - bucket: "test-bucket", - objectKey: "", - queryParams: map[string]string{}, - fallbackAction: s3_constants.ACTION_DELETE_BUCKET, - expected: "s3:DeleteBucket", - description: "Bucket deletion", - }, - { - name: "list_multipart_uploads", - method: "GET", - bucket: "test-bucket", - objectKey: "", - queryParams: map[string]string{"uploads": ""}, - fallbackAction: s3_constants.ACTION_LIST, - expected: "s3:ListBucketMultipartUploads", - description: "List multipart uploads in bucket", - }, - - // Fallback scenarios - { - name: "legacy_read_fallback", - method: "GET", - bucket: "", - objectKey: "", - queryParams: map[string]string{}, - fallbackAction: s3_constants.ACTION_READ, - expected: "s3:GetObject", - description: "Legacy read action fallback", - }, - { - name: "already_granular_action", - method: "GET", - bucket: "", - objectKey: "", - queryParams: map[string]string{}, - fallbackAction: "s3:GetBucketLocation", // Already granular - expected: "s3:GetBucketLocation", - description: "Already granular action passed through", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create HTTP request with query parameters - req := &http.Request{ - Method: tt.method, - URL: &url.URL{Path: "/" + tt.bucket + "/" + tt.objectKey}, - } - - // Add query parameters - query := req.URL.Query() - for key, value := range tt.queryParams { - query.Set(key, value) - } - req.URL.RawQuery = query.Encode() - - // Test the action determination - result := ResolveS3Action(req, string(tt.fallbackAction), tt.bucket, tt.objectKey) - - assert.Equal(t, tt.expected, result, - "Test %s failed: %s. Expected %s but got %s", - tt.name, tt.description, tt.expected, result) - }) - } -} - -// TestMapLegacyActionToIAM tests the legacy action fallback mapping -func TestMapLegacyActionToIAM(t *testing.T) { - tests := []struct { - name string - legacyAction Action - expected string - }{ - { - name: "read_action_fallback", - legacyAction: s3_constants.ACTION_READ, - expected: "s3:GetObject", - }, - { - name: "write_action_fallback", - legacyAction: s3_constants.ACTION_WRITE, - expected: "s3:PutObject", - }, - { - name: "admin_action_fallback", - legacyAction: s3_constants.ACTION_ADMIN, - expected: "s3:*", - }, - { - name: "granular_multipart_action", - legacyAction: s3_constants.S3_ACTION_CREATE_MULTIPART, - expected: s3_constants.S3_ACTION_CREATE_MULTIPART, - }, - { - name: "unknown_action_with_s3_prefix", - legacyAction: "s3:CustomAction", - expected: "s3:CustomAction", - }, - { - name: "unknown_action_without_prefix", - legacyAction: "CustomAction", - expected: "s3:CustomAction", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := mapLegacyActionToIAM(tt.legacyAction) - assert.Equal(t, tt.expected, result) - }) - } -} - -// TestExtractSourceIP tests source IP extraction from requests -func TestExtractSourceIP(t *testing.T) { - tests := []struct { - name string - setupReq func() *http.Request - expectedIP string - }{ - { - name: "X-Forwarded-For header", - setupReq: func() *http.Request { - req := httptest.NewRequest("GET", "/test", http.NoBody) - req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1") - // Set RemoteAddr to private IP to simulate trusted proxy - req.RemoteAddr = "127.0.0.1:12345" - return req - }, - expectedIP: "192.168.1.100", - }, - { - name: "X-Real-IP header", - setupReq: func() *http.Request { - req := httptest.NewRequest("GET", "/test", http.NoBody) - req.Header.Set("X-Real-IP", "192.168.1.200") - // Set RemoteAddr to private IP to simulate trusted proxy - req.RemoteAddr = "127.0.0.1:12345" - return req - }, - expectedIP: "192.168.1.200", - }, - { - name: "RemoteAddr fallback", - setupReq: func() *http.Request { - req := httptest.NewRequest("GET", "/test", http.NoBody) - req.RemoteAddr = "192.168.1.300:12345" - return req - }, - expectedIP: "192.168.1.300", - }, - { - name: "Untrusted proxy - public RemoteAddr ignores X-Forwarded-For", - setupReq: func() *http.Request { - req := httptest.NewRequest("GET", "/test", http.NoBody) - req.Header.Set("X-Forwarded-For", "192.168.1.100") - // Public IP - headers should NOT be trusted - req.RemoteAddr = "8.8.8.8:12345" - return req - }, - expectedIP: "8.8.8.8", // Should use RemoteAddr, not X-Forwarded-For - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := tt.setupReq() - result := extractSourceIP(req) - assert.Equal(t, tt.expectedIP, result) - }) - } -} - -// TestExtractRoleNameFromPrincipal tests role name extraction -func TestExtractRoleNameFromPrincipal(t *testing.T) { - tests := []struct { - name string - principal string - expected string - }{ - { - name: "valid assumed role ARN", - principal: "arn:aws:sts::assumed-role/S3ReadOnlyRole/session-123", - expected: "S3ReadOnlyRole", - }, - { - name: "invalid format", - principal: "invalid-principal", - expected: "", // Returns empty string to signal invalid format - }, - { - name: "missing session name", - principal: "arn:aws:sts::assumed-role/TestRole", - expected: "TestRole", // Extracts role name even without session name - }, - { - name: "empty principal", - principal: "", - expected: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := utils.ExtractRoleNameFromPrincipal(tt.principal) - assert.Equal(t, tt.expected, result) - }) - } -} - -// TestIAMIdentityIsAdmin tests the IsAdmin method -func TestIAMIdentityIsAdmin(t *testing.T) { - identity := &IAMIdentity{ - Name: "test-identity", - Principal: "arn:aws:sts::assumed-role/TestRole/session", - SessionToken: "test-token", - } - - // In our implementation, IsAdmin always returns false since admin status - // is determined by policies, not identity - result := identity.IsAdmin() - assert.False(t, result) -} diff --git a/weed/s3api/s3_multipart_iam.go b/weed/s3api/s3_multipart_iam.go deleted file mode 100644 index de3bccae9..000000000 --- a/weed/s3api/s3_multipart_iam.go +++ /dev/null @@ -1,420 +0,0 @@ -package s3api - -import ( - "fmt" - "net/http" - "strconv" - "strings" - "time" - - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" -) - -// S3MultipartIAMManager handles IAM integration for multipart upload operations -type S3MultipartIAMManager struct { - s3iam *S3IAMIntegration -} - -// NewS3MultipartIAMManager creates a new multipart IAM manager -func NewS3MultipartIAMManager(s3iam *S3IAMIntegration) *S3MultipartIAMManager { - return &S3MultipartIAMManager{ - s3iam: s3iam, - } -} - -// MultipartUploadRequest represents a multipart upload request -type MultipartUploadRequest struct { - Bucket string `json:"bucket"` // S3 bucket name - ObjectKey string `json:"object_key"` // S3 object key - UploadID string `json:"upload_id"` // Multipart upload ID - PartNumber int `json:"part_number"` // Part number for upload part - Operation string `json:"operation"` // Multipart operation type - SessionToken string `json:"session_token"` // JWT session token - Headers map[string]string `json:"headers"` // Request headers - ContentSize int64 `json:"content_size"` // Content size for validation -} - -// MultipartUploadPolicy represents security policies for multipart uploads -type MultipartUploadPolicy struct { - MaxPartSize int64 `json:"max_part_size"` // Maximum part size (5GB AWS limit) - MinPartSize int64 `json:"min_part_size"` // Minimum part size (5MB AWS limit, except last part) - MaxParts int `json:"max_parts"` // Maximum number of parts (10,000 AWS limit) - MaxUploadDuration time.Duration `json:"max_upload_duration"` // Maximum time to complete multipart upload - AllowedContentTypes []string `json:"allowed_content_types"` // Allowed content types - RequiredHeaders []string `json:"required_headers"` // Required headers for validation - IPWhitelist []string `json:"ip_whitelist"` // Allowed IP addresses/ranges -} - -// MultipartOperation represents different multipart upload operations -type MultipartOperation string - -const ( - MultipartOpInitiate MultipartOperation = "initiate" - MultipartOpUploadPart MultipartOperation = "upload_part" - MultipartOpComplete MultipartOperation = "complete" - MultipartOpAbort MultipartOperation = "abort" - MultipartOpList MultipartOperation = "list" - MultipartOpListParts MultipartOperation = "list_parts" -) - -// ValidateMultipartOperationWithIAM validates multipart operations using IAM policies -func (iam *IdentityAccessManagement) ValidateMultipartOperationWithIAM(r *http.Request, identity *Identity, operation MultipartOperation) s3err.ErrorCode { - if iam.iamIntegration == nil { - // Fall back to standard validation - return s3err.ErrNone - } - - // Extract bucket and object from request - bucket, object := s3_constants.GetBucketAndObject(r) - - // Determine the S3 action based on multipart operation - action := determineMultipartS3Action(operation) - - // Extract session token from request - sessionToken := extractSessionTokenFromRequest(r) - if sessionToken == "" { - // No session token - use standard auth - return s3err.ErrNone - } - - // Retrieve the actual principal ARN from the request header - // This header is set during initial authentication and contains the correct assumed role ARN - principalArn := r.Header.Get("X-SeaweedFS-Principal") - if principalArn == "" { - glog.V(2).Info("IAM authorization for multipart operation failed: missing principal ARN in request header") - return s3err.ErrAccessDenied - } - - // Create IAM identity for authorization - iamIdentity := &IAMIdentity{ - Name: identity.Name, - Principal: principalArn, - SessionToken: sessionToken, - Account: identity.Account, - } - - // Authorize using IAM - ctx := r.Context() - errCode := iam.iamIntegration.AuthorizeAction(ctx, iamIdentity, action, bucket, object, r) - if errCode != s3err.ErrNone { - glog.V(3).Infof("IAM authorization failed for multipart operation: principal=%s operation=%s action=%s bucket=%s object=%s", - iamIdentity.Principal, operation, action, bucket, object) - return errCode - } - - glog.V(3).Infof("IAM authorization succeeded for multipart operation: principal=%s operation=%s action=%s bucket=%s object=%s", - iamIdentity.Principal, operation, action, bucket, object) - return s3err.ErrNone -} - -// ValidateMultipartRequestWithPolicy validates multipart request against security policy -func (policy *MultipartUploadPolicy) ValidateMultipartRequestWithPolicy(req *MultipartUploadRequest) error { - if req == nil { - return fmt.Errorf("multipart request cannot be nil") - } - - // Validate part size for upload part operations - if req.Operation == string(MultipartOpUploadPart) { - if req.ContentSize > policy.MaxPartSize { - return fmt.Errorf("part size %d exceeds maximum allowed %d", req.ContentSize, policy.MaxPartSize) - } - - // Minimum part size validation (except for last part) - // Note: Last part validation would require knowing if this is the final part - if req.ContentSize < policy.MinPartSize && req.ContentSize > 0 { - glog.V(2).Infof("Part size %d is below minimum %d - assuming last part", req.ContentSize, policy.MinPartSize) - } - - // Validate part number - if req.PartNumber < 1 || req.PartNumber > policy.MaxParts { - return fmt.Errorf("part number %d is invalid (must be 1-%d)", req.PartNumber, policy.MaxParts) - } - } - - // Validate required headers first - if req.Headers != nil { - for _, requiredHeader := range policy.RequiredHeaders { - if _, exists := req.Headers[requiredHeader]; !exists { - // Check lowercase version - if _, exists := req.Headers[strings.ToLower(requiredHeader)]; !exists { - return fmt.Errorf("required header %s is missing", requiredHeader) - } - } - } - } - - // Validate content type if specified - if len(policy.AllowedContentTypes) > 0 && req.Headers != nil { - contentType := req.Headers["Content-Type"] - if contentType == "" { - contentType = req.Headers["content-type"] - } - - allowed := false - for _, allowedType := range policy.AllowedContentTypes { - if contentType == allowedType { - allowed = true - break - } - } - - if !allowed { - return fmt.Errorf("content type %s is not allowed", contentType) - } - } - - return nil -} - -// Enhanced multipart handlers with IAM integration - -// NewMultipartUploadWithIAM handles initiate multipart upload with IAM validation -func (s3a *S3ApiServer) NewMultipartUploadWithIAM(w http.ResponseWriter, r *http.Request) { - // Validate IAM permissions first - if s3a.iam.iamIntegration != nil { - if identity, errCode := s3a.iam.authRequest(r, s3_constants.ACTION_WRITE); errCode != s3err.ErrNone { - s3err.WriteErrorResponse(w, r, errCode) - return - } else { - // Additional multipart-specific IAM validation - if errCode := s3a.iam.ValidateMultipartOperationWithIAM(r, identity, MultipartOpInitiate); errCode != s3err.ErrNone { - s3err.WriteErrorResponse(w, r, errCode) - return - } - } - } - - // Delegate to existing handler - s3a.NewMultipartUploadHandler(w, r) -} - -// CompleteMultipartUploadWithIAM handles complete multipart upload with IAM validation -func (s3a *S3ApiServer) CompleteMultipartUploadWithIAM(w http.ResponseWriter, r *http.Request) { - // Validate IAM permissions first - if s3a.iam.iamIntegration != nil { - if identity, errCode := s3a.iam.authRequest(r, s3_constants.ACTION_WRITE); errCode != s3err.ErrNone { - s3err.WriteErrorResponse(w, r, errCode) - return - } else { - // Additional multipart-specific IAM validation - if errCode := s3a.iam.ValidateMultipartOperationWithIAM(r, identity, MultipartOpComplete); errCode != s3err.ErrNone { - s3err.WriteErrorResponse(w, r, errCode) - return - } - } - } - - // Delegate to existing handler - s3a.CompleteMultipartUploadHandler(w, r) -} - -// AbortMultipartUploadWithIAM handles abort multipart upload with IAM validation -func (s3a *S3ApiServer) AbortMultipartUploadWithIAM(w http.ResponseWriter, r *http.Request) { - // Validate IAM permissions first - if s3a.iam.iamIntegration != nil { - if identity, errCode := s3a.iam.authRequest(r, s3_constants.ACTION_WRITE); errCode != s3err.ErrNone { - s3err.WriteErrorResponse(w, r, errCode) - return - } else { - // Additional multipart-specific IAM validation - if errCode := s3a.iam.ValidateMultipartOperationWithIAM(r, identity, MultipartOpAbort); errCode != s3err.ErrNone { - s3err.WriteErrorResponse(w, r, errCode) - return - } - } - } - - // Delegate to existing handler - s3a.AbortMultipartUploadHandler(w, r) -} - -// ListMultipartUploadsWithIAM handles list multipart uploads with IAM validation -func (s3a *S3ApiServer) ListMultipartUploadsWithIAM(w http.ResponseWriter, r *http.Request) { - // Validate IAM permissions first - if s3a.iam.iamIntegration != nil { - if identity, errCode := s3a.iam.authRequest(r, s3_constants.ACTION_LIST); errCode != s3err.ErrNone { - s3err.WriteErrorResponse(w, r, errCode) - return - } else { - // Additional multipart-specific IAM validation - if errCode := s3a.iam.ValidateMultipartOperationWithIAM(r, identity, MultipartOpList); errCode != s3err.ErrNone { - s3err.WriteErrorResponse(w, r, errCode) - return - } - } - } - - // Delegate to existing handler - s3a.ListMultipartUploadsHandler(w, r) -} - -// UploadPartWithIAM handles upload part with IAM validation -func (s3a *S3ApiServer) UploadPartWithIAM(w http.ResponseWriter, r *http.Request) { - // Validate IAM permissions first - if s3a.iam.iamIntegration != nil { - if identity, errCode := s3a.iam.authRequest(r, s3_constants.ACTION_WRITE); errCode != s3err.ErrNone { - s3err.WriteErrorResponse(w, r, errCode) - return - } else { - // Additional multipart-specific IAM validation - if errCode := s3a.iam.ValidateMultipartOperationWithIAM(r, identity, MultipartOpUploadPart); errCode != s3err.ErrNone { - s3err.WriteErrorResponse(w, r, errCode) - return - } - - // Validate part size and other policies - if err := s3a.validateUploadPartRequest(r); err != nil { - glog.Errorf("Upload part validation failed: %v", err) - s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest) - return - } - } - } - - // Delegate to existing object PUT handler (which handles upload part) - s3a.PutObjectHandler(w, r) -} - -// Helper functions - -// determineMultipartS3Action maps multipart operations to granular S3 actions -// This enables fine-grained IAM policies for multipart upload operations -func determineMultipartS3Action(operation MultipartOperation) Action { - switch operation { - case MultipartOpInitiate: - return s3_constants.S3_ACTION_CREATE_MULTIPART - case MultipartOpUploadPart: - return s3_constants.S3_ACTION_UPLOAD_PART - case MultipartOpComplete: - return s3_constants.S3_ACTION_COMPLETE_MULTIPART - case MultipartOpAbort: - return s3_constants.S3_ACTION_ABORT_MULTIPART - case MultipartOpList: - return s3_constants.S3_ACTION_LIST_MULTIPART_UPLOADS - case MultipartOpListParts: - return s3_constants.S3_ACTION_LIST_PARTS - default: - // Fail closed for unmapped operations to prevent unintended access - glog.Errorf("unmapped multipart operation: %s", operation) - return "s3:InternalErrorUnknownMultipartAction" // Non-existent action ensures denial - } -} - -// extractSessionTokenFromRequest extracts session token from various request sources -func extractSessionTokenFromRequest(r *http.Request) string { - // Check Authorization header for Bearer token - if authHeader := r.Header.Get("Authorization"); authHeader != "" { - if strings.HasPrefix(authHeader, "Bearer ") { - return strings.TrimPrefix(authHeader, "Bearer ") - } - } - - // Check X-Amz-Security-Token header - if token := r.Header.Get("X-Amz-Security-Token"); token != "" { - return token - } - - // Check query parameters for presigned URL tokens - if token := r.URL.Query().Get("X-Amz-Security-Token"); token != "" { - return token - } - - return "" -} - -// validateUploadPartRequest validates upload part request against policies -func (s3a *S3ApiServer) validateUploadPartRequest(r *http.Request) error { - // Get default multipart policy - policy := DefaultMultipartUploadPolicy() - - // Extract part number from query - partNumberStr := r.URL.Query().Get("partNumber") - if partNumberStr == "" { - return fmt.Errorf("missing partNumber parameter") - } - - partNumber, err := strconv.Atoi(partNumberStr) - if err != nil { - return fmt.Errorf("invalid partNumber: %v", err) - } - - // Get content length - contentLength := r.ContentLength - if contentLength < 0 { - contentLength = 0 - } - - // Create multipart request for validation - bucket, object := s3_constants.GetBucketAndObject(r) - multipartReq := &MultipartUploadRequest{ - Bucket: bucket, - ObjectKey: object, - PartNumber: partNumber, - Operation: string(MultipartOpUploadPart), - ContentSize: contentLength, - Headers: make(map[string]string), - } - - // Copy relevant headers - for key, values := range r.Header { - if len(values) > 0 { - multipartReq.Headers[key] = values[0] - } - } - - // Validate against policy - return policy.ValidateMultipartRequestWithPolicy(multipartReq) -} - -// DefaultMultipartUploadPolicy returns a default multipart upload security policy -func DefaultMultipartUploadPolicy() *MultipartUploadPolicy { - return &MultipartUploadPolicy{ - MaxPartSize: 5 * 1024 * 1024 * 1024, // 5GB AWS limit - MinPartSize: 5 * 1024 * 1024, // 5MB AWS minimum (except last part) - MaxParts: 10000, // AWS limit - MaxUploadDuration: 7 * 24 * time.Hour, // 7 days to complete upload - AllowedContentTypes: []string{}, // Empty means all types allowed - RequiredHeaders: []string{}, // No required headers by default - IPWhitelist: []string{}, // Empty means no IP restrictions - } -} - -// MultipartUploadSession represents an ongoing multipart upload session -type MultipartUploadSession struct { - UploadID string `json:"upload_id"` - Bucket string `json:"bucket"` - ObjectKey string `json:"object_key"` - Initiator string `json:"initiator"` // User who initiated the upload - Owner string `json:"owner"` // Object owner - CreatedAt time.Time `json:"created_at"` // When upload was initiated - Parts []MultipartUploadPart `json:"parts"` // Uploaded parts - Metadata map[string]string `json:"metadata"` // Object metadata - Policy *MultipartUploadPolicy `json:"policy"` // Applied security policy - SessionToken string `json:"session_token"` // IAM session token -} - -// MultipartUploadPart represents an uploaded part -type MultipartUploadPart struct { - PartNumber int `json:"part_number"` - Size int64 `json:"size"` - ETag string `json:"etag"` - LastModified time.Time `json:"last_modified"` - Checksum string `json:"checksum"` // Optional integrity checksum -} - -// GetMultipartUploadSessions retrieves active multipart upload sessions for a bucket -func (s3a *S3ApiServer) GetMultipartUploadSessions(bucket string) ([]*MultipartUploadSession, error) { - // This would typically query the filer for active multipart uploads - // For now, return empty list as this is a placeholder for the full implementation - return []*MultipartUploadSession{}, nil -} - -// CleanupExpiredMultipartUploads removes expired multipart upload sessions -func (s3a *S3ApiServer) CleanupExpiredMultipartUploads(maxAge time.Duration) error { - // This would typically scan for and remove expired multipart uploads - // Implementation would depend on how multipart sessions are stored in the filer - glog.V(2).Infof("Cleanup expired multipart uploads older than %v", maxAge) - return nil -} diff --git a/weed/s3api/s3_multipart_iam_test.go b/weed/s3api/s3_multipart_iam_test.go deleted file mode 100644 index 12546eb7a..000000000 --- a/weed/s3api/s3_multipart_iam_test.go +++ /dev/null @@ -1,614 +0,0 @@ -package s3api - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/seaweedfs/seaweedfs/weed/iam/integration" - "github.com/seaweedfs/seaweedfs/weed/iam/ldap" - "github.com/seaweedfs/seaweedfs/weed/iam/oidc" - "github.com/seaweedfs/seaweedfs/weed/iam/policy" - "github.com/seaweedfs/seaweedfs/weed/iam/sts" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// createTestJWTMultipart creates a test JWT token with the specified issuer, subject and signing key -func createTestJWTMultipart(t *testing.T, issuer, subject, signingKey string) string { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "iss": issuer, - "sub": subject, - "aud": "test-client-id", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - // Add claims that trust policy validation expects - "idp": "test-oidc", // Identity provider claim for trust policy matching - }) - - tokenString, err := token.SignedString([]byte(signingKey)) - require.NoError(t, err) - return tokenString -} - -// TestMultipartIAMValidation tests IAM validation for multipart operations -func TestMultipartIAMValidation(t *testing.T) { - // Set up IAM system - iamManager := setupTestIAMManagerForMultipart(t) - s3iam := NewS3IAMIntegration(iamManager, "localhost:8888") - s3iam.enabled = true - - // Create IAM with integration - iam := &IdentityAccessManagement{ - isAuthEnabled: true, - } - iam.SetIAMIntegration(s3iam) - - // Set up roles - ctx := context.Background() - setupTestRolesForMultipart(ctx, iamManager) - - // Create a valid JWT token for testing - validJWTToken := createTestJWTMultipart(t, "https://test-issuer.com", "test-user-123", "test-signing-key") - - // Get session token - response, err := iamManager.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/S3WriteRole", - WebIdentityToken: validJWTToken, - RoleSessionName: "multipart-test-session", - }) - require.NoError(t, err) - - sessionToken := response.Credentials.SessionToken - - tests := []struct { - name string - operation MultipartOperation - method string - path string - sessionToken string - expectedResult s3err.ErrorCode - }{ - { - name: "Initiate multipart upload", - operation: MultipartOpInitiate, - method: "POST", - path: "/test-bucket/test-file.txt?uploads", - sessionToken: sessionToken, - expectedResult: s3err.ErrNone, - }, - { - name: "Upload part", - operation: MultipartOpUploadPart, - method: "PUT", - path: "/test-bucket/test-file.txt?partNumber=1&uploadId=test-upload-id", - sessionToken: sessionToken, - expectedResult: s3err.ErrNone, - }, - { - name: "Complete multipart upload", - operation: MultipartOpComplete, - method: "POST", - path: "/test-bucket/test-file.txt?uploadId=test-upload-id", - sessionToken: sessionToken, - expectedResult: s3err.ErrNone, - }, - { - name: "Abort multipart upload", - operation: MultipartOpAbort, - method: "DELETE", - path: "/test-bucket/test-file.txt?uploadId=test-upload-id", - sessionToken: sessionToken, - expectedResult: s3err.ErrNone, - }, - { - name: "List multipart uploads", - operation: MultipartOpList, - method: "GET", - path: "/test-bucket?uploads", - sessionToken: sessionToken, - expectedResult: s3err.ErrNone, - }, - { - name: "Upload part without session token", - operation: MultipartOpUploadPart, - method: "PUT", - path: "/test-bucket/test-file.txt?partNumber=1&uploadId=test-upload-id", - sessionToken: "", - expectedResult: s3err.ErrNone, // Falls back to standard auth - }, - { - name: "Upload part with invalid session token", - operation: MultipartOpUploadPart, - method: "PUT", - path: "/test-bucket/test-file.txt?partNumber=1&uploadId=test-upload-id", - sessionToken: "invalid-token", - expectedResult: s3err.ErrAccessDenied, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create request for multipart operation - req := createMultipartRequest(t, tt.method, tt.path, tt.sessionToken) - - // Create identity for testing - identity := &Identity{ - Name: "test-user", - Account: &AccountAdmin, - } - - // Test validation - result := iam.ValidateMultipartOperationWithIAM(req, identity, tt.operation) - assert.Equal(t, tt.expectedResult, result, "Multipart IAM validation result should match expected") - }) - } -} - -// TestMultipartUploadPolicy tests multipart upload security policies -func TestMultipartUploadPolicy(t *testing.T) { - policy := &MultipartUploadPolicy{ - MaxPartSize: 10 * 1024 * 1024, // 10MB for testing - MinPartSize: 5 * 1024 * 1024, // 5MB minimum - MaxParts: 100, // 100 parts max for testing - AllowedContentTypes: []string{"application/json", "text/plain"}, - RequiredHeaders: []string{"Content-Type"}, - } - - tests := []struct { - name string - request *MultipartUploadRequest - expectedError string - }{ - { - name: "Valid upload part request", - request: &MultipartUploadRequest{ - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - PartNumber: 1, - Operation: string(MultipartOpUploadPart), - ContentSize: 8 * 1024 * 1024, // 8MB - Headers: map[string]string{ - "Content-Type": "application/json", - }, - }, - expectedError: "", - }, - { - name: "Part size too large", - request: &MultipartUploadRequest{ - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - PartNumber: 1, - Operation: string(MultipartOpUploadPart), - ContentSize: 15 * 1024 * 1024, // 15MB exceeds limit - Headers: map[string]string{ - "Content-Type": "application/json", - }, - }, - expectedError: "part size", - }, - { - name: "Invalid part number (too high)", - request: &MultipartUploadRequest{ - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - PartNumber: 150, // Exceeds max parts - Operation: string(MultipartOpUploadPart), - ContentSize: 8 * 1024 * 1024, - Headers: map[string]string{ - "Content-Type": "application/json", - }, - }, - expectedError: "part number", - }, - { - name: "Invalid part number (too low)", - request: &MultipartUploadRequest{ - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - PartNumber: 0, // Must be >= 1 - Operation: string(MultipartOpUploadPart), - ContentSize: 8 * 1024 * 1024, - Headers: map[string]string{ - "Content-Type": "application/json", - }, - }, - expectedError: "part number", - }, - { - name: "Content type not allowed", - request: &MultipartUploadRequest{ - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - PartNumber: 1, - Operation: string(MultipartOpUploadPart), - ContentSize: 8 * 1024 * 1024, - Headers: map[string]string{ - "Content-Type": "video/mp4", // Not in allowed list - }, - }, - expectedError: "content type video/mp4 is not allowed", - }, - { - name: "Missing required header", - request: &MultipartUploadRequest{ - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - PartNumber: 1, - Operation: string(MultipartOpUploadPart), - ContentSize: 8 * 1024 * 1024, - Headers: map[string]string{}, // Missing Content-Type - }, - expectedError: "required header Content-Type is missing", - }, - { - name: "Non-upload operation (should not validate size)", - request: &MultipartUploadRequest{ - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - Operation: string(MultipartOpInitiate), - Headers: map[string]string{ - "Content-Type": "application/json", - }, - }, - expectedError: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := policy.ValidateMultipartRequestWithPolicy(tt.request) - - if tt.expectedError == "" { - assert.NoError(t, err, "Policy validation should succeed") - } else { - assert.Error(t, err, "Policy validation should fail") - assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text") - } - }) - } -} - -// TestMultipartS3ActionMapping tests the mapping of multipart operations to S3 actions -func TestMultipartS3ActionMapping(t *testing.T) { - tests := []struct { - operation MultipartOperation - expectedAction Action - }{ - {MultipartOpInitiate, s3_constants.S3_ACTION_CREATE_MULTIPART}, - {MultipartOpUploadPart, s3_constants.S3_ACTION_UPLOAD_PART}, - {MultipartOpComplete, s3_constants.S3_ACTION_COMPLETE_MULTIPART}, - {MultipartOpAbort, s3_constants.S3_ACTION_ABORT_MULTIPART}, - {MultipartOpList, s3_constants.S3_ACTION_LIST_MULTIPART_UPLOADS}, - {MultipartOpListParts, s3_constants.S3_ACTION_LIST_PARTS}, - {MultipartOperation("unknown"), "s3:InternalErrorUnknownMultipartAction"}, // Fail-closed for security - } - - for _, tt := range tests { - t.Run(string(tt.operation), func(t *testing.T) { - action := determineMultipartS3Action(tt.operation) - assert.Equal(t, tt.expectedAction, action, "S3 action mapping should match expected") - }) - } -} - -// TestSessionTokenExtraction tests session token extraction from various sources -func TestSessionTokenExtraction(t *testing.T) { - tests := []struct { - name string - setupRequest func() *http.Request - expectedToken string - }{ - { - name: "Bearer token in Authorization header", - setupRequest: func() *http.Request { - req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt", nil) - req.Header.Set("Authorization", "Bearer test-session-token-123") - return req - }, - expectedToken: "test-session-token-123", - }, - { - name: "X-Amz-Security-Token header", - setupRequest: func() *http.Request { - req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt", nil) - req.Header.Set("X-Amz-Security-Token", "security-token-456") - return req - }, - expectedToken: "security-token-456", - }, - { - name: "X-Amz-Security-Token query parameter", - setupRequest: func() *http.Request { - req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt?X-Amz-Security-Token=query-token-789", nil) - return req - }, - expectedToken: "query-token-789", - }, - { - name: "No token present", - setupRequest: func() *http.Request { - return httptest.NewRequest("PUT", "/test-bucket/test-file.txt", nil) - }, - expectedToken: "", - }, - { - name: "Authorization header without Bearer", - setupRequest: func() *http.Request { - req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt", nil) - req.Header.Set("Authorization", "AWS access_key:signature") - return req - }, - expectedToken: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := tt.setupRequest() - token := extractSessionTokenFromRequest(req) - assert.Equal(t, tt.expectedToken, token, "Extracted token should match expected") - }) - } -} - -// TestUploadPartValidation tests upload part request validation -func TestUploadPartValidation(t *testing.T) { - s3Server := &S3ApiServer{} - - tests := []struct { - name string - setupRequest func() *http.Request - expectedError string - }{ - { - name: "Valid upload part request", - setupRequest: func() *http.Request { - req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt?partNumber=1&uploadId=test-123", nil) - req.Header.Set("Content-Type", "application/octet-stream") - req.ContentLength = 6 * 1024 * 1024 // 6MB - return req - }, - expectedError: "", - }, - { - name: "Missing partNumber parameter", - setupRequest: func() *http.Request { - req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt?uploadId=test-123", nil) - req.Header.Set("Content-Type", "application/octet-stream") - req.ContentLength = 6 * 1024 * 1024 - return req - }, - expectedError: "missing partNumber parameter", - }, - { - name: "Invalid partNumber format", - setupRequest: func() *http.Request { - req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt?partNumber=abc&uploadId=test-123", nil) - req.Header.Set("Content-Type", "application/octet-stream") - req.ContentLength = 6 * 1024 * 1024 - return req - }, - expectedError: "invalid partNumber", - }, - { - name: "Part size too large", - setupRequest: func() *http.Request { - req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt?partNumber=1&uploadId=test-123", nil) - req.Header.Set("Content-Type", "application/octet-stream") - req.ContentLength = 6 * 1024 * 1024 * 1024 // 6GB exceeds 5GB limit - return req - }, - expectedError: "part size", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := tt.setupRequest() - err := s3Server.validateUploadPartRequest(req) - - if tt.expectedError == "" { - assert.NoError(t, err, "Upload part validation should succeed") - } else { - assert.Error(t, err, "Upload part validation should fail") - assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text") - } - }) - } -} - -// TestDefaultMultipartUploadPolicy tests the default policy configuration -func TestDefaultMultipartUploadPolicy(t *testing.T) { - policy := DefaultMultipartUploadPolicy() - - assert.Equal(t, int64(5*1024*1024*1024), policy.MaxPartSize, "Max part size should be 5GB") - assert.Equal(t, int64(5*1024*1024), policy.MinPartSize, "Min part size should be 5MB") - assert.Equal(t, 10000, policy.MaxParts, "Max parts should be 10,000") - assert.Equal(t, 7*24*time.Hour, policy.MaxUploadDuration, "Max upload duration should be 7 days") - assert.Empty(t, policy.AllowedContentTypes, "Should allow all content types by default") - assert.Empty(t, policy.RequiredHeaders, "Should have no required headers by default") - assert.Empty(t, policy.IPWhitelist, "Should have no IP restrictions by default") -} - -// TestMultipartUploadSession tests multipart upload session structure -func TestMultipartUploadSession(t *testing.T) { - session := &MultipartUploadSession{ - UploadID: "test-upload-123", - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - Initiator: "arn:aws:iam::user/testuser", - Owner: "arn:aws:iam::user/testuser", - CreatedAt: time.Now(), - Parts: []MultipartUploadPart{ - { - PartNumber: 1, - Size: 5 * 1024 * 1024, - ETag: "abc123", - LastModified: time.Now(), - Checksum: "sha256:def456", - }, - }, - Metadata: map[string]string{ - "Content-Type": "application/octet-stream", - "x-amz-meta-custom": "value", - }, - Policy: DefaultMultipartUploadPolicy(), - SessionToken: "session-token-789", - } - - assert.NotEmpty(t, session.UploadID, "Upload ID should not be empty") - assert.NotEmpty(t, session.Bucket, "Bucket should not be empty") - assert.NotEmpty(t, session.ObjectKey, "Object key should not be empty") - assert.Len(t, session.Parts, 1, "Should have one part") - assert.Equal(t, 1, session.Parts[0].PartNumber, "Part number should be 1") - assert.NotNil(t, session.Policy, "Policy should not be nil") -} - -// Helper functions for tests - -func setupTestIAMManagerForMultipart(t *testing.T) *integration.IAMManager { - // Create IAM manager - manager := integration.NewIAMManager() - - // Initialize with test configuration - config := &integration.IAMConfig{ - STS: &sts.STSConfig{ - TokenDuration: sts.FlexibleDuration{Duration: time.Hour}, - MaxSessionLength: sts.FlexibleDuration{Duration: time.Hour * 12}, - Issuer: "test-sts", - SigningKey: []byte("test-signing-key-32-characters-long"), - }, - Policy: &policy.PolicyEngineConfig{ - DefaultEffect: "Deny", - StoreType: "memory", - }, - Roles: &integration.RoleStoreConfig{ - StoreType: "memory", - }, - } - - err := manager.Initialize(config, func() string { - return "localhost:8888" // Mock filer address for testing - }) - require.NoError(t, err) - - // Set up test identity providers - setupTestProvidersForMultipart(t, manager) - - return manager -} - -func setupTestProvidersForMultipart(t *testing.T, manager *integration.IAMManager) { - // Set up OIDC provider - oidcProvider := oidc.NewMockOIDCProvider("test-oidc") - oidcConfig := &oidc.OIDCConfig{ - Issuer: "https://test-issuer.com", - ClientID: "test-client-id", - } - err := oidcProvider.Initialize(oidcConfig) - require.NoError(t, err) - oidcProvider.SetupDefaultTestData() - - // Set up LDAP provider - ldapProvider := ldap.NewMockLDAPProvider("test-ldap") - err = ldapProvider.Initialize(nil) // Mock doesn't need real config - require.NoError(t, err) - ldapProvider.SetupDefaultTestData() - - // Register providers - err = manager.RegisterIdentityProvider(oidcProvider) - require.NoError(t, err) - err = manager.RegisterIdentityProvider(ldapProvider) - require.NoError(t, err) -} - -func setupTestRolesForMultipart(ctx context.Context, manager *integration.IAMManager) { - // Create write policy for multipart operations - writePolicy := &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "AllowS3MultipartOperations", - Effect: "Allow", - Action: []string{ - "s3:PutObject", - "s3:GetObject", - "s3:ListBucket", - "s3:DeleteObject", - "s3:CreateMultipartUpload", - "s3:UploadPart", - "s3:CompleteMultipartUpload", - "s3:AbortMultipartUpload", - "s3:ListBucketMultipartUploads", - "s3:ListMultipartUploadParts", - }, - Resource: []string{ - "arn:aws:s3:::*", - "arn:aws:s3:::*/*", - }, - }, - }, - } - - manager.CreatePolicy(ctx, "", "S3WritePolicy", writePolicy) - - // Create write role - manager.CreateRole(ctx, "", "S3WriteRole", &integration.RoleDefinition{ - RoleName: "S3WriteRole", - TrustPolicy: &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Effect: "Allow", - Principal: map[string]interface{}{ - "Federated": "test-oidc", - }, - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - }, - }, - }, - AttachedPolicies: []string{"S3WritePolicy"}, - }) - - // Create a role for multipart users - manager.CreateRole(ctx, "", "MultipartUser", &integration.RoleDefinition{ - RoleName: "MultipartUser", - TrustPolicy: &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Effect: "Allow", - Principal: map[string]interface{}{ - "Federated": "test-oidc", - }, - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - }, - }, - }, - AttachedPolicies: []string{"S3WritePolicy"}, - }) -} - -func createMultipartRequest(t *testing.T, method, path, sessionToken string) *http.Request { - req := httptest.NewRequest(method, path, nil) - - // Add session token if provided - if sessionToken != "" { - req.Header.Set("Authorization", "Bearer "+sessionToken) - // Set the principal ARN header that matches the assumed role from the test setup - // This corresponds to the role "arn:aws:iam::role/S3WriteRole" with session name "multipart-test-session" - req.Header.Set("X-SeaweedFS-Principal", "arn:aws:sts::assumed-role/S3WriteRole/multipart-test-session") - } - - // Add common headers - req.Header.Set("Content-Type", "application/octet-stream") - - return req -} diff --git a/weed/s3api/s3_policy_templates.go b/weed/s3api/s3_policy_templates.go deleted file mode 100644 index 1506c68ee..000000000 --- a/weed/s3api/s3_policy_templates.go +++ /dev/null @@ -1,618 +0,0 @@ -package s3api - -import ( - "time" - - "github.com/seaweedfs/seaweedfs/weed/iam/policy" -) - -// S3PolicyTemplates provides pre-built IAM policy templates for common S3 use cases -type S3PolicyTemplates struct{} - -// NewS3PolicyTemplates creates a new policy templates provider -func NewS3PolicyTemplates() *S3PolicyTemplates { - return &S3PolicyTemplates{} -} - -// GetS3ReadOnlyPolicy returns a policy that allows read-only access to all S3 resources -func (t *S3PolicyTemplates) GetS3ReadOnlyPolicy() *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "S3ReadOnlyAccess", - Effect: "Allow", - Action: []string{ - "s3:GetObject", - "s3:GetObjectVersion", - "s3:ListBucket", - "s3:ListBucketVersions", - "s3:GetBucketLocation", - "s3:GetBucketVersioning", - "s3:ListAllMyBuckets", - }, - Resource: []string{ - "arn:aws:s3:::*", - "arn:aws:s3:::*/*", - }, - }, - }, - } -} - -// GetS3WriteOnlyPolicy returns a policy that allows write-only access to all S3 resources -func (t *S3PolicyTemplates) GetS3WriteOnlyPolicy() *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "S3WriteOnlyAccess", - Effect: "Allow", - Action: []string{ - "s3:PutObject", - "s3:PutObjectAcl", - "s3:CreateMultipartUpload", - "s3:UploadPart", - "s3:CompleteMultipartUpload", - "s3:AbortMultipartUpload", - "s3:ListMultipartUploads", - "s3:ListParts", - }, - Resource: []string{ - "arn:aws:s3:::*", - "arn:aws:s3:::*/*", - }, - }, - }, - } -} - -// GetS3AdminPolicy returns a policy that allows full admin access to all S3 resources -func (t *S3PolicyTemplates) GetS3AdminPolicy() *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "S3FullAccess", - Effect: "Allow", - Action: []string{ - "s3:*", - }, - Resource: []string{ - "arn:aws:s3:::*", - "arn:aws:s3:::*/*", - }, - }, - }, - } -} - -// GetBucketSpecificReadPolicy returns a policy for read-only access to a specific bucket -func (t *S3PolicyTemplates) GetBucketSpecificReadPolicy(bucketName string) *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "BucketSpecificReadAccess", - Effect: "Allow", - Action: []string{ - "s3:GetObject", - "s3:GetObjectVersion", - "s3:ListBucket", - "s3:ListBucketVersions", - "s3:GetBucketLocation", - }, - Resource: []string{ - "arn:aws:s3:::" + bucketName, - "arn:aws:s3:::" + bucketName + "/*", - }, - }, - }, - } -} - -// GetBucketSpecificWritePolicy returns a policy for write-only access to a specific bucket -func (t *S3PolicyTemplates) GetBucketSpecificWritePolicy(bucketName string) *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "BucketSpecificWriteAccess", - Effect: "Allow", - Action: []string{ - "s3:PutObject", - "s3:PutObjectAcl", - "s3:CreateMultipartUpload", - "s3:UploadPart", - "s3:CompleteMultipartUpload", - "s3:AbortMultipartUpload", - "s3:ListMultipartUploads", - "s3:ListParts", - }, - Resource: []string{ - "arn:aws:s3:::" + bucketName, - "arn:aws:s3:::" + bucketName + "/*", - }, - }, - }, - } -} - -// GetPathBasedAccessPolicy returns a policy that restricts access to a specific path within a bucket -func (t *S3PolicyTemplates) GetPathBasedAccessPolicy(bucketName, pathPrefix string) *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "ListBucketPermission", - Effect: "Allow", - Action: []string{ - "s3:ListBucket", - }, - Resource: []string{ - "arn:aws:s3:::" + bucketName, - }, - Condition: map[string]map[string]interface{}{ - "StringLike": map[string]interface{}{ - "s3:prefix": []string{pathPrefix + "/*"}, - }, - }, - }, - { - Sid: "PathBasedObjectAccess", - Effect: "Allow", - Action: []string{ - "s3:GetObject", - "s3:PutObject", - "s3:DeleteObject", - "s3:CreateMultipartUpload", - "s3:UploadPart", - "s3:CompleteMultipartUpload", - "s3:AbortMultipartUpload", - }, - Resource: []string{ - "arn:aws:s3:::" + bucketName + "/" + pathPrefix + "/*", - }, - }, - }, - } -} - -// GetIPRestrictedPolicy returns a policy that restricts access based on source IP -func (t *S3PolicyTemplates) GetIPRestrictedPolicy(allowedCIDRs []string) *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "IPRestrictedS3Access", - Effect: "Allow", - Action: []string{ - "s3:*", - }, - Resource: []string{ - "arn:aws:s3:::*", - "arn:aws:s3:::*/*", - }, - Condition: map[string]map[string]interface{}{ - "IpAddress": map[string]interface{}{ - "aws:SourceIp": allowedCIDRs, - }, - }, - }, - }, - } -} - -// GetTimeBasedAccessPolicy returns a policy that allows access only during specific hours -func (t *S3PolicyTemplates) GetTimeBasedAccessPolicy(startHour, endHour int) *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "TimeBasedS3Access", - Effect: "Allow", - Action: []string{ - "s3:GetObject", - "s3:PutObject", - "s3:ListBucket", - }, - Resource: []string{ - "arn:aws:s3:::*", - "arn:aws:s3:::*/*", - }, - Condition: map[string]map[string]interface{}{ - "DateGreaterThan": map[string]interface{}{ - "aws:CurrentTime": time.Now().Format("2006-01-02") + "T" + - formatHour(startHour) + ":00:00Z", - }, - "DateLessThan": map[string]interface{}{ - "aws:CurrentTime": time.Now().Format("2006-01-02") + "T" + - formatHour(endHour) + ":00:00Z", - }, - }, - }, - }, - } -} - -// GetMultipartUploadPolicy returns a policy specifically for multipart upload operations -func (t *S3PolicyTemplates) GetMultipartUploadPolicy(bucketName string) *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "MultipartUploadOperations", - Effect: "Allow", - Action: []string{ - "s3:CreateMultipartUpload", - "s3:UploadPart", - "s3:CompleteMultipartUpload", - "s3:AbortMultipartUpload", - "s3:ListMultipartUploads", - "s3:ListParts", - }, - Resource: []string{ - "arn:aws:s3:::" + bucketName + "/*", - }, - }, - { - Sid: "ListBucketForMultipart", - Effect: "Allow", - Action: []string{ - "s3:ListBucket", - }, - Resource: []string{ - "arn:aws:s3:::" + bucketName, - }, - }, - }, - } -} - -// GetPresignedURLPolicy returns a policy for generating and using presigned URLs -func (t *S3PolicyTemplates) GetPresignedURLPolicy(bucketName string) *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "PresignedURLAccess", - Effect: "Allow", - Action: []string{ - "s3:GetObject", - "s3:PutObject", - }, - Resource: []string{ - "arn:aws:s3:::" + bucketName + "/*", - }, - Condition: map[string]map[string]interface{}{ - "StringEquals": map[string]interface{}{ - "s3:x-amz-signature-version": "AWS4-HMAC-SHA256", - }, - }, - }, - }, - } -} - -// GetTemporaryAccessPolicy returns a policy for temporary access with expiration -func (t *S3PolicyTemplates) GetTemporaryAccessPolicy(bucketName string, expirationHours int) *policy.PolicyDocument { - expirationTime := time.Now().Add(time.Duration(expirationHours) * time.Hour) - - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "TemporaryS3Access", - Effect: "Allow", - Action: []string{ - "s3:GetObject", - "s3:PutObject", - "s3:ListBucket", - }, - Resource: []string{ - "arn:aws:s3:::" + bucketName, - "arn:aws:s3:::" + bucketName + "/*", - }, - Condition: map[string]map[string]interface{}{ - "DateLessThan": map[string]interface{}{ - "aws:CurrentTime": expirationTime.UTC().Format("2006-01-02T15:04:05Z"), - }, - }, - }, - }, - } -} - -// GetContentTypeRestrictedPolicy returns a policy that restricts uploads to specific content types -func (t *S3PolicyTemplates) GetContentTypeRestrictedPolicy(bucketName string, allowedContentTypes []string) *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "ContentTypeRestrictedUpload", - Effect: "Allow", - Action: []string{ - "s3:PutObject", - "s3:CreateMultipartUpload", - "s3:UploadPart", - "s3:CompleteMultipartUpload", - }, - Resource: []string{ - "arn:aws:s3:::" + bucketName + "/*", - }, - Condition: map[string]map[string]interface{}{ - "StringEquals": map[string]interface{}{ - "s3:content-type": allowedContentTypes, - }, - }, - }, - { - Sid: "ReadAccess", - Effect: "Allow", - Action: []string{ - "s3:GetObject", - "s3:ListBucket", - }, - Resource: []string{ - "arn:aws:s3:::" + bucketName, - "arn:aws:s3:::" + bucketName + "/*", - }, - }, - }, - } -} - -// GetDenyDeletePolicy returns a policy that allows all operations except delete -func (t *S3PolicyTemplates) GetDenyDeletePolicy() *policy.PolicyDocument { - return &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "AllowAllExceptDelete", - Effect: "Allow", - Action: []string{ - "s3:GetObject", - "s3:GetObjectVersion", - "s3:PutObject", - "s3:PutObjectAcl", - "s3:ListBucket", - "s3:ListBucketVersions", - "s3:CreateMultipartUpload", - "s3:UploadPart", - "s3:CompleteMultipartUpload", - "s3:AbortMultipartUpload", - "s3:ListMultipartUploads", - "s3:ListParts", - }, - Resource: []string{ - "arn:aws:s3:::*", - "arn:aws:s3:::*/*", - }, - }, - { - Sid: "DenyDeleteOperations", - Effect: "Deny", - Action: []string{ - "s3:DeleteObject", - "s3:DeleteObjectVersion", - "s3:DeleteBucket", - }, - Resource: []string{ - "arn:aws:s3:::*", - "arn:aws:s3:::*/*", - }, - }, - }, - } -} - -// Helper function to format hour with leading zero -func formatHour(hour int) string { - if hour < 10 { - return "0" + string(rune('0'+hour)) - } - return string(rune('0'+hour/10)) + string(rune('0'+hour%10)) -} - -// PolicyTemplateDefinition represents metadata about a policy template -type PolicyTemplateDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Category string `json:"category"` - UseCase string `json:"use_case"` - Parameters []PolicyTemplateParam `json:"parameters,omitempty"` - Policy *policy.PolicyDocument `json:"policy"` -} - -// PolicyTemplateParam represents a parameter for customizing policy templates -type PolicyTemplateParam struct { - Name string `json:"name"` - Type string `json:"type"` - Description string `json:"description"` - Required bool `json:"required"` - DefaultValue string `json:"default_value,omitempty"` - Example string `json:"example,omitempty"` -} - -// GetAllPolicyTemplates returns all available policy templates with metadata -func (t *S3PolicyTemplates) GetAllPolicyTemplates() []PolicyTemplateDefinition { - return []PolicyTemplateDefinition{ - { - Name: "S3ReadOnlyAccess", - Description: "Provides read-only access to all S3 buckets and objects", - Category: "Basic Access", - UseCase: "Data consumers, backup services, monitoring applications", - Policy: t.GetS3ReadOnlyPolicy(), - }, - { - Name: "S3WriteOnlyAccess", - Description: "Provides write-only access to all S3 buckets and objects", - Category: "Basic Access", - UseCase: "Data ingestion services, backup applications", - Policy: t.GetS3WriteOnlyPolicy(), - }, - { - Name: "S3AdminAccess", - Description: "Provides full administrative access to all S3 resources", - Category: "Administrative", - UseCase: "S3 administrators, service accounts with full control", - Policy: t.GetS3AdminPolicy(), - }, - { - Name: "BucketSpecificRead", - Description: "Provides read-only access to a specific bucket", - Category: "Bucket-Specific", - UseCase: "Applications that need access to specific data sets", - Parameters: []PolicyTemplateParam{ - { - Name: "bucketName", - Type: "string", - Description: "Name of the S3 bucket to grant access to", - Required: true, - Example: "my-data-bucket", - }, - }, - Policy: t.GetBucketSpecificReadPolicy("${bucketName}"), - }, - { - Name: "BucketSpecificWrite", - Description: "Provides write-only access to a specific bucket", - Category: "Bucket-Specific", - UseCase: "Upload services, data ingestion for specific datasets", - Parameters: []PolicyTemplateParam{ - { - Name: "bucketName", - Type: "string", - Description: "Name of the S3 bucket to grant access to", - Required: true, - Example: "my-upload-bucket", - }, - }, - Policy: t.GetBucketSpecificWritePolicy("${bucketName}"), - }, - { - Name: "PathBasedAccess", - Description: "Restricts access to a specific path/prefix within a bucket", - Category: "Path-Restricted", - UseCase: "Multi-tenant applications, user-specific directories", - Parameters: []PolicyTemplateParam{ - { - Name: "bucketName", - Type: "string", - Description: "Name of the S3 bucket", - Required: true, - Example: "shared-bucket", - }, - { - Name: "pathPrefix", - Type: "string", - Description: "Path prefix to restrict access to", - Required: true, - Example: "user123/documents", - }, - }, - Policy: t.GetPathBasedAccessPolicy("${bucketName}", "${pathPrefix}"), - }, - { - Name: "IPRestrictedAccess", - Description: "Allows access only from specific IP addresses or ranges", - Category: "Security", - UseCase: "Corporate networks, office-based access, VPN restrictions", - Parameters: []PolicyTemplateParam{ - { - Name: "allowedCIDRs", - Type: "array", - Description: "List of allowed IP addresses or CIDR ranges", - Required: true, - Example: "[\"192.168.1.0/24\", \"10.0.0.0/8\"]", - }, - }, - Policy: t.GetIPRestrictedPolicy([]string{"${allowedCIDRs}"}), - }, - { - Name: "MultipartUploadOnly", - Description: "Allows only multipart upload operations on a specific bucket", - Category: "Upload-Specific", - UseCase: "Large file upload services, streaming applications", - Parameters: []PolicyTemplateParam{ - { - Name: "bucketName", - Type: "string", - Description: "Name of the S3 bucket for multipart uploads", - Required: true, - Example: "large-files-bucket", - }, - }, - Policy: t.GetMultipartUploadPolicy("${bucketName}"), - }, - { - Name: "PresignedURLAccess", - Description: "Policy for generating and using presigned URLs", - Category: "Presigned URLs", - UseCase: "Frontend applications, temporary file sharing", - Parameters: []PolicyTemplateParam{ - { - Name: "bucketName", - Type: "string", - Description: "Name of the S3 bucket for presigned URL access", - Required: true, - Example: "shared-files-bucket", - }, - }, - Policy: t.GetPresignedURLPolicy("${bucketName}"), - }, - { - Name: "ContentTypeRestricted", - Description: "Restricts uploads to specific content types", - Category: "Content Control", - UseCase: "Image galleries, document repositories, media libraries", - Parameters: []PolicyTemplateParam{ - { - Name: "bucketName", - Type: "string", - Description: "Name of the S3 bucket", - Required: true, - Example: "media-bucket", - }, - { - Name: "allowedContentTypes", - Type: "array", - Description: "List of allowed MIME content types", - Required: true, - Example: "[\"image/jpeg\", \"image/png\", \"video/mp4\"]", - }, - }, - Policy: t.GetContentTypeRestrictedPolicy("${bucketName}", []string{"${allowedContentTypes}"}), - }, - { - Name: "DenyDeleteAccess", - Description: "Allows all operations except delete (immutable storage)", - Category: "Data Protection", - UseCase: "Compliance storage, audit logs, backup retention", - Policy: t.GetDenyDeletePolicy(), - }, - } -} - -// GetPolicyTemplateByName returns a specific policy template by name -func (t *S3PolicyTemplates) GetPolicyTemplateByName(name string) *PolicyTemplateDefinition { - templates := t.GetAllPolicyTemplates() - for _, template := range templates { - if template.Name == name { - return &template - } - } - return nil -} - -// GetPolicyTemplatesByCategory returns all policy templates in a specific category -func (t *S3PolicyTemplates) GetPolicyTemplatesByCategory(category string) []PolicyTemplateDefinition { - var result []PolicyTemplateDefinition - templates := t.GetAllPolicyTemplates() - for _, template := range templates { - if template.Category == category { - result = append(result, template) - } - } - return result -} diff --git a/weed/s3api/s3_policy_templates_test.go b/weed/s3api/s3_policy_templates_test.go deleted file mode 100644 index 453260c2a..000000000 --- a/weed/s3api/s3_policy_templates_test.go +++ /dev/null @@ -1,504 +0,0 @@ -package s3api - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestS3PolicyTemplates(t *testing.T) { - templates := NewS3PolicyTemplates() - - t.Run("S3ReadOnlyPolicy", func(t *testing.T) { - policy := templates.GetS3ReadOnlyPolicy() - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 1) - - stmt := policy.Statement[0] - assert.Equal(t, "Allow", stmt.Effect) - assert.Equal(t, "S3ReadOnlyAccess", stmt.Sid) - assert.Contains(t, stmt.Action, "s3:GetObject") - assert.Contains(t, stmt.Action, "s3:ListBucket") - assert.NotContains(t, stmt.Action, "s3:PutObject") - assert.NotContains(t, stmt.Action, "s3:DeleteObject") - - assert.Contains(t, stmt.Resource, "arn:aws:s3:::*") - assert.Contains(t, stmt.Resource, "arn:aws:s3:::*/*") - }) - - t.Run("S3WriteOnlyPolicy", func(t *testing.T) { - policy := templates.GetS3WriteOnlyPolicy() - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 1) - - stmt := policy.Statement[0] - assert.Equal(t, "Allow", stmt.Effect) - assert.Equal(t, "S3WriteOnlyAccess", stmt.Sid) - assert.Contains(t, stmt.Action, "s3:PutObject") - assert.Contains(t, stmt.Action, "s3:CreateMultipartUpload") - assert.NotContains(t, stmt.Action, "s3:GetObject") - assert.NotContains(t, stmt.Action, "s3:DeleteObject") - - assert.Contains(t, stmt.Resource, "arn:aws:s3:::*") - assert.Contains(t, stmt.Resource, "arn:aws:s3:::*/*") - }) - - t.Run("S3AdminPolicy", func(t *testing.T) { - policy := templates.GetS3AdminPolicy() - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 1) - - stmt := policy.Statement[0] - assert.Equal(t, "Allow", stmt.Effect) - assert.Equal(t, "S3FullAccess", stmt.Sid) - assert.Contains(t, stmt.Action, "s3:*") - - assert.Contains(t, stmt.Resource, "arn:aws:s3:::*") - assert.Contains(t, stmt.Resource, "arn:aws:s3:::*/*") - }) -} - -func TestBucketSpecificPolicies(t *testing.T) { - templates := NewS3PolicyTemplates() - bucketName := "test-bucket" - - t.Run("BucketSpecificReadPolicy", func(t *testing.T) { - policy := templates.GetBucketSpecificReadPolicy(bucketName) - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 1) - - stmt := policy.Statement[0] - assert.Equal(t, "Allow", stmt.Effect) - assert.Equal(t, "BucketSpecificReadAccess", stmt.Sid) - assert.Contains(t, stmt.Action, "s3:GetObject") - assert.Contains(t, stmt.Action, "s3:ListBucket") - assert.NotContains(t, stmt.Action, "s3:PutObject") - - expectedBucketArn := "arn:aws:s3:::" + bucketName - expectedObjectArn := "arn:aws:s3:::" + bucketName + "/*" - assert.Contains(t, stmt.Resource, expectedBucketArn) - assert.Contains(t, stmt.Resource, expectedObjectArn) - }) - - t.Run("BucketSpecificWritePolicy", func(t *testing.T) { - policy := templates.GetBucketSpecificWritePolicy(bucketName) - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 1) - - stmt := policy.Statement[0] - assert.Equal(t, "Allow", stmt.Effect) - assert.Equal(t, "BucketSpecificWriteAccess", stmt.Sid) - assert.Contains(t, stmt.Action, "s3:PutObject") - assert.Contains(t, stmt.Action, "s3:CreateMultipartUpload") - assert.NotContains(t, stmt.Action, "s3:GetObject") - - expectedBucketArn := "arn:aws:s3:::" + bucketName - expectedObjectArn := "arn:aws:s3:::" + bucketName + "/*" - assert.Contains(t, stmt.Resource, expectedBucketArn) - assert.Contains(t, stmt.Resource, expectedObjectArn) - }) -} - -func TestPathBasedAccessPolicy(t *testing.T) { - templates := NewS3PolicyTemplates() - bucketName := "shared-bucket" - pathPrefix := "user123/documents" - - policy := templates.GetPathBasedAccessPolicy(bucketName, pathPrefix) - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 2) - - // First statement: List bucket with prefix condition - listStmt := policy.Statement[0] - assert.Equal(t, "Allow", listStmt.Effect) - assert.Equal(t, "ListBucketPermission", listStmt.Sid) - assert.Contains(t, listStmt.Action, "s3:ListBucket") - assert.Contains(t, listStmt.Resource, "arn:aws:s3:::"+bucketName) - assert.NotNil(t, listStmt.Condition) - - // Second statement: Object operations on path - objectStmt := policy.Statement[1] - assert.Equal(t, "Allow", objectStmt.Effect) - assert.Equal(t, "PathBasedObjectAccess", objectStmt.Sid) - assert.Contains(t, objectStmt.Action, "s3:GetObject") - assert.Contains(t, objectStmt.Action, "s3:PutObject") - assert.Contains(t, objectStmt.Action, "s3:DeleteObject") - - expectedObjectArn := "arn:aws:s3:::" + bucketName + "/" + pathPrefix + "/*" - assert.Contains(t, objectStmt.Resource, expectedObjectArn) -} - -func TestIPRestrictedPolicy(t *testing.T) { - templates := NewS3PolicyTemplates() - allowedCIDRs := []string{"192.168.1.0/24", "10.0.0.0/8"} - - policy := templates.GetIPRestrictedPolicy(allowedCIDRs) - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 1) - - stmt := policy.Statement[0] - assert.Equal(t, "Allow", stmt.Effect) - assert.Equal(t, "IPRestrictedS3Access", stmt.Sid) - assert.Contains(t, stmt.Action, "s3:*") - assert.NotNil(t, stmt.Condition) - - // Check IP condition structure - condition := stmt.Condition - ipAddress, exists := condition["IpAddress"] - assert.True(t, exists) - - sourceIp, exists := ipAddress["aws:SourceIp"] - assert.True(t, exists) - assert.Equal(t, allowedCIDRs, sourceIp) -} - -func TestTimeBasedAccessPolicy(t *testing.T) { - templates := NewS3PolicyTemplates() - startHour := 9 // 9 AM - endHour := 17 // 5 PM - - policy := templates.GetTimeBasedAccessPolicy(startHour, endHour) - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 1) - - stmt := policy.Statement[0] - assert.Equal(t, "Allow", stmt.Effect) - assert.Equal(t, "TimeBasedS3Access", stmt.Sid) - assert.Contains(t, stmt.Action, "s3:GetObject") - assert.Contains(t, stmt.Action, "s3:PutObject") - assert.Contains(t, stmt.Action, "s3:ListBucket") - assert.NotNil(t, stmt.Condition) - - // Check time condition structure - condition := stmt.Condition - _, hasGreater := condition["DateGreaterThan"] - _, hasLess := condition["DateLessThan"] - assert.True(t, hasGreater) - assert.True(t, hasLess) -} - -func TestMultipartUploadPolicyTemplate(t *testing.T) { - templates := NewS3PolicyTemplates() - bucketName := "large-files" - - policy := templates.GetMultipartUploadPolicy(bucketName) - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 2) - - // First statement: Multipart operations - multipartStmt := policy.Statement[0] - assert.Equal(t, "Allow", multipartStmt.Effect) - assert.Equal(t, "MultipartUploadOperations", multipartStmt.Sid) - assert.Contains(t, multipartStmt.Action, "s3:CreateMultipartUpload") - assert.Contains(t, multipartStmt.Action, "s3:UploadPart") - assert.Contains(t, multipartStmt.Action, "s3:CompleteMultipartUpload") - assert.Contains(t, multipartStmt.Action, "s3:AbortMultipartUpload") - assert.Contains(t, multipartStmt.Action, "s3:ListMultipartUploads") - assert.Contains(t, multipartStmt.Action, "s3:ListParts") - - expectedObjectArn := "arn:aws:s3:::" + bucketName + "/*" - assert.Contains(t, multipartStmt.Resource, expectedObjectArn) - - // Second statement: List bucket - listStmt := policy.Statement[1] - assert.Equal(t, "Allow", listStmt.Effect) - assert.Equal(t, "ListBucketForMultipart", listStmt.Sid) - assert.Contains(t, listStmt.Action, "s3:ListBucket") - - expectedBucketArn := "arn:aws:s3:::" + bucketName - assert.Contains(t, listStmt.Resource, expectedBucketArn) -} - -func TestPresignedURLPolicy(t *testing.T) { - templates := NewS3PolicyTemplates() - bucketName := "shared-files" - - policy := templates.GetPresignedURLPolicy(bucketName) - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 1) - - stmt := policy.Statement[0] - assert.Equal(t, "Allow", stmt.Effect) - assert.Equal(t, "PresignedURLAccess", stmt.Sid) - assert.Contains(t, stmt.Action, "s3:GetObject") - assert.Contains(t, stmt.Action, "s3:PutObject") - assert.NotNil(t, stmt.Condition) - - expectedObjectArn := "arn:aws:s3:::" + bucketName + "/*" - assert.Contains(t, stmt.Resource, expectedObjectArn) - - // Check signature version condition - condition := stmt.Condition - stringEquals, exists := condition["StringEquals"] - assert.True(t, exists) - - signatureVersion, exists := stringEquals["s3:x-amz-signature-version"] - assert.True(t, exists) - assert.Equal(t, "AWS4-HMAC-SHA256", signatureVersion) -} - -func TestTemporaryAccessPolicy(t *testing.T) { - templates := NewS3PolicyTemplates() - bucketName := "temp-bucket" - expirationHours := 24 - - policy := templates.GetTemporaryAccessPolicy(bucketName, expirationHours) - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 1) - - stmt := policy.Statement[0] - assert.Equal(t, "Allow", stmt.Effect) - assert.Equal(t, "TemporaryS3Access", stmt.Sid) - assert.Contains(t, stmt.Action, "s3:GetObject") - assert.Contains(t, stmt.Action, "s3:PutObject") - assert.Contains(t, stmt.Action, "s3:ListBucket") - assert.NotNil(t, stmt.Condition) - - // Check expiration condition - condition := stmt.Condition - dateLessThan, exists := condition["DateLessThan"] - assert.True(t, exists) - - currentTime, exists := dateLessThan["aws:CurrentTime"] - assert.True(t, exists) - assert.IsType(t, "", currentTime) // Should be a string timestamp -} - -func TestContentTypeRestrictedPolicy(t *testing.T) { - templates := NewS3PolicyTemplates() - bucketName := "media-bucket" - allowedTypes := []string{"image/jpeg", "image/png", "video/mp4"} - - policy := templates.GetContentTypeRestrictedPolicy(bucketName, allowedTypes) - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 2) - - // First statement: Upload with content type restriction - uploadStmt := policy.Statement[0] - assert.Equal(t, "Allow", uploadStmt.Effect) - assert.Equal(t, "ContentTypeRestrictedUpload", uploadStmt.Sid) - assert.Contains(t, uploadStmt.Action, "s3:PutObject") - assert.Contains(t, uploadStmt.Action, "s3:CreateMultipartUpload") - assert.NotNil(t, uploadStmt.Condition) - - // Check content type condition - condition := uploadStmt.Condition - stringEquals, exists := condition["StringEquals"] - assert.True(t, exists) - - contentType, exists := stringEquals["s3:content-type"] - assert.True(t, exists) - assert.Equal(t, allowedTypes, contentType) - - // Second statement: Read access without restrictions - readStmt := policy.Statement[1] - assert.Equal(t, "Allow", readStmt.Effect) - assert.Equal(t, "ReadAccess", readStmt.Sid) - assert.Contains(t, readStmt.Action, "s3:GetObject") - assert.Contains(t, readStmt.Action, "s3:ListBucket") - assert.Nil(t, readStmt.Condition) // No conditions for read access -} - -func TestDenyDeletePolicy(t *testing.T) { - templates := NewS3PolicyTemplates() - - policy := templates.GetDenyDeletePolicy() - - require.NotNil(t, policy) - assert.Equal(t, "2012-10-17", policy.Version) - assert.Len(t, policy.Statement, 2) - - // First statement: Allow everything except delete - allowStmt := policy.Statement[0] - assert.Equal(t, "Allow", allowStmt.Effect) - assert.Equal(t, "AllowAllExceptDelete", allowStmt.Sid) - assert.Contains(t, allowStmt.Action, "s3:GetObject") - assert.Contains(t, allowStmt.Action, "s3:PutObject") - assert.Contains(t, allowStmt.Action, "s3:ListBucket") - assert.NotContains(t, allowStmt.Action, "s3:DeleteObject") - assert.NotContains(t, allowStmt.Action, "s3:DeleteBucket") - - // Second statement: Explicitly deny delete operations - denyStmt := policy.Statement[1] - assert.Equal(t, "Deny", denyStmt.Effect) - assert.Equal(t, "DenyDeleteOperations", denyStmt.Sid) - assert.Contains(t, denyStmt.Action, "s3:DeleteObject") - assert.Contains(t, denyStmt.Action, "s3:DeleteObjectVersion") - assert.Contains(t, denyStmt.Action, "s3:DeleteBucket") -} - -func TestPolicyTemplateMetadata(t *testing.T) { - templates := NewS3PolicyTemplates() - - t.Run("GetAllPolicyTemplates", func(t *testing.T) { - allTemplates := templates.GetAllPolicyTemplates() - - assert.Greater(t, len(allTemplates), 10) // Should have many templates - - // Check that each template has required fields - for _, template := range allTemplates { - assert.NotEmpty(t, template.Name) - assert.NotEmpty(t, template.Description) - assert.NotEmpty(t, template.Category) - assert.NotEmpty(t, template.UseCase) - assert.NotNil(t, template.Policy) - assert.Equal(t, "2012-10-17", template.Policy.Version) - } - }) - - t.Run("GetPolicyTemplateByName", func(t *testing.T) { - // Test existing template - template := templates.GetPolicyTemplateByName("S3ReadOnlyAccess") - require.NotNil(t, template) - assert.Equal(t, "S3ReadOnlyAccess", template.Name) - assert.Equal(t, "Basic Access", template.Category) - - // Test non-existing template - nonExistent := templates.GetPolicyTemplateByName("NonExistentTemplate") - assert.Nil(t, nonExistent) - }) - - t.Run("GetPolicyTemplatesByCategory", func(t *testing.T) { - basicAccessTemplates := templates.GetPolicyTemplatesByCategory("Basic Access") - assert.GreaterOrEqual(t, len(basicAccessTemplates), 2) - - for _, template := range basicAccessTemplates { - assert.Equal(t, "Basic Access", template.Category) - } - - // Test non-existing category - emptyCategory := templates.GetPolicyTemplatesByCategory("NonExistentCategory") - assert.Empty(t, emptyCategory) - }) - - t.Run("PolicyTemplateParameters", func(t *testing.T) { - allTemplates := templates.GetAllPolicyTemplates() - - // Find a template with parameters (like BucketSpecificRead) - var templateWithParams *PolicyTemplateDefinition - for _, template := range allTemplates { - if template.Name == "BucketSpecificRead" { - templateWithParams = &template - break - } - } - - require.NotNil(t, templateWithParams) - assert.Greater(t, len(templateWithParams.Parameters), 0) - - param := templateWithParams.Parameters[0] - assert.Equal(t, "bucketName", param.Name) - assert.Equal(t, "string", param.Type) - assert.True(t, param.Required) - assert.NotEmpty(t, param.Description) - assert.NotEmpty(t, param.Example) - }) -} - -func TestFormatHourHelper(t *testing.T) { - tests := []struct { - hour int - expected string - }{ - {0, "00"}, - {5, "05"}, - {9, "09"}, - {10, "10"}, - {15, "15"}, - {23, "23"}, - } - - for _, tt := range tests { - t.Run(fmt.Sprintf("Hour_%d", tt.hour), func(t *testing.T) { - result := formatHour(tt.hour) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestPolicyTemplateCategories(t *testing.T) { - templates := NewS3PolicyTemplates() - allTemplates := templates.GetAllPolicyTemplates() - - // Extract all categories - categoryMap := make(map[string]int) - for _, template := range allTemplates { - categoryMap[template.Category]++ - } - - // Expected categories - expectedCategories := []string{ - "Basic Access", - "Administrative", - "Bucket-Specific", - "Path-Restricted", - "Security", - "Upload-Specific", - "Presigned URLs", - "Content Control", - "Data Protection", - } - - for _, expectedCategory := range expectedCategories { - count, exists := categoryMap[expectedCategory] - assert.True(t, exists, "Category %s should exist", expectedCategory) - assert.Greater(t, count, 0, "Category %s should have at least one template", expectedCategory) - } -} - -func TestPolicyValidation(t *testing.T) { - templates := NewS3PolicyTemplates() - allTemplates := templates.GetAllPolicyTemplates() - - // Test that all policies have valid structure - for _, template := range allTemplates { - t.Run("Policy_"+template.Name, func(t *testing.T) { - policy := template.Policy - - // Basic validation - assert.Equal(t, "2012-10-17", policy.Version) - assert.Greater(t, len(policy.Statement), 0) - - // Validate each statement - for i, stmt := range policy.Statement { - assert.NotEmpty(t, stmt.Effect, "Statement %d should have effect", i) - assert.Contains(t, []string{"Allow", "Deny"}, stmt.Effect, "Statement %d effect should be Allow or Deny", i) - assert.Greater(t, len(stmt.Action), 0, "Statement %d should have actions", i) - assert.Greater(t, len(stmt.Resource), 0, "Statement %d should have resources", i) - - // Check resource format - for _, resource := range stmt.Resource { - if resource != "*" { - assert.Contains(t, resource, "arn:aws:s3:::", "Resource should be valid AWS S3 ARN: %s", resource) - } - } - } - }) - } -} diff --git a/weed/s3api/s3_presigned_url_iam.go b/weed/s3api/s3_presigned_url_iam.go deleted file mode 100644 index b731b1634..000000000 --- a/weed/s3api/s3_presigned_url_iam.go +++ /dev/null @@ -1,355 +0,0 @@ -package s3api - -import ( - "context" - "crypto/sha256" - "encoding/hex" - "fmt" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" -) - -// S3PresignedURLManager handles IAM integration for presigned URLs -type S3PresignedURLManager struct { - s3iam *S3IAMIntegration -} - -// NewS3PresignedURLManager creates a new presigned URL manager with IAM integration -func NewS3PresignedURLManager(s3iam *S3IAMIntegration) *S3PresignedURLManager { - return &S3PresignedURLManager{ - s3iam: s3iam, - } -} - -// PresignedURLRequest represents a request to generate a presigned URL -type PresignedURLRequest struct { - Method string `json:"method"` // HTTP method (GET, PUT, POST, DELETE) - Bucket string `json:"bucket"` // S3 bucket name - ObjectKey string `json:"object_key"` // S3 object key - Expiration time.Duration `json:"expiration"` // URL expiration duration - SessionToken string `json:"session_token"` // JWT session token for IAM - Headers map[string]string `json:"headers"` // Additional headers to sign - QueryParams map[string]string `json:"query_params"` // Additional query parameters -} - -// PresignedURLResponse represents the generated presigned URL -type PresignedURLResponse struct { - URL string `json:"url"` // The presigned URL - Method string `json:"method"` // HTTP method - Headers map[string]string `json:"headers"` // Required headers - ExpiresAt time.Time `json:"expires_at"` // URL expiration time - SignedHeaders []string `json:"signed_headers"` // List of signed headers - CanonicalQuery string `json:"canonical_query"` // Canonical query string -} - -// ValidatePresignedURLWithIAM validates a presigned URL request using IAM policies -func (iam *IdentityAccessManagement) ValidatePresignedURLWithIAM(r *http.Request, identity *Identity) s3err.ErrorCode { - if iam.iamIntegration == nil { - // Fall back to standard validation - return s3err.ErrNone - } - - // Extract bucket and object from request - bucket, object := s3_constants.GetBucketAndObject(r) - - // Determine the S3 action from HTTP method and path - action := determineS3ActionFromRequest(r, bucket, object) - - // Check if the user has permission for this action - ctx := r.Context() - sessionToken := extractSessionTokenFromPresignedURL(r) - if sessionToken == "" { - // No session token in presigned URL - use standard auth - return s3err.ErrNone - } - - // Create a temporary cloned request with Authorization header to reuse the secure AuthenticateJWT logic - // This ensures we use the same robust validation (STS vs OIDC, signature verification, etc.) - // as standard requests, preventing security regressions. - authReq := r.Clone(ctx) - authReq.Header.Set("Authorization", "Bearer "+sessionToken) - - // Authenticate the token using the centralized IAM integration - iamIdentity, errCode := iam.iamIntegration.AuthenticateJWT(ctx, authReq) - if errCode != s3err.ErrNone { - glog.V(3).Infof("JWT authentication failed for presigned URL: %v", errCode) - return errCode - } - - // Authorize using IAM - errCode = iam.iamIntegration.AuthorizeAction(ctx, iamIdentity, action, bucket, object, r) - if errCode != s3err.ErrNone { - glog.V(3).Infof("IAM authorization failed for presigned URL: principal=%s action=%s bucket=%s object=%s", - iamIdentity.Principal, action, bucket, object) - return errCode - } - - glog.V(3).Infof("IAM authorization succeeded for presigned URL: principal=%s action=%s bucket=%s object=%s", - iamIdentity.Principal, action, bucket, object) - return s3err.ErrNone -} - -// GeneratePresignedURLWithIAM generates a presigned URL with IAM policy validation -func (pm *S3PresignedURLManager) GeneratePresignedURLWithIAM(ctx context.Context, req *PresignedURLRequest, baseURL string) (*PresignedURLResponse, error) { - if pm.s3iam == nil || !pm.s3iam.enabled { - return nil, fmt.Errorf("IAM integration not enabled") - } - if req == nil || strings.TrimSpace(req.SessionToken) == "" { - return nil, fmt.Errorf("IAM authorization failed: session token is required") - } - - authRequest := &http.Request{ - Method: req.Method, - URL: &url.URL{Path: "/" + req.Bucket + "/" + req.ObjectKey}, - Header: make(http.Header), - } - authRequest.Header.Set("Authorization", "Bearer "+req.SessionToken) - authRequest = authRequest.WithContext(ctx) - - iamIdentity, errCode := pm.s3iam.AuthenticateJWT(ctx, authRequest) - if errCode != s3err.ErrNone { - return nil, fmt.Errorf("IAM authorization failed: invalid session token") - } - - // Determine S3 action from method - action := determineS3ActionFromMethodAndPath(req.Method, req.Bucket, req.ObjectKey) - - // Check IAM permissions before generating URL - errCode = pm.s3iam.AuthorizeAction(ctx, iamIdentity, action, req.Bucket, req.ObjectKey, authRequest) - if errCode != s3err.ErrNone { - return nil, fmt.Errorf("IAM authorization failed: user does not have permission for action %s on resource %s/%s", action, req.Bucket, req.ObjectKey) - } - - // Generate presigned URL with validated permissions - return pm.generatePresignedURL(req, baseURL, iamIdentity) -} - -// generatePresignedURL creates the actual presigned URL -func (pm *S3PresignedURLManager) generatePresignedURL(req *PresignedURLRequest, baseURL string, identity *IAMIdentity) (*PresignedURLResponse, error) { - // Calculate expiration time - expiresAt := time.Now().Add(req.Expiration) - - // Build the base URL - urlPath := "/" + req.Bucket - if req.ObjectKey != "" { - urlPath += "/" + req.ObjectKey - } - - // Create query parameters for AWS signature v4 - queryParams := make(map[string]string) - for k, v := range req.QueryParams { - queryParams[k] = v - } - - // Add AWS signature v4 parameters - queryParams["X-Amz-Algorithm"] = "AWS4-HMAC-SHA256" - queryParams["X-Amz-Credential"] = fmt.Sprintf("seaweedfs/%s/us-east-1/s3/aws4_request", expiresAt.Format("20060102")) - queryParams["X-Amz-Date"] = expiresAt.Format("20060102T150405Z") - queryParams["X-Amz-Expires"] = strconv.Itoa(int(req.Expiration.Seconds())) - queryParams["X-Amz-SignedHeaders"] = "host" - - // Add session token if available - if identity.SessionToken != "" { - queryParams["X-Amz-Security-Token"] = identity.SessionToken - } - - // Build canonical query string - canonicalQuery := buildCanonicalQuery(queryParams) - - // For now, we'll create a mock signature - // In production, this would use proper AWS signature v4 signing - mockSignature := generateMockSignature(req.Method, urlPath, canonicalQuery, identity.SessionToken) - queryParams["X-Amz-Signature"] = mockSignature - - // Build final URL - finalQuery := buildCanonicalQuery(queryParams) - fullURL := baseURL + urlPath + "?" + finalQuery - - // Prepare response - headers := make(map[string]string) - for k, v := range req.Headers { - headers[k] = v - } - - return &PresignedURLResponse{ - URL: fullURL, - Method: req.Method, - Headers: headers, - ExpiresAt: expiresAt, - SignedHeaders: []string{"host"}, - CanonicalQuery: canonicalQuery, - }, nil -} - -// Helper functions - -// determineS3ActionFromRequest determines the S3 action based on HTTP request -func determineS3ActionFromRequest(r *http.Request, bucket, object string) Action { - return determineS3ActionFromMethodAndPath(r.Method, bucket, object) -} - -// determineS3ActionFromMethodAndPath determines the S3 action based on method and path -func determineS3ActionFromMethodAndPath(method, bucket, object string) Action { - switch method { - case "GET": - if object == "" { - return s3_constants.ACTION_LIST // ListBucket - } else { - return s3_constants.ACTION_READ // GetObject - } - case "PUT", "POST": - return s3_constants.ACTION_WRITE // PutObject - case "DELETE": - if object == "" { - return s3_constants.ACTION_DELETE_BUCKET // DeleteBucket - } else { - return s3_constants.ACTION_WRITE // DeleteObject (uses WRITE action) - } - case "HEAD": - if object == "" { - return s3_constants.ACTION_LIST // HeadBucket - } else { - return s3_constants.ACTION_READ // HeadObject - } - default: - return s3_constants.ACTION_READ // Default to read - } -} - -// extractSessionTokenFromPresignedURL extracts session token from presigned URL query parameters -func extractSessionTokenFromPresignedURL(r *http.Request) string { - // Check for X-Amz-Security-Token in query parameters - if token := r.URL.Query().Get("X-Amz-Security-Token"); token != "" { - return token - } - - // Check for session token in other possible locations - if token := r.URL.Query().Get("SessionToken"); token != "" { - return token - } - - return "" -} - -// buildCanonicalQuery builds a canonical query string for AWS signature -func buildCanonicalQuery(params map[string]string) string { - var keys []string - for k := range params { - keys = append(keys, k) - } - - // Sort keys for canonical order - for i := 0; i < len(keys); i++ { - for j := i + 1; j < len(keys); j++ { - if keys[i] > keys[j] { - keys[i], keys[j] = keys[j], keys[i] - } - } - } - - var parts []string - for _, k := range keys { - parts = append(parts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(params[k]))) - } - - return strings.Join(parts, "&") -} - -// generateMockSignature generates a mock signature for testing purposes -func generateMockSignature(method, path, query, sessionToken string) string { - // This is a simplified signature for demonstration - // In production, use proper AWS signature v4 calculation - data := fmt.Sprintf("%s\n%s\n%s\n%s", method, path, query, sessionToken) - hash := sha256.Sum256([]byte(data)) - return hex.EncodeToString(hash[:])[:16] // Truncate for readability -} - -// ValidatePresignedURLExpiration validates that a presigned URL hasn't expired -func ValidatePresignedURLExpiration(r *http.Request) error { - query := r.URL.Query() - - // Get X-Amz-Date and X-Amz-Expires - dateStr := query.Get("X-Amz-Date") - expiresStr := query.Get("X-Amz-Expires") - - if dateStr == "" || expiresStr == "" { - return fmt.Errorf("missing required presigned URL parameters") - } - - // Parse date (always in UTC) - signedDate, err := time.Parse("20060102T150405Z", dateStr) - if err != nil { - return fmt.Errorf("invalid X-Amz-Date format: %v", err) - } - - // Parse expires - expires, err := strconv.Atoi(expiresStr) - if err != nil { - return fmt.Errorf("invalid X-Amz-Expires format: %v", err) - } - - // Check expiration - compare in UTC - expirationTime := signedDate.Add(time.Duration(expires) * time.Second) - now := time.Now().UTC() - if now.After(expirationTime) { - return fmt.Errorf("presigned URL has expired") - } - - return nil -} - -// PresignedURLSecurityPolicy represents security constraints for presigned URL generation -type PresignedURLSecurityPolicy struct { - MaxExpirationDuration time.Duration `json:"max_expiration_duration"` // Maximum allowed expiration - AllowedMethods []string `json:"allowed_methods"` // Allowed HTTP methods - RequiredHeaders []string `json:"required_headers"` // Headers that must be present - IPWhitelist []string `json:"ip_whitelist"` // Allowed IP addresses/ranges - MaxFileSize int64 `json:"max_file_size"` // Maximum file size for uploads -} - -// DefaultPresignedURLSecurityPolicy returns a default security policy -func DefaultPresignedURLSecurityPolicy() *PresignedURLSecurityPolicy { - return &PresignedURLSecurityPolicy{ - MaxExpirationDuration: 7 * 24 * time.Hour, // 7 days max - AllowedMethods: []string{"GET", "PUT", "POST", "HEAD"}, - RequiredHeaders: []string{}, - IPWhitelist: []string{}, // Empty means no IP restrictions - MaxFileSize: 5 * 1024 * 1024 * 1024, // 5GB default - } -} - -// ValidatePresignedURLRequest validates a presigned URL request against security policy -func (policy *PresignedURLSecurityPolicy) ValidatePresignedURLRequest(req *PresignedURLRequest) error { - // Check expiration duration - if req.Expiration > policy.MaxExpirationDuration { - return fmt.Errorf("expiration duration %v exceeds maximum allowed %v", req.Expiration, policy.MaxExpirationDuration) - } - - // Check HTTP method - methodAllowed := false - for _, allowedMethod := range policy.AllowedMethods { - if req.Method == allowedMethod { - methodAllowed = true - break - } - } - if !methodAllowed { - return fmt.Errorf("HTTP method %s is not allowed", req.Method) - } - - // Check required headers - for _, requiredHeader := range policy.RequiredHeaders { - if _, exists := req.Headers[requiredHeader]; !exists { - return fmt.Errorf("required header %s is missing", requiredHeader) - } - } - - return nil -} diff --git a/weed/s3api/s3_presigned_url_iam_test.go b/weed/s3api/s3_presigned_url_iam_test.go deleted file mode 100644 index 5d50f06dc..000000000 --- a/weed/s3api/s3_presigned_url_iam_test.go +++ /dev/null @@ -1,631 +0,0 @@ -package s3api - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/seaweedfs/seaweedfs/weed/iam/integration" - "github.com/seaweedfs/seaweedfs/weed/iam/ldap" - "github.com/seaweedfs/seaweedfs/weed/iam/oidc" - "github.com/seaweedfs/seaweedfs/weed/iam/policy" - "github.com/seaweedfs/seaweedfs/weed/iam/sts" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// createTestJWTPresigned creates a test JWT token with the specified issuer, subject and signing key -func createTestJWTPresigned(t *testing.T, issuer, subject, signingKey string) string { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "iss": issuer, - "sub": subject, - "aud": "test-client-id", - "exp": time.Now().Add(time.Hour).Unix(), - "iat": time.Now().Unix(), - // Add claims that trust policy validation expects - "idp": "test-oidc", // Identity provider claim for trust policy matching - }) - - tokenString, err := token.SignedString([]byte(signingKey)) - require.NoError(t, err) - return tokenString -} - -// TestPresignedURLIAMValidation tests IAM validation for presigned URLs -func TestPresignedURLIAMValidation(t *testing.T) { - // Set up IAM system - iamManager := setupTestIAMManagerForPresigned(t) - s3iam := NewS3IAMIntegration(iamManager, "localhost:8888") - - // Create IAM with integration - iam := &IdentityAccessManagement{ - isAuthEnabled: true, - } - iam.SetIAMIntegration(s3iam) - - // Set up roles - ctx := context.Background() - setupTestRolesForPresigned(ctx, iamManager) - - // Create a valid JWT token for testing - validJWTToken := createTestJWTPresigned(t, "https://test-issuer.com", "test-user-123", "test-signing-key") - - // Get session token - response, err := iamManager.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/S3ReadOnlyRole", - WebIdentityToken: validJWTToken, - RoleSessionName: "presigned-test-session", - }) - require.NoError(t, err) - - sessionToken := response.Credentials.SessionToken - - tests := []struct { - name string - method string - path string - sessionToken string - expectedResult s3err.ErrorCode - }{ - { - name: "GET object with read permissions", - method: "GET", - path: "/test-bucket/test-file.txt", - sessionToken: sessionToken, - expectedResult: s3err.ErrNone, - }, - { - name: "PUT object with read-only permissions (should fail)", - method: "PUT", - path: "/test-bucket/new-file.txt", - sessionToken: sessionToken, - expectedResult: s3err.ErrAccessDenied, - }, - { - name: "GET object without session token", - method: "GET", - path: "/test-bucket/test-file.txt", - sessionToken: "", - expectedResult: s3err.ErrNone, // Falls back to standard auth - }, - { - name: "Invalid session token", - method: "GET", - path: "/test-bucket/test-file.txt", - sessionToken: "invalid-token", - expectedResult: s3err.ErrAccessDenied, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create request with presigned URL parameters - req := createPresignedURLRequest(t, tt.method, tt.path, tt.sessionToken) - - // Create identity for testing - identity := &Identity{ - Name: "test-user", - Account: &AccountAdmin, - } - - // Test validation - result := iam.ValidatePresignedURLWithIAM(req, identity) - assert.Equal(t, tt.expectedResult, result, "IAM validation result should match expected") - }) - } -} - -// TestPresignedURLGeneration tests IAM-aware presigned URL generation -func TestPresignedURLGeneration(t *testing.T) { - // Set up IAM system - iamManager := setupTestIAMManagerForPresigned(t) - s3iam := NewS3IAMIntegration(iamManager, "localhost:8888") - s3iam.enabled = true // Enable IAM integration - presignedManager := NewS3PresignedURLManager(s3iam) - - ctx := context.Background() - setupTestRolesForPresigned(ctx, iamManager) - - // Create a valid JWT token for testing - validJWTToken := createTestJWTPresigned(t, "https://test-issuer.com", "test-user-123", "test-signing-key") - - // Get session token - response, err := iamManager.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/S3AdminRole", - WebIdentityToken: validJWTToken, - RoleSessionName: "presigned-gen-test-session", - }) - require.NoError(t, err) - - sessionToken := response.Credentials.SessionToken - - tests := []struct { - name string - request *PresignedURLRequest - shouldSucceed bool - expectedError string - }{ - { - name: "Generate valid presigned GET URL", - request: &PresignedURLRequest{ - Method: "GET", - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - Expiration: time.Hour, - SessionToken: sessionToken, - }, - shouldSucceed: true, - }, - { - name: "Generate valid presigned PUT URL", - request: &PresignedURLRequest{ - Method: "PUT", - Bucket: "test-bucket", - ObjectKey: "new-file.txt", - Expiration: time.Hour, - SessionToken: sessionToken, - }, - shouldSucceed: true, - }, - { - name: "Generate URL with invalid session token", - request: &PresignedURLRequest{ - Method: "GET", - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - Expiration: time.Hour, - SessionToken: "invalid-token", - }, - shouldSucceed: false, - expectedError: "IAM authorization failed", - }, - { - name: "Generate URL without session token", - request: &PresignedURLRequest{ - Method: "GET", - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - Expiration: time.Hour, - }, - shouldSucceed: false, - expectedError: "IAM authorization failed", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - response, err := presignedManager.GeneratePresignedURLWithIAM(ctx, tt.request, "http://localhost:8333") - - if tt.shouldSucceed { - assert.NoError(t, err, "Presigned URL generation should succeed") - if response != nil { - assert.NotEmpty(t, response.URL, "URL should not be empty") - assert.Equal(t, tt.request.Method, response.Method, "Method should match") - assert.True(t, response.ExpiresAt.After(time.Now()), "URL should not be expired") - } else { - t.Errorf("Response should not be nil when generation should succeed") - } - } else { - assert.Error(t, err, "Presigned URL generation should fail") - if tt.expectedError != "" { - assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text") - } - } - }) - } -} - -func TestPresignedURLGenerationUsesAuthenticatedPrincipal(t *testing.T) { - iamManager := setupTestIAMManagerForPresigned(t) - s3iam := NewS3IAMIntegration(iamManager, "localhost:8888") - s3iam.enabled = true - presignedManager := NewS3PresignedURLManager(s3iam) - - ctx := context.Background() - setupTestRolesForPresigned(ctx, iamManager) - - validJWTToken := createTestJWTPresigned(t, "https://test-issuer.com", "test-user-123", "test-signing-key") - - response, err := iamManager.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityRequest{ - RoleArn: "arn:aws:iam::role/S3ReadOnlyRole", - WebIdentityToken: validJWTToken, - RoleSessionName: "presigned-read-only-session", - }) - require.NoError(t, err) - - _, err = presignedManager.GeneratePresignedURLWithIAM(ctx, &PresignedURLRequest{ - Method: "PUT", - Bucket: "test-bucket", - ObjectKey: "new-file.txt", - Expiration: time.Hour, - SessionToken: response.Credentials.SessionToken, - }, "http://localhost:8333") - require.Error(t, err) - assert.Contains(t, err.Error(), "IAM authorization failed") -} - -// TestPresignedURLExpiration tests URL expiration validation -func TestPresignedURLExpiration(t *testing.T) { - tests := []struct { - name string - setupRequest func() *http.Request - expectedError string - }{ - { - name: "Valid non-expired URL", - setupRequest: func() *http.Request { - req := httptest.NewRequest("GET", "/test-bucket/test-file.txt", nil) - q := req.URL.Query() - // Set date to 30 minutes ago with 2 hours expiration for safe margin - q.Set("X-Amz-Date", time.Now().UTC().Add(-30*time.Minute).Format("20060102T150405Z")) - q.Set("X-Amz-Expires", "7200") // 2 hours - req.URL.RawQuery = q.Encode() - return req - }, - expectedError: "", - }, - { - name: "Expired URL", - setupRequest: func() *http.Request { - req := httptest.NewRequest("GET", "/test-bucket/test-file.txt", nil) - q := req.URL.Query() - // Set date to 2 hours ago with 1 hour expiration - q.Set("X-Amz-Date", time.Now().UTC().Add(-2*time.Hour).Format("20060102T150405Z")) - q.Set("X-Amz-Expires", "3600") // 1 hour - req.URL.RawQuery = q.Encode() - return req - }, - expectedError: "presigned URL has expired", - }, - { - name: "Missing date parameter", - setupRequest: func() *http.Request { - req := httptest.NewRequest("GET", "/test-bucket/test-file.txt", nil) - q := req.URL.Query() - q.Set("X-Amz-Expires", "3600") - req.URL.RawQuery = q.Encode() - return req - }, - expectedError: "missing required presigned URL parameters", - }, - { - name: "Invalid date format", - setupRequest: func() *http.Request { - req := httptest.NewRequest("GET", "/test-bucket/test-file.txt", nil) - q := req.URL.Query() - q.Set("X-Amz-Date", "invalid-date") - q.Set("X-Amz-Expires", "3600") - req.URL.RawQuery = q.Encode() - return req - }, - expectedError: "invalid X-Amz-Date format", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := tt.setupRequest() - err := ValidatePresignedURLExpiration(req) - - if tt.expectedError == "" { - assert.NoError(t, err, "Validation should succeed") - } else { - assert.Error(t, err, "Validation should fail") - assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text") - } - }) - } -} - -// TestPresignedURLSecurityPolicy tests security policy enforcement -func TestPresignedURLSecurityPolicy(t *testing.T) { - policy := &PresignedURLSecurityPolicy{ - MaxExpirationDuration: 24 * time.Hour, - AllowedMethods: []string{"GET", "PUT"}, - RequiredHeaders: []string{"Content-Type"}, - MaxFileSize: 1024 * 1024, // 1MB - } - - tests := []struct { - name string - request *PresignedURLRequest - expectedError string - }{ - { - name: "Valid request", - request: &PresignedURLRequest{ - Method: "GET", - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - Expiration: 12 * time.Hour, - Headers: map[string]string{"Content-Type": "application/json"}, - }, - expectedError: "", - }, - { - name: "Expiration too long", - request: &PresignedURLRequest{ - Method: "GET", - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - Expiration: 48 * time.Hour, // Exceeds 24h limit - Headers: map[string]string{"Content-Type": "application/json"}, - }, - expectedError: "expiration duration", - }, - { - name: "Method not allowed", - request: &PresignedURLRequest{ - Method: "DELETE", // Not in allowed methods - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - Expiration: 12 * time.Hour, - Headers: map[string]string{"Content-Type": "application/json"}, - }, - expectedError: "HTTP method DELETE is not allowed", - }, - { - name: "Missing required header", - request: &PresignedURLRequest{ - Method: "GET", - Bucket: "test-bucket", - ObjectKey: "test-file.txt", - Expiration: 12 * time.Hour, - Headers: map[string]string{}, // Missing Content-Type - }, - expectedError: "required header Content-Type is missing", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := policy.ValidatePresignedURLRequest(tt.request) - - if tt.expectedError == "" { - assert.NoError(t, err, "Policy validation should succeed") - } else { - assert.Error(t, err, "Policy validation should fail") - assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text") - } - }) - } -} - -// TestS3ActionDetermination tests action determination from HTTP methods -func TestS3ActionDetermination(t *testing.T) { - tests := []struct { - name string - method string - bucket string - object string - expectedAction Action - }{ - { - name: "GET object", - method: "GET", - bucket: "test-bucket", - object: "test-file.txt", - expectedAction: s3_constants.ACTION_READ, - }, - { - name: "GET bucket (list)", - method: "GET", - bucket: "test-bucket", - object: "", - expectedAction: s3_constants.ACTION_LIST, - }, - { - name: "PUT object", - method: "PUT", - bucket: "test-bucket", - object: "new-file.txt", - expectedAction: s3_constants.ACTION_WRITE, - }, - { - name: "DELETE object", - method: "DELETE", - bucket: "test-bucket", - object: "old-file.txt", - expectedAction: s3_constants.ACTION_WRITE, - }, - { - name: "DELETE bucket", - method: "DELETE", - bucket: "test-bucket", - object: "", - expectedAction: s3_constants.ACTION_DELETE_BUCKET, - }, - { - name: "HEAD object", - method: "HEAD", - bucket: "test-bucket", - object: "test-file.txt", - expectedAction: s3_constants.ACTION_READ, - }, - { - name: "POST object", - method: "POST", - bucket: "test-bucket", - object: "upload-file.txt", - expectedAction: s3_constants.ACTION_WRITE, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - action := determineS3ActionFromMethodAndPath(tt.method, tt.bucket, tt.object) - assert.Equal(t, tt.expectedAction, action, "S3 action should match expected") - }) - } -} - -// Helper functions for tests - -func setupTestIAMManagerForPresigned(t *testing.T) *integration.IAMManager { - // Create IAM manager - manager := integration.NewIAMManager() - - // Initialize with test configuration - config := &integration.IAMConfig{ - STS: &sts.STSConfig{ - TokenDuration: sts.FlexibleDuration{Duration: time.Hour}, - MaxSessionLength: sts.FlexibleDuration{Duration: time.Hour * 12}, - Issuer: "test-sts", - SigningKey: []byte("test-signing-key-32-characters-long"), - }, - Policy: &policy.PolicyEngineConfig{ - DefaultEffect: "Deny", - StoreType: "memory", - }, - Roles: &integration.RoleStoreConfig{ - StoreType: "memory", - }, - } - - err := manager.Initialize(config, func() string { - return "localhost:8888" // Mock filer address for testing - }) - require.NoError(t, err) - - // Set up test identity providers - setupTestProvidersForPresigned(t, manager) - - return manager -} - -func setupTestProvidersForPresigned(t *testing.T, manager *integration.IAMManager) { - // Set up OIDC provider - oidcProvider := oidc.NewMockOIDCProvider("test-oidc") - oidcConfig := &oidc.OIDCConfig{ - Issuer: "https://test-issuer.com", - ClientID: "test-client-id", - } - err := oidcProvider.Initialize(oidcConfig) - require.NoError(t, err) - oidcProvider.SetupDefaultTestData() - - // Set up LDAP provider - ldapProvider := ldap.NewMockLDAPProvider("test-ldap") - err = ldapProvider.Initialize(nil) // Mock doesn't need real config - require.NoError(t, err) - ldapProvider.SetupDefaultTestData() - - // Register providers - err = manager.RegisterIdentityProvider(oidcProvider) - require.NoError(t, err) - err = manager.RegisterIdentityProvider(ldapProvider) - require.NoError(t, err) -} - -func setupTestRolesForPresigned(ctx context.Context, manager *integration.IAMManager) { - // Create read-only policy - readOnlyPolicy := &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "AllowS3ReadOperations", - Effect: "Allow", - Action: []string{"s3:GetObject", "s3:ListBucket", "s3:HeadObject"}, - Resource: []string{ - "arn:aws:s3:::*", - "arn:aws:s3:::*/*", - }, - }, - }, - } - - manager.CreatePolicy(ctx, "", "S3ReadOnlyPolicy", readOnlyPolicy) - - // Create read-only role - manager.CreateRole(ctx, "", "S3ReadOnlyRole", &integration.RoleDefinition{ - RoleName: "S3ReadOnlyRole", - TrustPolicy: &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Effect: "Allow", - Principal: map[string]interface{}{ - "Federated": "test-oidc", - }, - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - }, - }, - }, - AttachedPolicies: []string{"S3ReadOnlyPolicy"}, - }) - - // Create admin policy - adminPolicy := &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Sid: "AllowAllS3Operations", - Effect: "Allow", - Action: []string{"s3:*"}, - Resource: []string{ - "arn:aws:s3:::*", - "arn:aws:s3:::*/*", - }, - }, - }, - } - - manager.CreatePolicy(ctx, "", "S3AdminPolicy", adminPolicy) - - // Create admin role - manager.CreateRole(ctx, "", "S3AdminRole", &integration.RoleDefinition{ - RoleName: "S3AdminRole", - TrustPolicy: &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Effect: "Allow", - Principal: map[string]interface{}{ - "Federated": "test-oidc", - }, - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - }, - }, - }, - AttachedPolicies: []string{"S3AdminPolicy"}, - }) - - // Create a role for presigned URL users with admin permissions for testing - manager.CreateRole(ctx, "", "PresignedUser", &integration.RoleDefinition{ - RoleName: "PresignedUser", - TrustPolicy: &policy.PolicyDocument{ - Version: "2012-10-17", - Statement: []policy.Statement{ - { - Effect: "Allow", - Principal: map[string]interface{}{ - "Federated": "test-oidc", - }, - Action: []string{"sts:AssumeRoleWithWebIdentity"}, - }, - }, - }, - AttachedPolicies: []string{"S3AdminPolicy"}, // Use admin policy for testing - }) -} - -func createPresignedURLRequest(t *testing.T, method, path, sessionToken string) *http.Request { - req := httptest.NewRequest(method, path, nil) - - // Add presigned URL parameters if session token is provided - if sessionToken != "" { - q := req.URL.Query() - q.Set("X-Amz-Algorithm", "AWS4-HMAC-SHA256") - q.Set("X-Amz-Security-Token", sessionToken) - q.Set("X-Amz-Date", time.Now().Format("20060102T150405Z")) - q.Set("X-Amz-Expires", "3600") - req.URL.RawQuery = q.Encode() - } - - return req -} diff --git a/weed/s3api/s3_sse_bucket_test.go b/weed/s3api/s3_sse_bucket_test.go deleted file mode 100644 index 74ad9296b..000000000 --- a/weed/s3api/s3_sse_bucket_test.go +++ /dev/null @@ -1,401 +0,0 @@ -package s3api - -import ( - "fmt" - "strings" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/s3_pb" -) - -// TestBucketDefaultSSEKMSEnforcement tests bucket default encryption enforcement -func TestBucketDefaultSSEKMSEnforcement(t *testing.T) { - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - // Create bucket encryption configuration - config := &s3_pb.EncryptionConfiguration{ - SseAlgorithm: "aws:kms", - KmsKeyId: kmsKey.KeyID, - BucketKeyEnabled: false, - } - - t.Run("Bucket with SSE-KMS default encryption", func(t *testing.T) { - // Test that default encryption config is properly stored and retrieved - if config.SseAlgorithm != "aws:kms" { - t.Errorf("Expected SSE algorithm aws:kms, got %s", config.SseAlgorithm) - } - - if config.KmsKeyId != kmsKey.KeyID { - t.Errorf("Expected KMS key ID %s, got %s", kmsKey.KeyID, config.KmsKeyId) - } - }) - - t.Run("Default encryption headers generation", func(t *testing.T) { - // Test generating default encryption headers for objects - headers := GetDefaultEncryptionHeaders(config) - - if headers == nil { - t.Fatal("Expected default headers, got nil") - } - - expectedAlgorithm := headers["X-Amz-Server-Side-Encryption"] - if expectedAlgorithm != "aws:kms" { - t.Errorf("Expected X-Amz-Server-Side-Encryption header aws:kms, got %s", expectedAlgorithm) - } - - expectedKeyID := headers["X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id"] - if expectedKeyID != kmsKey.KeyID { - t.Errorf("Expected X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id header %s, got %s", kmsKey.KeyID, expectedKeyID) - } - }) - - t.Run("Default encryption detection", func(t *testing.T) { - // Test IsDefaultEncryptionEnabled - enabled := IsDefaultEncryptionEnabled(config) - if !enabled { - t.Error("Should detect default encryption as enabled") - } - - // Test with nil config - enabled = IsDefaultEncryptionEnabled(nil) - if enabled { - t.Error("Should detect default encryption as disabled for nil config") - } - - // Test with empty config - emptyConfig := &s3_pb.EncryptionConfiguration{} - enabled = IsDefaultEncryptionEnabled(emptyConfig) - if enabled { - t.Error("Should detect default encryption as disabled for empty config") - } - }) -} - -// TestBucketEncryptionConfigValidation tests XML validation of bucket encryption configurations -func TestBucketEncryptionConfigValidation(t *testing.T) { - testCases := []struct { - name string - xml string - expectError bool - description string - }{ - { - name: "Valid SSE-S3 configuration", - xml: ` - - - AES256 - - - `, - expectError: false, - description: "Basic SSE-S3 configuration should be valid", - }, - { - name: "Valid SSE-KMS configuration", - xml: ` - - - aws:kms - test-key-id - - - `, - expectError: false, - description: "SSE-KMS configuration with key ID should be valid", - }, - { - name: "Valid SSE-KMS without key ID", - xml: ` - - - aws:kms - - - `, - expectError: false, - description: "SSE-KMS without key ID should use default key", - }, - { - name: "Invalid XML structure", - xml: ` - - AES256 - - `, - expectError: true, - description: "Invalid XML structure should be rejected", - }, - { - name: "Empty configuration", - xml: ` - `, - expectError: true, - description: "Empty configuration should be rejected", - }, - { - name: "Invalid algorithm", - xml: ` - - - INVALID - - - `, - expectError: true, - description: "Invalid algorithm should be rejected", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - config, err := encryptionConfigFromXMLBytes([]byte(tc.xml)) - - if tc.expectError && err == nil { - t.Errorf("Expected error for %s, but got none. %s", tc.name, tc.description) - } - - if !tc.expectError && err != nil { - t.Errorf("Expected no error for %s, but got: %v. %s", tc.name, err, tc.description) - } - - if !tc.expectError && config != nil { - // Validate the parsed configuration - t.Logf("Successfully parsed config: Algorithm=%s, KeyID=%s", - config.SseAlgorithm, config.KmsKeyId) - } - }) - } -} - -// TestBucketEncryptionAPIOperations tests the bucket encryption API operations -func TestBucketEncryptionAPIOperations(t *testing.T) { - // Note: These tests would normally require a full S3 API server setup - // For now, we test the individual components - - t.Run("PUT bucket encryption", func(t *testing.T) { - xml := ` - - - aws:kms - test-key-id - - - ` - - // Parse the XML to protobuf - config, err := encryptionConfigFromXMLBytes([]byte(xml)) - if err != nil { - t.Fatalf("Failed to parse encryption config: %v", err) - } - - // Verify the parsed configuration - if config.SseAlgorithm != "aws:kms" { - t.Errorf("Expected algorithm aws:kms, got %s", config.SseAlgorithm) - } - - if config.KmsKeyId != "test-key-id" { - t.Errorf("Expected key ID test-key-id, got %s", config.KmsKeyId) - } - - // Convert back to XML - xmlBytes, err := encryptionConfigToXMLBytes(config) - if err != nil { - t.Fatalf("Failed to convert config to XML: %v", err) - } - - // Verify round-trip - if len(xmlBytes) == 0 { - t.Error("Generated XML should not be empty") - } - - // Parse again to verify - roundTripConfig, err := encryptionConfigFromXMLBytes(xmlBytes) - if err != nil { - t.Fatalf("Failed to parse round-trip XML: %v", err) - } - - if roundTripConfig.SseAlgorithm != config.SseAlgorithm { - t.Error("Round-trip algorithm doesn't match") - } - - if roundTripConfig.KmsKeyId != config.KmsKeyId { - t.Error("Round-trip key ID doesn't match") - } - }) - - t.Run("GET bucket encryption", func(t *testing.T) { - // Test getting encryption configuration - config := &s3_pb.EncryptionConfiguration{ - SseAlgorithm: "AES256", - KmsKeyId: "", - BucketKeyEnabled: false, - } - - // Convert to XML for GET response - xmlBytes, err := encryptionConfigToXMLBytes(config) - if err != nil { - t.Fatalf("Failed to convert config to XML: %v", err) - } - - if len(xmlBytes) == 0 { - t.Error("Generated XML should not be empty") - } - - // Verify XML contains expected elements - xmlStr := string(xmlBytes) - if !strings.Contains(xmlStr, "AES256") { - t.Error("XML should contain AES256 algorithm") - } - }) - - t.Run("DELETE bucket encryption", func(t *testing.T) { - // Test deleting encryption configuration - // This would typically involve removing the configuration from metadata - - // Simulate checking if encryption is enabled after deletion - enabled := IsDefaultEncryptionEnabled(nil) - if enabled { - t.Error("Encryption should be disabled after deletion") - } - }) -} - -// TestBucketEncryptionEdgeCases tests edge cases in bucket encryption -func TestBucketEncryptionEdgeCases(t *testing.T) { - t.Run("Large XML configuration", func(t *testing.T) { - // Test with a large but valid XML - largeXML := ` - - - aws:kms - arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012 - - true - - ` - - config, err := encryptionConfigFromXMLBytes([]byte(largeXML)) - if err != nil { - t.Fatalf("Failed to parse large XML: %v", err) - } - - if config.SseAlgorithm != "aws:kms" { - t.Error("Should parse large XML correctly") - } - }) - - t.Run("XML with namespaces", func(t *testing.T) { - // Test XML with namespaces - namespacedXML := ` - - - AES256 - - - ` - - config, err := encryptionConfigFromXMLBytes([]byte(namespacedXML)) - if err != nil { - t.Fatalf("Failed to parse namespaced XML: %v", err) - } - - if config.SseAlgorithm != "AES256" { - t.Error("Should parse namespaced XML correctly") - } - }) - - t.Run("Malformed XML", func(t *testing.T) { - malformedXMLs := []string{ - `AES256`, // Unclosed tags - ``, // Empty rule - `not-xml-at-all`, // Not XML - `AES256`, // Invalid namespace - } - - for i, malformedXML := range malformedXMLs { - t.Run(fmt.Sprintf("Malformed XML %d", i), func(t *testing.T) { - _, err := encryptionConfigFromXMLBytes([]byte(malformedXML)) - if err == nil { - t.Errorf("Expected error for malformed XML %d, but got none", i) - } - }) - } - }) -} - -// TestGetDefaultEncryptionHeaders tests generation of default encryption headers -func TestGetDefaultEncryptionHeaders(t *testing.T) { - testCases := []struct { - name string - config *s3_pb.EncryptionConfiguration - expectedHeaders map[string]string - }{ - { - name: "Nil configuration", - config: nil, - expectedHeaders: nil, - }, - { - name: "SSE-S3 configuration", - config: &s3_pb.EncryptionConfiguration{ - SseAlgorithm: "AES256", - }, - expectedHeaders: map[string]string{ - "X-Amz-Server-Side-Encryption": "AES256", - }, - }, - { - name: "SSE-KMS configuration with key", - config: &s3_pb.EncryptionConfiguration{ - SseAlgorithm: "aws:kms", - KmsKeyId: "test-key-id", - }, - expectedHeaders: map[string]string{ - "X-Amz-Server-Side-Encryption": "aws:kms", - "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": "test-key-id", - }, - }, - { - name: "SSE-KMS configuration without key", - config: &s3_pb.EncryptionConfiguration{ - SseAlgorithm: "aws:kms", - }, - expectedHeaders: map[string]string{ - "X-Amz-Server-Side-Encryption": "aws:kms", - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - headers := GetDefaultEncryptionHeaders(tc.config) - - if tc.expectedHeaders == nil && headers != nil { - t.Error("Expected nil headers but got some") - } - - if tc.expectedHeaders != nil && headers == nil { - t.Error("Expected headers but got nil") - } - - if tc.expectedHeaders != nil && headers != nil { - for key, expectedValue := range tc.expectedHeaders { - if actualValue, exists := headers[key]; !exists { - t.Errorf("Expected header %s not found", key) - } else if actualValue != expectedValue { - t.Errorf("Header %s: expected %s, got %s", key, expectedValue, actualValue) - } - } - - // Check for unexpected headers - for key := range headers { - if _, expected := tc.expectedHeaders[key]; !expected { - t.Errorf("Unexpected header found: %s", key) - } - } - } - }) - } -} diff --git a/weed/s3api/s3_sse_c.go b/weed/s3api/s3_sse_c.go index 79cf96041..97990853f 100644 --- a/weed/s3api/s3_sse_c.go +++ b/weed/s3api/s3_sse_c.go @@ -58,9 +58,9 @@ var ( // SSECustomerKey represents a customer-provided encryption key for SSE-C type SSECustomerKey struct { - Algorithm string - Key []byte - KeyMD5 string + Algorithm string + Key []byte + KeyMD5 string } // IsSSECRequest checks if the request contains SSE-C headers @@ -134,16 +134,6 @@ func validateAndParseSSECHeaders(algorithm, key, keyMD5 string) (*SSECustomerKey }, nil } -// ValidateSSECHeaders validates SSE-C headers in the request -func ValidateSSECHeaders(r *http.Request) error { - algorithm := r.Header.Get(s3_constants.AmzServerSideEncryptionCustomerAlgorithm) - key := r.Header.Get(s3_constants.AmzServerSideEncryptionCustomerKey) - keyMD5 := r.Header.Get(s3_constants.AmzServerSideEncryptionCustomerKeyMD5) - - _, err := validateAndParseSSECHeaders(algorithm, key, keyMD5) - return err -} - // ParseSSECHeaders parses and validates SSE-C headers from the request func ParseSSECHeaders(r *http.Request) (*SSECustomerKey, error) { algorithm := r.Header.Get(s3_constants.AmzServerSideEncryptionCustomerAlgorithm) diff --git a/weed/s3api/s3_sse_c_test.go b/weed/s3api/s3_sse_c_test.go deleted file mode 100644 index 034f07a8e..000000000 --- a/weed/s3api/s3_sse_c_test.go +++ /dev/null @@ -1,407 +0,0 @@ -package s3api - -import ( - "bytes" - "crypto/md5" - "encoding/base64" - "fmt" - "io" - "net/http" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" -) - -func base64MD5(b []byte) string { - s := md5.Sum(b) - return base64.StdEncoding.EncodeToString(s[:]) -} - -func TestSSECHeaderValidation(t *testing.T) { - // Test valid SSE-C headers - req := &http.Request{Header: make(http.Header)} - - key := make([]byte, 32) // 256-bit key - for i := range key { - key[i] = byte(i) - } - - keyBase64 := base64.StdEncoding.EncodeToString(key) - md5sum := md5.Sum(key) - keyMD5 := base64.StdEncoding.EncodeToString(md5sum[:]) - - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256") - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, keyBase64) - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyMD5) - - // Test validation - err := ValidateSSECHeaders(req) - if err != nil { - t.Errorf("Expected valid headers, got error: %v", err) - } - - // Test parsing - customerKey, err := ParseSSECHeaders(req) - if err != nil { - t.Errorf("Expected successful parsing, got error: %v", err) - } - - if customerKey == nil { - t.Error("Expected customer key, got nil") - } - - if customerKey.Algorithm != "AES256" { - t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm) - } - - if !bytes.Equal(customerKey.Key, key) { - t.Error("Key doesn't match original") - } - - if customerKey.KeyMD5 != keyMD5 { - t.Errorf("Expected key MD5 %s, got %s", keyMD5, customerKey.KeyMD5) - } -} - -func TestSSECCopySourceHeaders(t *testing.T) { - // Test valid SSE-C copy source headers - req := &http.Request{Header: make(http.Header)} - - key := make([]byte, 32) // 256-bit key - for i := range key { - key[i] = byte(i) + 1 // Different from regular test - } - - keyBase64 := base64.StdEncoding.EncodeToString(key) - md5sum2 := md5.Sum(key) - keyMD5 := base64.StdEncoding.EncodeToString(md5sum2[:]) - - req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerAlgorithm, "AES256") - req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKey, keyBase64) - req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKeyMD5, keyMD5) - - // Test parsing copy source headers - customerKey, err := ParseSSECCopySourceHeaders(req) - if err != nil { - t.Errorf("Expected successful copy source parsing, got error: %v", err) - } - - if customerKey == nil { - t.Error("Expected customer key from copy source headers, got nil") - } - - if customerKey.Algorithm != "AES256" { - t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm) - } - - if !bytes.Equal(customerKey.Key, key) { - t.Error("Copy source key doesn't match original") - } - - // Test that regular headers don't interfere with copy source headers - regularKey, err := ParseSSECHeaders(req) - if err != nil { - t.Errorf("Regular header parsing should not fail: %v", err) - } - - if regularKey != nil { - t.Error("Expected nil for regular headers when only copy source headers are present") - } -} - -func TestSSECHeaderValidationErrors(t *testing.T) { - tests := []struct { - name string - algorithm string - key string - keyMD5 string - wantErr error - }{ - { - name: "invalid algorithm", - algorithm: "AES128", - key: base64.StdEncoding.EncodeToString(make([]byte, 32)), - keyMD5: base64MD5(make([]byte, 32)), - wantErr: ErrInvalidEncryptionAlgorithm, - }, - { - name: "invalid key length", - algorithm: "AES256", - key: base64.StdEncoding.EncodeToString(make([]byte, 16)), - keyMD5: base64MD5(make([]byte, 16)), - wantErr: ErrInvalidEncryptionKey, - }, - { - name: "mismatched MD5", - algorithm: "AES256", - key: base64.StdEncoding.EncodeToString(make([]byte, 32)), - keyMD5: "wrong==md5", - wantErr: ErrSSECustomerKeyMD5Mismatch, - }, - { - name: "incomplete headers", - algorithm: "AES256", - key: "", - keyMD5: "", - wantErr: ErrInvalidRequest, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := &http.Request{Header: make(http.Header)} - - if tt.algorithm != "" { - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, tt.algorithm) - } - if tt.key != "" { - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, tt.key) - } - if tt.keyMD5 != "" { - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, tt.keyMD5) - } - - err := ValidateSSECHeaders(req) - if err != tt.wantErr { - t.Errorf("Expected error %v, got %v", tt.wantErr, err) - } - }) - } -} - -func TestSSECEncryptionDecryption(t *testing.T) { - // Create customer key - key := make([]byte, 32) - for i := range key { - key[i] = byte(i) - } - - md5sumKey := md5.Sum(key) - customerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: key, - KeyMD5: base64.StdEncoding.EncodeToString(md5sumKey[:]), - } - - // Test data - testData := []byte("Hello, World! This is a test of SSE-C encryption.") - - // Create encrypted reader - dataReader := bytes.NewReader(testData) - encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - // Read encrypted data - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Verify data is actually encrypted (different from original) - if bytes.Equal(encryptedData[16:], testData) { // Skip IV - t.Error("Data doesn't appear to be encrypted") - } - - // Create decrypted reader - encryptedReader2 := bytes.NewReader(encryptedData) - decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - // Read decrypted data - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data: %v", err) - } - - // Verify decrypted data matches original - if !bytes.Equal(decryptedData, testData) { - t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData) - } -} - -func TestSSECIsSSECRequest(t *testing.T) { - // Test with SSE-C headers - req := &http.Request{Header: make(http.Header)} - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256") - - if !IsSSECRequest(req) { - t.Error("Expected IsSSECRequest to return true when SSE-C headers are present") - } - - // Test without SSE-C headers - req2 := &http.Request{Header: make(http.Header)} - if IsSSECRequest(req2) { - t.Error("Expected IsSSECRequest to return false when no SSE-C headers are present") - } -} - -// Test encryption with different data sizes (similar to s3tests) -func TestSSECEncryptionVariousSizes(t *testing.T) { - sizes := []int{1, 13, 1024, 1024 * 1024} // 1B, 13B, 1KB, 1MB - - for _, size := range sizes { - t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) { - // Create customer key - key := make([]byte, 32) - for i := range key { - key[i] = byte(i + size) // Make key unique per test - } - - md5sumDyn := md5.Sum(key) - customerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: key, - KeyMD5: base64.StdEncoding.EncodeToString(md5sumDyn[:]), - } - - // Create test data of specified size - testData := make([]byte, size) - for i := range testData { - testData[i] = byte('A' + (i % 26)) // Pattern of A-Z - } - - // Encrypt - dataReader := bytes.NewReader(testData) - encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Verify encrypted data has same size as original (IV is stored in metadata, not in stream) - if len(encryptedData) != size { - t.Errorf("Expected encrypted data length %d (same as original), got %d", size, len(encryptedData)) - } - - // Decrypt - encryptedReader2 := bytes.NewReader(encryptedData) - decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data: %v", err) - } - - // Verify decrypted data matches original - if !bytes.Equal(decryptedData, testData) { - t.Errorf("Decrypted data doesn't match original for size %d", size) - } - }) - } -} - -func TestSSECEncryptionWithNilKey(t *testing.T) { - testData := []byte("test data") - dataReader := bytes.NewReader(testData) - - // Test encryption with nil key (should pass through) - encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, nil) - if err != nil { - t.Fatalf("Failed to create encrypted reader with nil key: %v", err) - } - - result, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read from pass-through reader: %v", err) - } - - if !bytes.Equal(result, testData) { - t.Error("Data should pass through unchanged when key is nil") - } - - // Test decryption with nil key (should pass through) - dataReader2 := bytes.NewReader(testData) - decryptedReader, err := CreateSSECDecryptedReader(dataReader2, nil, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader with nil key: %v", err) - } - - result2, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read from pass-through reader: %v", err) - } - - if !bytes.Equal(result2, testData) { - t.Error("Data should pass through unchanged when key is nil") - } -} - -// TestSSECEncryptionSmallBuffers tests the fix for the critical bug where small buffers -// could corrupt the data stream when reading in chunks smaller than the IV size -func TestSSECEncryptionSmallBuffers(t *testing.T) { - testData := []byte("This is a test message for small buffer reads") - - // Create customer key - key := make([]byte, 32) - for i := range key { - key[i] = byte(i) - } - - md5sumKey3 := md5.Sum(key) - customerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: key, - KeyMD5: base64.StdEncoding.EncodeToString(md5sumKey3[:]), - } - - // Create encrypted reader - dataReader := bytes.NewReader(testData) - encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - // Read with very small buffers (smaller than IV size of 16 bytes) - var encryptedData []byte - smallBuffer := make([]byte, 5) // Much smaller than 16-byte IV - - for { - n, err := encryptedReader.Read(smallBuffer) - if n > 0 { - encryptedData = append(encryptedData, smallBuffer[:n]...) - } - if err == io.EOF { - break - } - if err != nil { - t.Fatalf("Error reading encrypted data: %v", err) - } - } - - // Verify we have some encrypted data (IV is in metadata, not in stream) - if len(encryptedData) == 0 && len(testData) > 0 { - t.Fatal("Expected encrypted data but got none") - } - - // Expected size: same as original data (IV is stored in metadata, not in stream) - if len(encryptedData) != len(testData) { - t.Errorf("Expected encrypted data size %d (same as original), got %d", len(testData), len(encryptedData)) - } - - // Decrypt and verify - encryptedReader2 := bytes.NewReader(encryptedData) - decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data: %v", err) - } - - if !bytes.Equal(decryptedData, testData) { - t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData) - } -} diff --git a/weed/s3api/s3_sse_copy_test.go b/weed/s3api/s3_sse_copy_test.go deleted file mode 100644 index b377b45a9..000000000 --- a/weed/s3api/s3_sse_copy_test.go +++ /dev/null @@ -1,628 +0,0 @@ -package s3api - -import ( - "bytes" - "io" - "net/http" - "strings" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" -) - -// TestSSECObjectCopy tests copying SSE-C encrypted objects with different keys -func TestSSECObjectCopy(t *testing.T) { - // Original key for source object - sourceKey := GenerateTestSSECKey(1) - sourceCustomerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: sourceKey.Key, - KeyMD5: sourceKey.KeyMD5, - } - - // Destination key for target object - destKey := GenerateTestSSECKey(2) - destCustomerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: destKey.Key, - KeyMD5: destKey.KeyMD5, - } - - testData := "Hello, SSE-C copy world!" - - // Encrypt with source key - encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), sourceCustomerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Test copy strategy determination - sourceMetadata := make(map[string][]byte) - StoreSSECIVInMetadata(sourceMetadata, iv) - sourceMetadata[s3_constants.AmzServerSideEncryptionCustomerAlgorithm] = []byte("AES256") - sourceMetadata[s3_constants.AmzServerSideEncryptionCustomerKeyMD5] = []byte(sourceKey.KeyMD5) - - t.Run("Same key copy (direct copy)", func(t *testing.T) { - strategy, err := DetermineSSECCopyStrategy(sourceMetadata, sourceCustomerKey, sourceCustomerKey) - if err != nil { - t.Fatalf("Failed to determine copy strategy: %v", err) - } - - if strategy != SSECCopyStrategyDirect { - t.Errorf("Expected direct copy strategy for same key, got %v", strategy) - } - }) - - t.Run("Different key copy (decrypt-encrypt)", func(t *testing.T) { - strategy, err := DetermineSSECCopyStrategy(sourceMetadata, sourceCustomerKey, destCustomerKey) - if err != nil { - t.Fatalf("Failed to determine copy strategy: %v", err) - } - - if strategy != SSECCopyStrategyDecryptEncrypt { - t.Errorf("Expected decrypt-encrypt copy strategy for different keys, got %v", strategy) - } - }) - - t.Run("Can direct copy check", func(t *testing.T) { - // Same key should allow direct copy - canDirect := CanDirectCopySSEC(sourceMetadata, sourceCustomerKey, sourceCustomerKey) - if !canDirect { - t.Error("Should allow direct copy with same key") - } - - // Different key should not allow direct copy - canDirect = CanDirectCopySSEC(sourceMetadata, sourceCustomerKey, destCustomerKey) - if canDirect { - t.Error("Should not allow direct copy with different keys") - } - }) - - // Test actual copy operation (decrypt with source key, encrypt with dest key) - t.Run("Full copy operation", func(t *testing.T) { - // Decrypt with source key - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), sourceCustomerKey, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - // Re-encrypt with destination key - reEncryptedReader, destIV, err := CreateSSECEncryptedReader(decryptedReader, destCustomerKey) - if err != nil { - t.Fatalf("Failed to create re-encrypted reader: %v", err) - } - - reEncryptedData, err := io.ReadAll(reEncryptedReader) - if err != nil { - t.Fatalf("Failed to read re-encrypted data: %v", err) - } - - // Verify we can decrypt with destination key - finalDecryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(reEncryptedData), destCustomerKey, destIV) - if err != nil { - t.Fatalf("Failed to create final decrypted reader: %v", err) - } - - finalData, err := io.ReadAll(finalDecryptedReader) - if err != nil { - t.Fatalf("Failed to read final decrypted data: %v", err) - } - - if string(finalData) != testData { - t.Errorf("Expected %s, got %s", testData, string(finalData)) - } - }) -} - -// TestSSEKMSObjectCopy tests copying SSE-KMS encrypted objects -func TestSSEKMSObjectCopy(t *testing.T) { - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - testData := "Hello, SSE-KMS copy world!" - encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false) - - // Encrypt with SSE-KMS - encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(testData), kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - t.Run("Same KMS key copy", func(t *testing.T) { - // Decrypt with original key - decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - // Re-encrypt with same KMS key - reEncryptedReader, newSseKey, err := CreateSSEKMSEncryptedReader(decryptedReader, kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create re-encrypted reader: %v", err) - } - - reEncryptedData, err := io.ReadAll(reEncryptedReader) - if err != nil { - t.Fatalf("Failed to read re-encrypted data: %v", err) - } - - // Verify we can decrypt with new key - finalDecryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(reEncryptedData), newSseKey) - if err != nil { - t.Fatalf("Failed to create final decrypted reader: %v", err) - } - - finalData, err := io.ReadAll(finalDecryptedReader) - if err != nil { - t.Fatalf("Failed to read final decrypted data: %v", err) - } - - if string(finalData) != testData { - t.Errorf("Expected %s, got %s", testData, string(finalData)) - } - }) -} - -// TestSSECToSSEKMSCopy tests cross-encryption copy (SSE-C to SSE-KMS) -func TestSSECToSSEKMSCopy(t *testing.T) { - // Setup SSE-C key - ssecKey := GenerateTestSSECKey(1) - ssecCustomerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: ssecKey.Key, - KeyMD5: ssecKey.KeyMD5, - } - - // Setup SSE-KMS - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - testData := "Hello, cross-encryption copy world!" - - // Encrypt with SSE-C - encryptedReader, ssecIV, err := CreateSSECEncryptedReader(strings.NewReader(testData), ssecCustomerKey) - if err != nil { - t.Fatalf("Failed to create SSE-C encrypted reader: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read SSE-C encrypted data: %v", err) - } - - // Decrypt SSE-C data - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), ssecCustomerKey, ssecIV) - if err != nil { - t.Fatalf("Failed to create SSE-C decrypted reader: %v", err) - } - - // Re-encrypt with SSE-KMS - encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false) - reEncryptedReader, sseKmsKey, err := CreateSSEKMSEncryptedReader(decryptedReader, kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create SSE-KMS encrypted reader: %v", err) - } - - reEncryptedData, err := io.ReadAll(reEncryptedReader) - if err != nil { - t.Fatalf("Failed to read SSE-KMS encrypted data: %v", err) - } - - // Decrypt with SSE-KMS - finalDecryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(reEncryptedData), sseKmsKey) - if err != nil { - t.Fatalf("Failed to create SSE-KMS decrypted reader: %v", err) - } - - finalData, err := io.ReadAll(finalDecryptedReader) - if err != nil { - t.Fatalf("Failed to read final decrypted data: %v", err) - } - - if string(finalData) != testData { - t.Errorf("Expected %s, got %s", testData, string(finalData)) - } -} - -// TestSSEKMSToSSECCopy tests cross-encryption copy (SSE-KMS to SSE-C) -func TestSSEKMSToSSECCopy(t *testing.T) { - // Setup SSE-KMS - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - // Setup SSE-C key - ssecKey := GenerateTestSSECKey(1) - ssecCustomerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: ssecKey.Key, - KeyMD5: ssecKey.KeyMD5, - } - - testData := "Hello, reverse cross-encryption copy world!" - encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false) - - // Encrypt with SSE-KMS - encryptedReader, sseKmsKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(testData), kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create SSE-KMS encrypted reader: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read SSE-KMS encrypted data: %v", err) - } - - // Decrypt SSE-KMS data - decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKmsKey) - if err != nil { - t.Fatalf("Failed to create SSE-KMS decrypted reader: %v", err) - } - - // Re-encrypt with SSE-C - reEncryptedReader, reEncryptedIV, err := CreateSSECEncryptedReader(decryptedReader, ssecCustomerKey) - if err != nil { - t.Fatalf("Failed to create SSE-C encrypted reader: %v", err) - } - - reEncryptedData, err := io.ReadAll(reEncryptedReader) - if err != nil { - t.Fatalf("Failed to read SSE-C encrypted data: %v", err) - } - - // Decrypt with SSE-C - finalDecryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(reEncryptedData), ssecCustomerKey, reEncryptedIV) - if err != nil { - t.Fatalf("Failed to create SSE-C decrypted reader: %v", err) - } - - finalData, err := io.ReadAll(finalDecryptedReader) - if err != nil { - t.Fatalf("Failed to read final decrypted data: %v", err) - } - - if string(finalData) != testData { - t.Errorf("Expected %s, got %s", testData, string(finalData)) - } -} - -// TestSSECopyWithCorruptedSource tests copy operations with corrupted source data -func TestSSECopyWithCorruptedSource(t *testing.T) { - ssecKey := GenerateTestSSECKey(1) - ssecCustomerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: ssecKey.Key, - KeyMD5: ssecKey.KeyMD5, - } - - testData := "Hello, corruption test!" - - // Encrypt data - encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), ssecCustomerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Corrupt the encrypted data - corruptedData := make([]byte, len(encryptedData)) - copy(corruptedData, encryptedData) - if len(corruptedData) > s3_constants.AESBlockSize { - // Corrupt a byte after the IV - corruptedData[s3_constants.AESBlockSize] ^= 0xFF - } - - // Try to decrypt corrupted data - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(corruptedData), ssecCustomerKey, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader for corrupted data: %v", err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - // This is okay - corrupted data might cause read errors - t.Logf("Read error for corrupted data (expected): %v", err) - return - } - - // If we can read it, the data should be different from original - if string(decryptedData) == testData { - t.Error("Decrypted corrupted data should not match original") - } -} - -// TestSSEKMSCopyStrategy tests SSE-KMS copy strategy determination -func TestSSEKMSCopyStrategy(t *testing.T) { - tests := []struct { - name string - srcMetadata map[string][]byte - destKeyID string - expectedStrategy SSEKMSCopyStrategy - }{ - { - name: "Unencrypted to unencrypted", - srcMetadata: map[string][]byte{}, - destKeyID: "", - expectedStrategy: SSEKMSCopyStrategyDirect, - }, - { - name: "Same KMS key", - srcMetadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"), - }, - destKeyID: "test-key-123", - expectedStrategy: SSEKMSCopyStrategyDirect, - }, - { - name: "Different KMS keys", - srcMetadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"), - }, - destKeyID: "test-key-456", - expectedStrategy: SSEKMSCopyStrategyDecryptEncrypt, - }, - { - name: "Encrypted to unencrypted", - srcMetadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"), - }, - destKeyID: "", - expectedStrategy: SSEKMSCopyStrategyDecryptEncrypt, - }, - { - name: "Unencrypted to encrypted", - srcMetadata: map[string][]byte{}, - destKeyID: "test-key-123", - expectedStrategy: SSEKMSCopyStrategyDecryptEncrypt, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - strategy, err := DetermineSSEKMSCopyStrategy(tt.srcMetadata, tt.destKeyID) - if err != nil { - t.Fatalf("DetermineSSEKMSCopyStrategy failed: %v", err) - } - if strategy != tt.expectedStrategy { - t.Errorf("Expected strategy %v, got %v", tt.expectedStrategy, strategy) - } - }) - } -} - -// TestSSEKMSCopyHeaders tests SSE-KMS copy header parsing -func TestSSEKMSCopyHeaders(t *testing.T) { - tests := []struct { - name string - headers map[string]string - expectedKeyID string - expectedContext map[string]string - expectedBucketKey bool - expectError bool - }{ - { - name: "No SSE-KMS headers", - headers: map[string]string{}, - expectedKeyID: "", - expectedContext: nil, - expectedBucketKey: false, - expectError: false, - }, - { - name: "SSE-KMS with key ID", - headers: map[string]string{ - s3_constants.AmzServerSideEncryption: "aws:kms", - s3_constants.AmzServerSideEncryptionAwsKmsKeyId: "test-key-123", - }, - expectedKeyID: "test-key-123", - expectedContext: nil, - expectedBucketKey: false, - expectError: false, - }, - { - name: "SSE-KMS with all options", - headers: map[string]string{ - s3_constants.AmzServerSideEncryption: "aws:kms", - s3_constants.AmzServerSideEncryptionAwsKmsKeyId: "test-key-123", - s3_constants.AmzServerSideEncryptionContext: "eyJ0ZXN0IjoidmFsdWUifQ==", // base64 of {"test":"value"} - s3_constants.AmzServerSideEncryptionBucketKeyEnabled: "true", - }, - expectedKeyID: "test-key-123", - expectedContext: map[string]string{"test": "value"}, - expectedBucketKey: true, - expectError: false, - }, - { - name: "Invalid key ID", - headers: map[string]string{ - s3_constants.AmzServerSideEncryption: "aws:kms", - s3_constants.AmzServerSideEncryptionAwsKmsKeyId: "invalid key id", - }, - expectError: true, - }, - { - name: "Invalid encryption context", - headers: map[string]string{ - s3_constants.AmzServerSideEncryption: "aws:kms", - s3_constants.AmzServerSideEncryptionContext: "invalid-base64!", - }, - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req, _ := http.NewRequest("PUT", "/test", nil) - for k, v := range tt.headers { - req.Header.Set(k, v) - } - - keyID, context, bucketKey, err := ParseSSEKMSCopyHeaders(req) - - if tt.expectError { - if err == nil { - t.Error("Expected error but got none") - } - return - } - - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - if keyID != tt.expectedKeyID { - t.Errorf("Expected keyID %s, got %s", tt.expectedKeyID, keyID) - } - - if !mapsEqual(context, tt.expectedContext) { - t.Errorf("Expected context %v, got %v", tt.expectedContext, context) - } - - if bucketKey != tt.expectedBucketKey { - t.Errorf("Expected bucketKey %v, got %v", tt.expectedBucketKey, bucketKey) - } - }) - } -} - -// TestSSEKMSDirectCopy tests direct copy scenarios -func TestSSEKMSDirectCopy(t *testing.T) { - tests := []struct { - name string - srcMetadata map[string][]byte - destKeyID string - canDirect bool - }{ - { - name: "Both unencrypted", - srcMetadata: map[string][]byte{}, - destKeyID: "", - canDirect: true, - }, - { - name: "Same key ID", - srcMetadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"), - }, - destKeyID: "test-key-123", - canDirect: true, - }, - { - name: "Different key IDs", - srcMetadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"), - }, - destKeyID: "test-key-456", - canDirect: false, - }, - { - name: "Source encrypted, dest unencrypted", - srcMetadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"), - }, - destKeyID: "", - canDirect: false, - }, - { - name: "Source unencrypted, dest encrypted", - srcMetadata: map[string][]byte{}, - destKeyID: "test-key-123", - canDirect: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - canDirect := CanDirectCopySSEKMS(tt.srcMetadata, tt.destKeyID) - if canDirect != tt.canDirect { - t.Errorf("Expected canDirect %v, got %v", tt.canDirect, canDirect) - } - }) - } -} - -// TestGetSourceSSEKMSInfo tests extraction of SSE-KMS info from metadata -func TestGetSourceSSEKMSInfo(t *testing.T) { - tests := []struct { - name string - metadata map[string][]byte - expectedKeyID string - expectedEncrypted bool - }{ - { - name: "No encryption", - metadata: map[string][]byte{}, - expectedKeyID: "", - expectedEncrypted: false, - }, - { - name: "SSE-KMS with key ID", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"), - }, - expectedKeyID: "test-key-123", - expectedEncrypted: true, - }, - { - name: "SSE-KMS without key ID (default key)", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - }, - expectedKeyID: "", - expectedEncrypted: true, - }, - { - name: "Non-KMS encryption", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - }, - expectedKeyID: "", - expectedEncrypted: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - keyID, encrypted := GetSourceSSEKMSInfo(tt.metadata) - if keyID != tt.expectedKeyID { - t.Errorf("Expected keyID %s, got %s", tt.expectedKeyID, keyID) - } - if encrypted != tt.expectedEncrypted { - t.Errorf("Expected encrypted %v, got %v", tt.expectedEncrypted, encrypted) - } - }) - } -} - -// Helper function to compare maps -func mapsEqual(a, b map[string]string) bool { - if len(a) != len(b) { - return false - } - for k, v := range a { - if b[k] != v { - return false - } - } - return true -} diff --git a/weed/s3api/s3_sse_error_test.go b/weed/s3api/s3_sse_error_test.go deleted file mode 100644 index a344e2ef7..000000000 --- a/weed/s3api/s3_sse_error_test.go +++ /dev/null @@ -1,400 +0,0 @@ -package s3api - -import ( - "bytes" - "fmt" - "io" - "net/http" - "strings" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" -) - -// TestSSECWrongKeyDecryption tests decryption with wrong SSE-C key -func TestSSECWrongKeyDecryption(t *testing.T) { - // Setup original key and encrypt data - originalKey := GenerateTestSSECKey(1) - testData := "Hello, SSE-C world!" - - encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), &SSECustomerKey{ - Algorithm: "AES256", - Key: originalKey.Key, - KeyMD5: originalKey.KeyMD5, - }) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - // Read encrypted data - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Try to decrypt with wrong key - wrongKey := GenerateTestSSECKey(2) // Different seed = different key - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), &SSECustomerKey{ - Algorithm: "AES256", - Key: wrongKey.Key, - KeyMD5: wrongKey.KeyMD5, - }, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - // Read decrypted data - should be garbage/different from original - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data: %v", err) - } - - // Verify the decrypted data is NOT the same as original (wrong key used) - if string(decryptedData) == testData { - t.Error("Decryption with wrong key should not produce original data") - } -} - -// TestSSEKMSKeyNotFound tests handling of missing KMS key -func TestSSEKMSKeyNotFound(t *testing.T) { - // Note: The local KMS provider creates keys on-demand by design. - // This test validates that when on-demand creation fails or is disabled, - // appropriate errors are returned. - - // Test with an invalid key ID that would fail even on-demand creation - invalidKeyID := "" // Empty key ID should fail - encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false) - - _, _, err := CreateSSEKMSEncryptedReader(strings.NewReader("test data"), invalidKeyID, encryptionContext) - - // Should get an error for invalid/empty key - if err == nil { - t.Error("Expected error for empty KMS key ID, got none") - } - - // For local KMS with on-demand creation, we test what we can realistically test - if err != nil { - t.Logf("Got expected error for empty key ID: %v", err) - } -} - -// TestSSEHeadersWithoutEncryption tests inconsistent state where headers are present but no encryption -func TestSSEHeadersWithoutEncryption(t *testing.T) { - testCases := []struct { - name string - setupReq func() *http.Request - }{ - { - name: "SSE-C algorithm without key", - setupReq: func() *http.Request { - req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256") - // Missing key and MD5 - return req - }, - }, - { - name: "SSE-C key without algorithm", - setupReq: func() *http.Request { - req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) - keyPair := GenerateTestSSECKey(1) - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, keyPair.KeyB64) - // Missing algorithm - return req - }, - }, - { - name: "SSE-KMS key ID without algorithm", - setupReq: func() *http.Request { - req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) - req.Header.Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, "test-key-id") - // Missing algorithm - return req - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req := tc.setupReq() - - // Validate headers - should catch incomplete configurations - if strings.Contains(tc.name, "SSE-C") { - err := ValidateSSECHeaders(req) - if err == nil { - t.Error("Expected validation error for incomplete SSE-C headers") - } - } - }) - } -} - -// TestSSECInvalidKeyFormats tests various invalid SSE-C key formats -func TestSSECInvalidKeyFormats(t *testing.T) { - testCases := []struct { - name string - algorithm string - key string - keyMD5 string - expectErr bool - }{ - { - name: "Invalid algorithm", - algorithm: "AES128", - key: "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXk=", // 32 bytes base64 - keyMD5: "valid-md5-hash", - expectErr: true, - }, - { - name: "Invalid key length (too short)", - algorithm: "AES256", - key: "c2hvcnRrZXk=", // "shortkey" base64 - too short - keyMD5: "valid-md5-hash", - expectErr: true, - }, - { - name: "Invalid key length (too long)", - algorithm: "AES256", - key: "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleQ==", // too long - keyMD5: "valid-md5-hash", - expectErr: true, - }, - { - name: "Invalid base64 key", - algorithm: "AES256", - key: "invalid-base64!", - keyMD5: "valid-md5-hash", - expectErr: true, - }, - { - name: "Invalid base64 MD5", - algorithm: "AES256", - key: "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXk=", - keyMD5: "invalid-base64!", - expectErr: true, - }, - { - name: "Mismatched MD5", - algorithm: "AES256", - key: "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXk=", - keyMD5: "d29uZy1tZDUtaGFzaA==", // "wrong-md5-hash" base64 - expectErr: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, tc.algorithm) - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, tc.key) - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, tc.keyMD5) - - err := ValidateSSECHeaders(req) - if tc.expectErr && err == nil { - t.Errorf("Expected error for %s, but got none", tc.name) - } - if !tc.expectErr && err != nil { - t.Errorf("Expected no error for %s, but got: %v", tc.name, err) - } - }) - } -} - -// TestSSEKMSInvalidConfigurations tests various invalid SSE-KMS configurations -func TestSSEKMSInvalidConfigurations(t *testing.T) { - testCases := []struct { - name string - setupRequest func() *http.Request - expectError bool - }{ - { - name: "Invalid algorithm", - setupRequest: func() *http.Request { - req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) - req.Header.Set(s3_constants.AmzServerSideEncryption, "invalid-algorithm") - return req - }, - expectError: true, - }, - { - name: "Empty key ID", - setupRequest: func() *http.Request { - req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) - req.Header.Set(s3_constants.AmzServerSideEncryption, "aws:kms") - req.Header.Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, "") - return req - }, - expectError: false, // Empty key ID might be valid (use default) - }, - { - name: "Invalid key ID format", - setupRequest: func() *http.Request { - req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) - req.Header.Set(s3_constants.AmzServerSideEncryption, "aws:kms") - req.Header.Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, "invalid key id with spaces") - return req - }, - expectError: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req := tc.setupRequest() - - _, err := ParseSSEKMSHeaders(req) - if tc.expectError && err == nil { - t.Errorf("Expected error for %s, but got none", tc.name) - } - if !tc.expectError && err != nil { - t.Errorf("Expected no error for %s, but got: %v", tc.name, err) - } - }) - } -} - -// TestSSEEmptyDataHandling tests handling of empty data with SSE -func TestSSEEmptyDataHandling(t *testing.T) { - t.Run("SSE-C with empty data", func(t *testing.T) { - keyPair := GenerateTestSSECKey(1) - customerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: keyPair.Key, - KeyMD5: keyPair.KeyMD5, - } - - // Encrypt empty data - encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(""), customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader for empty data: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted empty data: %v", err) - } - - // Should have IV for empty data - if len(iv) != s3_constants.AESBlockSize { - t.Error("IV should be present even for empty data") - } - - // Decrypt and verify - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader for empty data: %v", err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted empty data: %v", err) - } - - if len(decryptedData) != 0 { - t.Errorf("Expected empty decrypted data, got %d bytes", len(decryptedData)) - } - }) - - t.Run("SSE-KMS with empty data", func(t *testing.T) { - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false) - - // Encrypt empty data - encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(""), kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create encrypted reader for empty data: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted empty data: %v", err) - } - - // Empty data should produce empty encrypted data (IV is stored in metadata) - if len(encryptedData) != 0 { - t.Errorf("Encrypted empty data should be empty, got %d bytes", len(encryptedData)) - } - - // Decrypt and verify - decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey) - if err != nil { - t.Fatalf("Failed to create decrypted reader for empty data: %v", err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted empty data: %v", err) - } - - if len(decryptedData) != 0 { - t.Errorf("Expected empty decrypted data, got %d bytes", len(decryptedData)) - } - }) -} - -// TestSSEConcurrentAccess tests SSE operations under concurrent access -func TestSSEConcurrentAccess(t *testing.T) { - keyPair := GenerateTestSSECKey(1) - customerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: keyPair.Key, - KeyMD5: keyPair.KeyMD5, - } - - const numGoroutines = 10 - done := make(chan bool, numGoroutines) - errors := make(chan error, numGoroutines) - - // Run multiple encryption/decryption operations concurrently - for i := 0; i < numGoroutines; i++ { - go func(id int) { - defer func() { done <- true }() - - testData := fmt.Sprintf("test data %d", id) - - // Encrypt - encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), customerKey) - if err != nil { - errors <- fmt.Errorf("goroutine %d encrypt error: %v", id, err) - return - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - errors <- fmt.Errorf("goroutine %d read encrypted error: %v", id, err) - return - } - - // Decrypt - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv) - if err != nil { - errors <- fmt.Errorf("goroutine %d decrypt error: %v", id, err) - return - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - errors <- fmt.Errorf("goroutine %d read decrypted error: %v", id, err) - return - } - - if string(decryptedData) != testData { - errors <- fmt.Errorf("goroutine %d data mismatch: expected %s, got %s", id, testData, string(decryptedData)) - return - } - }(i) - } - - // Wait for all goroutines to complete - for i := 0; i < numGoroutines; i++ { - <-done - } - - // Check for errors - close(errors) - for err := range errors { - t.Error(err) - } -} diff --git a/weed/s3api/s3_sse_http_test.go b/weed/s3api/s3_sse_http_test.go deleted file mode 100644 index 95f141ca7..000000000 --- a/weed/s3api/s3_sse_http_test.go +++ /dev/null @@ -1,401 +0,0 @@ -package s3api - -import ( - "bytes" - "net/http" - "net/http/httptest" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" -) - -// TestPutObjectWithSSEC tests PUT object with SSE-C through HTTP handler -func TestPutObjectWithSSEC(t *testing.T) { - keyPair := GenerateTestSSECKey(1) - testData := "Hello, SSE-C PUT object!" - - // Create HTTP request - req := CreateTestHTTPRequest("PUT", "/test-bucket/test-object", []byte(testData)) - SetupTestSSECHeaders(req, keyPair) - SetupTestMuxVars(req, map[string]string{ - "bucket": "test-bucket", - "object": "test-object", - }) - - // Create response recorder - w := CreateTestHTTPResponse() - - // Test header validation - err := ValidateSSECHeaders(req) - if err != nil { - t.Fatalf("Header validation failed: %v", err) - } - - // Parse SSE-C headers - customerKey, err := ParseSSECHeaders(req) - if err != nil { - t.Fatalf("Failed to parse SSE-C headers: %v", err) - } - - if customerKey == nil { - t.Fatal("Expected customer key, got nil") - } - - // Verify parsed key matches input - if !bytes.Equal(customerKey.Key, keyPair.Key) { - t.Error("Parsed key doesn't match input key") - } - - if customerKey.KeyMD5 != keyPair.KeyMD5 { - t.Errorf("Parsed key MD5 doesn't match: expected %s, got %s", keyPair.KeyMD5, customerKey.KeyMD5) - } - - // Simulate setting response headers - w.Header().Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256") - w.Header().Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyPair.KeyMD5) - - // Verify response headers - AssertSSECHeaders(t, w, keyPair) -} - -// TestGetObjectWithSSEC tests GET object with SSE-C through HTTP handler -func TestGetObjectWithSSEC(t *testing.T) { - keyPair := GenerateTestSSECKey(1) - - // Create HTTP request for GET - req := CreateTestHTTPRequest("GET", "/test-bucket/test-object", nil) - SetupTestSSECHeaders(req, keyPair) - SetupTestMuxVars(req, map[string]string{ - "bucket": "test-bucket", - "object": "test-object", - }) - - // Create response recorder - w := CreateTestHTTPResponse() - - // Test that SSE-C is detected for GET requests - if !IsSSECRequest(req) { - t.Error("Should detect SSE-C request for GET with SSE-C headers") - } - - // Validate headers - err := ValidateSSECHeaders(req) - if err != nil { - t.Fatalf("Header validation failed: %v", err) - } - - // Simulate response with SSE-C headers - w.Header().Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256") - w.Header().Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyPair.KeyMD5) - w.WriteHeader(http.StatusOK) - - // Verify response - if w.Code != http.StatusOK { - t.Errorf("Expected status 200, got %d", w.Code) - } - - AssertSSECHeaders(t, w, keyPair) -} - -// TestPutObjectWithSSEKMS tests PUT object with SSE-KMS through HTTP handler -func TestPutObjectWithSSEKMS(t *testing.T) { - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - testData := "Hello, SSE-KMS PUT object!" - - // Create HTTP request - req := CreateTestHTTPRequest("PUT", "/test-bucket/test-object", []byte(testData)) - SetupTestSSEKMSHeaders(req, kmsKey.KeyID) - SetupTestMuxVars(req, map[string]string{ - "bucket": "test-bucket", - "object": "test-object", - }) - - // Create response recorder - w := CreateTestHTTPResponse() - - // Test that SSE-KMS is detected - if !IsSSEKMSRequest(req) { - t.Error("Should detect SSE-KMS request") - } - - // Parse SSE-KMS headers - sseKmsKey, err := ParseSSEKMSHeaders(req) - if err != nil { - t.Fatalf("Failed to parse SSE-KMS headers: %v", err) - } - - if sseKmsKey == nil { - t.Fatal("Expected SSE-KMS key, got nil") - } - - if sseKmsKey.KeyID != kmsKey.KeyID { - t.Errorf("Parsed key ID doesn't match: expected %s, got %s", kmsKey.KeyID, sseKmsKey.KeyID) - } - - // Simulate setting response headers - w.Header().Set(s3_constants.AmzServerSideEncryption, "aws:kms") - w.Header().Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, kmsKey.KeyID) - - // Verify response headers - AssertSSEKMSHeaders(t, w, kmsKey.KeyID) -} - -// TestGetObjectWithSSEKMS tests GET object with SSE-KMS through HTTP handler -func TestGetObjectWithSSEKMS(t *testing.T) { - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - // Create HTTP request for GET (no SSE headers needed for GET) - req := CreateTestHTTPRequest("GET", "/test-bucket/test-object", nil) - SetupTestMuxVars(req, map[string]string{ - "bucket": "test-bucket", - "object": "test-object", - }) - - // Create response recorder - w := CreateTestHTTPResponse() - - // Simulate response with SSE-KMS headers (would come from stored metadata) - w.Header().Set(s3_constants.AmzServerSideEncryption, "aws:kms") - w.Header().Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, kmsKey.KeyID) - w.WriteHeader(http.StatusOK) - - // Verify response - if w.Code != http.StatusOK { - t.Errorf("Expected status 200, got %d", w.Code) - } - - AssertSSEKMSHeaders(t, w, kmsKey.KeyID) -} - -// TestSSECRangeRequestSupport tests that range requests are now supported for SSE-C -func TestSSECRangeRequestSupport(t *testing.T) { - keyPair := GenerateTestSSECKey(1) - - // Create HTTP request with Range header - req := CreateTestHTTPRequest("GET", "/test-bucket/test-object", nil) - req.Header.Set("Range", "bytes=0-100") - SetupTestSSECHeaders(req, keyPair) - SetupTestMuxVars(req, map[string]string{ - "bucket": "test-bucket", - "object": "test-object", - }) - - // Create a mock proxy response with SSE-C headers - proxyResponse := httptest.NewRecorder() - proxyResponse.Header().Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256") - proxyResponse.Header().Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyPair.KeyMD5) - proxyResponse.Header().Set("Content-Length", "1000") - - // Test the detection logic - these should all still work - - // Should detect as SSE-C request - if !IsSSECRequest(req) { - t.Error("Should detect SSE-C request") - } - - // Should detect range request - if req.Header.Get("Range") == "" { - t.Error("Range header should be present") - } - - // The combination should now be allowed and handled by the filer layer - // Range requests with SSE-C are now supported since IV is stored in metadata -} - -// TestSSEHeaderConflicts tests conflicting SSE headers -func TestSSEHeaderConflicts(t *testing.T) { - testCases := []struct { - name string - setupFn func(*http.Request) - valid bool - }{ - { - name: "SSE-C and SSE-KMS conflict", - setupFn: func(req *http.Request) { - keyPair := GenerateTestSSECKey(1) - SetupTestSSECHeaders(req, keyPair) - SetupTestSSEKMSHeaders(req, "test-key-id") - }, - valid: false, - }, - { - name: "Valid SSE-C only", - setupFn: func(req *http.Request) { - keyPair := GenerateTestSSECKey(1) - SetupTestSSECHeaders(req, keyPair) - }, - valid: true, - }, - { - name: "Valid SSE-KMS only", - setupFn: func(req *http.Request) { - SetupTestSSEKMSHeaders(req, "test-key-id") - }, - valid: true, - }, - { - name: "No SSE headers", - setupFn: func(req *http.Request) { - // No SSE headers - }, - valid: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req := CreateTestHTTPRequest("PUT", "/test-bucket/test-object", []byte("test")) - tc.setupFn(req) - - ssecDetected := IsSSECRequest(req) - sseKmsDetected := IsSSEKMSRequest(req) - - // Both shouldn't be detected simultaneously - if ssecDetected && sseKmsDetected { - t.Error("Both SSE-C and SSE-KMS should not be detected simultaneously") - } - - // Test validation if SSE-C is detected - if ssecDetected { - err := ValidateSSECHeaders(req) - if tc.valid && err != nil { - t.Errorf("Expected valid SSE-C headers, got error: %v", err) - } - if !tc.valid && err == nil && tc.name == "SSE-C and SSE-KMS conflict" { - // This specific test case should probably be handled at a higher level - t.Log("Conflict detection should be handled by higher-level validation") - } - } - }) - } -} - -// TestSSECopySourceHeaders tests copy operations with SSE headers -func TestSSECopySourceHeaders(t *testing.T) { - sourceKey := GenerateTestSSECKey(1) - destKey := GenerateTestSSECKey(2) - - // Create copy request with both source and destination SSE-C headers - req := CreateTestHTTPRequest("PUT", "/dest-bucket/dest-object", nil) - - // Set copy source headers - SetupTestSSECCopyHeaders(req, sourceKey) - - // Set destination headers - SetupTestSSECHeaders(req, destKey) - - // Set copy source - req.Header.Set("X-Amz-Copy-Source", "/source-bucket/source-object") - - SetupTestMuxVars(req, map[string]string{ - "bucket": "dest-bucket", - "object": "dest-object", - }) - - // Parse copy source headers - copySourceKey, err := ParseSSECCopySourceHeaders(req) - if err != nil { - t.Fatalf("Failed to parse copy source headers: %v", err) - } - - if copySourceKey == nil { - t.Fatal("Expected copy source key, got nil") - } - - if !bytes.Equal(copySourceKey.Key, sourceKey.Key) { - t.Error("Copy source key doesn't match") - } - - // Parse destination headers - destCustomerKey, err := ParseSSECHeaders(req) - if err != nil { - t.Fatalf("Failed to parse destination headers: %v", err) - } - - if destCustomerKey == nil { - t.Fatal("Expected destination key, got nil") - } - - if !bytes.Equal(destCustomerKey.Key, destKey.Key) { - t.Error("Destination key doesn't match") - } -} - -// TestSSERequestValidation tests comprehensive request validation -func TestSSERequestValidation(t *testing.T) { - testCases := []struct { - name string - method string - setupFn func(*http.Request) - expectError bool - errorType string - }{ - { - name: "Valid PUT with SSE-C", - method: "PUT", - setupFn: func(req *http.Request) { - keyPair := GenerateTestSSECKey(1) - SetupTestSSECHeaders(req, keyPair) - }, - expectError: false, - }, - { - name: "Valid GET with SSE-C", - method: "GET", - setupFn: func(req *http.Request) { - keyPair := GenerateTestSSECKey(1) - SetupTestSSECHeaders(req, keyPair) - }, - expectError: false, - }, - { - name: "Invalid SSE-C key format", - method: "PUT", - setupFn: func(req *http.Request) { - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256") - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, "invalid-key") - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, "invalid-md5") - }, - expectError: true, - errorType: "InvalidRequest", - }, - { - name: "Missing SSE-C key MD5", - method: "PUT", - setupFn: func(req *http.Request) { - keyPair := GenerateTestSSECKey(1) - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256") - req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, keyPair.KeyB64) - // Missing MD5 - }, - expectError: true, - errorType: "InvalidRequest", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req := CreateTestHTTPRequest(tc.method, "/test-bucket/test-object", []byte("test data")) - tc.setupFn(req) - - SetupTestMuxVars(req, map[string]string{ - "bucket": "test-bucket", - "object": "test-object", - }) - - // Test header validation - if IsSSECRequest(req) { - err := ValidateSSECHeaders(req) - if tc.expectError && err == nil { - t.Errorf("Expected error for %s, but got none", tc.name) - } - if !tc.expectError && err != nil { - t.Errorf("Expected no error for %s, but got: %v", tc.name, err) - } - } - }) - } -} diff --git a/weed/s3api/s3_sse_kms.go b/weed/s3api/s3_sse_kms.go index fa9451a8f..b87e0bf1a 100644 --- a/weed/s3api/s3_sse_kms.go +++ b/weed/s3api/s3_sse_kms.go @@ -59,11 +59,6 @@ const ( // Bucket key cache TTL (moved to be used with per-bucket cache) const BucketKeyCacheTTL = time.Hour -// CreateSSEKMSEncryptedReader creates an encrypted reader using KMS envelope encryption -func CreateSSEKMSEncryptedReader(r io.Reader, keyID string, encryptionContext map[string]string) (io.Reader, *SSEKMSKey, error) { - return CreateSSEKMSEncryptedReaderWithBucketKey(r, keyID, encryptionContext, false) -} - // CreateSSEKMSEncryptedReaderWithBucketKey creates an encrypted reader with optional S3 Bucket Keys optimization func CreateSSEKMSEncryptedReaderWithBucketKey(r io.Reader, keyID string, encryptionContext map[string]string, bucketKeyEnabled bool) (io.Reader, *SSEKMSKey, error) { if bucketKeyEnabled { @@ -111,42 +106,6 @@ func CreateSSEKMSEncryptedReaderWithBucketKey(r io.Reader, keyID string, encrypt return encryptedReader, sseKey, nil } -// CreateSSEKMSEncryptedReaderWithBaseIV creates an SSE-KMS encrypted reader using a provided base IV -// This is used for multipart uploads where all chunks need to use the same base IV -func CreateSSEKMSEncryptedReaderWithBaseIV(r io.Reader, keyID string, encryptionContext map[string]string, bucketKeyEnabled bool, baseIV []byte) (io.Reader, *SSEKMSKey, error) { - if err := ValidateIV(baseIV, "base IV"); err != nil { - return nil, nil, err - } - - // Generate data key using common utility - dataKeyResult, err := generateKMSDataKey(keyID, encryptionContext) - if err != nil { - return nil, nil, err - } - - // Ensure we clear the plaintext data key from memory when done - defer clearKMSDataKey(dataKeyResult) - - // Use the provided base IV instead of generating a new one - iv := make([]byte, s3_constants.AESBlockSize) - copy(iv, baseIV) - - // Create CTR mode cipher stream - stream := cipher.NewCTR(dataKeyResult.Block, iv) - - // Create the SSE-KMS metadata using utility function - sseKey := createSSEKMSKey(dataKeyResult, encryptionContext, bucketKeyEnabled, iv, 0) - - // The IV is stored in SSE key metadata, so the encrypted stream does not need to prepend the IV - // This ensures correct Content-Length for clients - encryptedReader := &cipher.StreamReader{S: stream, R: r} - - // Store the base IV in the SSE key for metadata storage - sseKey.IV = iv - - return encryptedReader, sseKey, nil -} - // CreateSSEKMSEncryptedReaderWithBaseIVAndOffset creates an SSE-KMS encrypted reader using a provided base IV and offset // This is used for multipart uploads where all chunks need unique IVs to prevent IV reuse vulnerabilities func CreateSSEKMSEncryptedReaderWithBaseIVAndOffset(r io.Reader, keyID string, encryptionContext map[string]string, bucketKeyEnabled bool, baseIV []byte, offset int64) (io.Reader, *SSEKMSKey, error) { @@ -453,67 +412,6 @@ func CreateSSEKMSDecryptedReader(r io.Reader, sseKey *SSEKMSKey) (io.Reader, err return decryptReader, nil } -// ParseSSEKMSHeaders parses SSE-KMS headers from an HTTP request -func ParseSSEKMSHeaders(r *http.Request) (*SSEKMSKey, error) { - sseAlgorithm := r.Header.Get(s3_constants.AmzServerSideEncryption) - - // Check if SSE-KMS is requested - if sseAlgorithm == "" { - return nil, nil // No SSE headers present - } - if sseAlgorithm != s3_constants.SSEAlgorithmKMS { - return nil, fmt.Errorf("invalid SSE algorithm: %s", sseAlgorithm) - } - - keyID := r.Header.Get(s3_constants.AmzServerSideEncryptionAwsKmsKeyId) - encryptionContextHeader := r.Header.Get(s3_constants.AmzServerSideEncryptionContext) - bucketKeyEnabledHeader := r.Header.Get(s3_constants.AmzServerSideEncryptionBucketKeyEnabled) - - // Parse encryption context if provided - var encryptionContext map[string]string - if encryptionContextHeader != "" { - // Decode base64-encoded JSON encryption context - contextBytes, err := base64.StdEncoding.DecodeString(encryptionContextHeader) - if err != nil { - return nil, fmt.Errorf("invalid encryption context format: %v", err) - } - - if err := json.Unmarshal(contextBytes, &encryptionContext); err != nil { - return nil, fmt.Errorf("invalid encryption context JSON: %v", err) - } - } - - // Parse bucket key enabled flag - bucketKeyEnabled := strings.ToLower(bucketKeyEnabledHeader) == "true" - - sseKey := &SSEKMSKey{ - KeyID: keyID, - EncryptionContext: encryptionContext, - BucketKeyEnabled: bucketKeyEnabled, - } - - // Validate the parsed key including key ID format - if err := ValidateSSEKMSKeyInternal(sseKey); err != nil { - return nil, err - } - - return sseKey, nil -} - -// ValidateSSEKMSKey validates an SSE-KMS key configuration -func ValidateSSEKMSKeyInternal(sseKey *SSEKMSKey) error { - if err := ValidateSSEKMSKey(sseKey); err != nil { - return err - } - - // An empty key ID is valid and means the default KMS key should be used. - if sseKey.KeyID != "" && !isValidKMSKeyID(sseKey.KeyID) { - return fmt.Errorf("invalid KMS key ID format: %s", sseKey.KeyID) - } - - return nil -} - // BuildEncryptionContext creates the encryption context for S3 objects func BuildEncryptionContext(bucketName, objectKey string, useBucketKey bool) map[string]string { return kms.BuildS3EncryptionContext(bucketName, objectKey, useBucketKey) @@ -732,28 +630,6 @@ func IsSSEKMSEncrypted(metadata map[string][]byte) bool { return false } -// IsAnySSEEncrypted checks if metadata indicates any type of SSE encryption -func IsAnySSEEncrypted(metadata map[string][]byte) bool { - if metadata == nil { - return false - } - - // Check for any SSE type - if IsSSECEncrypted(metadata) { - return true - } - if IsSSEKMSEncrypted(metadata) { - return true - } - - // Check for SSE-S3 - if sseAlgorithm, exists := metadata[s3_constants.AmzServerSideEncryption]; exists { - return string(sseAlgorithm) == s3_constants.SSEAlgorithmAES256 - } - - return false -} - // MapKMSErrorToS3Error maps KMS errors to appropriate S3 error codes func MapKMSErrorToS3Error(err error) s3err.ErrorCode { if err == nil { @@ -990,21 +866,6 @@ func DetermineUnifiedCopyStrategy(state *EncryptionState, srcMetadata map[string return CopyStrategyDirect, nil } -// DetectEncryptionState analyzes the source metadata and request headers to determine encryption state -func DetectEncryptionState(srcMetadata map[string][]byte, r *http.Request, srcPath, dstPath string) *EncryptionState { - state := &EncryptionState{ - SrcSSEC: IsSSECEncrypted(srcMetadata), - SrcSSEKMS: IsSSEKMSEncrypted(srcMetadata), - SrcSSES3: IsSSES3EncryptedInternal(srcMetadata), - DstSSEC: IsSSECRequest(r), - DstSSEKMS: IsSSEKMSRequest(r), - DstSSES3: IsSSES3RequestInternal(r), - SameObject: srcPath == dstPath, - } - - return state -} - // DetectEncryptionStateWithEntry analyzes the source entry and request headers to determine encryption state // This version can detect multipart encrypted objects by examining chunks func DetectEncryptionStateWithEntry(entry *filer_pb.Entry, r *http.Request, srcPath, dstPath string) *EncryptionState { diff --git a/weed/s3api/s3_sse_kms_test.go b/weed/s3api/s3_sse_kms_test.go deleted file mode 100644 index 487a239a5..000000000 --- a/weed/s3api/s3_sse_kms_test.go +++ /dev/null @@ -1,399 +0,0 @@ -package s3api - -import ( - "bytes" - "encoding/json" - "io" - "strings" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/kms" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" -) - -func TestSSEKMSEncryptionDecryption(t *testing.T) { - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - // Test data - testData := "Hello, SSE-KMS world! This is a test of envelope encryption." - testReader := strings.NewReader(testData) - - // Create encryption context - encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false) - - // Encrypt the data - encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(testReader, kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - // Verify SSE key metadata - if sseKey.KeyID != kmsKey.KeyID { - t.Errorf("Expected key ID %s, got %s", kmsKey.KeyID, sseKey.KeyID) - } - - if len(sseKey.EncryptedDataKey) == 0 { - t.Error("Encrypted data key should not be empty") - } - - if sseKey.EncryptionContext == nil { - t.Error("Encryption context should not be nil") - } - - // Read the encrypted data - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Verify the encrypted data is different from original - if string(encryptedData) == testData { - t.Error("Encrypted data should be different from original data") - } - - // The encrypted data should be same size as original (IV is stored in metadata, not in stream) - if len(encryptedData) != len(testData) { - t.Errorf("Encrypted data should be same size as original: expected %d, got %d", len(testData), len(encryptedData)) - } - - // Decrypt the data - decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - // Read the decrypted data - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data: %v", err) - } - - // Verify the decrypted data matches the original - if string(decryptedData) != testData { - t.Errorf("Decrypted data does not match original.\nExpected: %s\nGot: %s", testData, string(decryptedData)) - } -} - -func TestSSEKMSKeyValidation(t *testing.T) { - tests := []struct { - name string - keyID string - wantValid bool - }{ - { - name: "Valid UUID key ID", - keyID: "12345678-1234-1234-1234-123456789012", - wantValid: true, - }, - { - name: "Valid alias", - keyID: "alias/my-test-key", - wantValid: true, - }, - { - name: "Valid ARN", - keyID: "arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012", - wantValid: true, - }, - { - name: "Valid alias ARN", - keyID: "arn:aws:kms:us-east-1:123456789012:alias/my-test-key", - wantValid: true, - }, - - { - name: "Valid test key format", - keyID: "invalid-key-format", - wantValid: true, // Now valid - following Minio's permissive approach - }, - { - name: "Valid short key", - keyID: "12345678-1234", - wantValid: true, // Now valid - following Minio's permissive approach - }, - { - name: "Invalid - leading space", - keyID: " leading-space", - wantValid: false, - }, - { - name: "Invalid - trailing space", - keyID: "trailing-space ", - wantValid: false, - }, - { - name: "Invalid - empty", - keyID: "", - wantValid: false, - }, - { - name: "Invalid - internal spaces", - keyID: "invalid key id", - wantValid: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - valid := isValidKMSKeyID(tt.keyID) - if valid != tt.wantValid { - t.Errorf("isValidKMSKeyID(%s) = %v, want %v", tt.keyID, valid, tt.wantValid) - } - }) - } -} - -func TestSSEKMSMetadataSerialization(t *testing.T) { - // Create test SSE key - sseKey := &SSEKMSKey{ - KeyID: "test-key-id", - EncryptedDataKey: []byte("encrypted-data-key"), - EncryptionContext: map[string]string{ - "aws:s3:arn": "arn:aws:s3:::test-bucket/test-object", - }, - BucketKeyEnabled: true, - } - - // Serialize metadata - serialized, err := SerializeSSEKMSMetadata(sseKey) - if err != nil { - t.Fatalf("Failed to serialize SSE-KMS metadata: %v", err) - } - - // Verify it's valid JSON - var jsonData map[string]interface{} - if err := json.Unmarshal(serialized, &jsonData); err != nil { - t.Fatalf("Serialized data is not valid JSON: %v", err) - } - - // Deserialize metadata - deserializedKey, err := DeserializeSSEKMSMetadata(serialized) - if err != nil { - t.Fatalf("Failed to deserialize SSE-KMS metadata: %v", err) - } - - // Verify the deserialized data matches original - if deserializedKey.KeyID != sseKey.KeyID { - t.Errorf("KeyID mismatch: expected %s, got %s", sseKey.KeyID, deserializedKey.KeyID) - } - - if !bytes.Equal(deserializedKey.EncryptedDataKey, sseKey.EncryptedDataKey) { - t.Error("EncryptedDataKey mismatch") - } - - if len(deserializedKey.EncryptionContext) != len(sseKey.EncryptionContext) { - t.Error("EncryptionContext length mismatch") - } - - for k, v := range sseKey.EncryptionContext { - if deserializedKey.EncryptionContext[k] != v { - t.Errorf("EncryptionContext mismatch for key %s: expected %s, got %s", k, v, deserializedKey.EncryptionContext[k]) - } - } - - if deserializedKey.BucketKeyEnabled != sseKey.BucketKeyEnabled { - t.Errorf("BucketKeyEnabled mismatch: expected %v, got %v", sseKey.BucketKeyEnabled, deserializedKey.BucketKeyEnabled) - } -} - -func TestBuildEncryptionContext(t *testing.T) { - tests := []struct { - name string - bucket string - object string - useBucketKey bool - expectedARN string - }{ - { - name: "Object-level encryption", - bucket: "test-bucket", - object: "test-object", - useBucketKey: false, - expectedARN: "arn:aws:s3:::test-bucket/test-object", - }, - { - name: "Bucket-level encryption", - bucket: "test-bucket", - object: "test-object", - useBucketKey: true, - expectedARN: "arn:aws:s3:::test-bucket", - }, - { - name: "Nested object path", - bucket: "my-bucket", - object: "folder/subfolder/file.txt", - useBucketKey: false, - expectedARN: "arn:aws:s3:::my-bucket/folder/subfolder/file.txt", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - context := BuildEncryptionContext(tt.bucket, tt.object, tt.useBucketKey) - - if context == nil { - t.Fatal("Encryption context should not be nil") - } - - arn, exists := context[kms.EncryptionContextS3ARN] - if !exists { - t.Error("Encryption context should contain S3 ARN") - } - - if arn != tt.expectedARN { - t.Errorf("Expected ARN %s, got %s", tt.expectedARN, arn) - } - }) - } -} - -func TestKMSErrorMapping(t *testing.T) { - tests := []struct { - name string - kmsError *kms.KMSError - expectedErr string - }{ - { - name: "Key not found", - kmsError: &kms.KMSError{ - Code: kms.ErrCodeNotFoundException, - Message: "Key not found", - }, - expectedErr: "KMSKeyNotFoundException", - }, - { - name: "Access denied", - kmsError: &kms.KMSError{ - Code: kms.ErrCodeAccessDenied, - Message: "Access denied", - }, - expectedErr: "KMSAccessDeniedException", - }, - { - name: "Key unavailable", - kmsError: &kms.KMSError{ - Code: kms.ErrCodeKeyUnavailable, - Message: "Key is disabled", - }, - expectedErr: "KMSKeyDisabledException", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - errorCode := MapKMSErrorToS3Error(tt.kmsError) - - // Get the actual error description - apiError := s3err.GetAPIError(errorCode) - if apiError.Code != tt.expectedErr { - t.Errorf("Expected error code %s, got %s", tt.expectedErr, apiError.Code) - } - }) - } -} - -// TestLargeDataEncryption tests encryption/decryption of larger data streams -func TestSSEKMSLargeDataEncryption(t *testing.T) { - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - // Create a larger test dataset (1MB) - testData := strings.Repeat("This is a test of SSE-KMS with larger data streams. ", 20000) - testReader := strings.NewReader(testData) - - // Create encryption context - encryptionContext := BuildEncryptionContext("large-bucket", "large-object", false) - - // Encrypt the data - encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(testReader, kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - // Read the encrypted data - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Decrypt the data - decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - // Read the decrypted data - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data: %v", err) - } - - // Verify the decrypted data matches the original - if string(decryptedData) != testData { - t.Errorf("Decrypted data length: %d, original data length: %d", len(decryptedData), len(testData)) - t.Error("Decrypted large data does not match original") - } - - t.Logf("Successfully encrypted/decrypted %d bytes of data", len(testData)) -} - -// TestValidateSSEKMSKey tests the ValidateSSEKMSKey function, which correctly handles empty key IDs -func TestValidateSSEKMSKey(t *testing.T) { - tests := []struct { - name string - sseKey *SSEKMSKey - wantErr bool - }{ - { - name: "nil SSE-KMS key", - sseKey: nil, - wantErr: true, - }, - { - name: "empty key ID (valid - represents default KMS key)", - sseKey: &SSEKMSKey{ - KeyID: "", - EncryptionContext: map[string]string{"test": "value"}, - BucketKeyEnabled: false, - }, - wantErr: false, - }, - { - name: "valid UUID key ID", - sseKey: &SSEKMSKey{ - KeyID: "12345678-1234-1234-1234-123456789012", - EncryptionContext: map[string]string{"test": "value"}, - BucketKeyEnabled: true, - }, - wantErr: false, - }, - { - name: "valid alias", - sseKey: &SSEKMSKey{ - KeyID: "alias/my-test-key", - EncryptionContext: map[string]string{}, - BucketKeyEnabled: false, - }, - wantErr: false, - }, - { - name: "valid flexible key ID format", - sseKey: &SSEKMSKey{ - KeyID: "invalid-format", - EncryptionContext: map[string]string{}, - BucketKeyEnabled: false, - }, - wantErr: false, // Now valid - following Minio's permissive approach - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := ValidateSSEKMSKey(tt.sseKey) - if (err != nil) != tt.wantErr { - t.Errorf("ValidateSSEKMSKey() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/weed/s3api/s3_sse_metadata_test.go b/weed/s3api/s3_sse_metadata_test.go deleted file mode 100644 index c0c1360af..000000000 --- a/weed/s3api/s3_sse_metadata_test.go +++ /dev/null @@ -1,328 +0,0 @@ -package s3api - -import ( - "testing" - - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" -) - -// TestSSECIsEncrypted tests detection of SSE-C encryption from metadata -func TestSSECIsEncrypted(t *testing.T) { - testCases := []struct { - name string - metadata map[string][]byte - expected bool - }{ - { - name: "Empty metadata", - metadata: CreateTestMetadata(), - expected: false, - }, - { - name: "Valid SSE-C metadata", - metadata: CreateTestMetadataWithSSEC(GenerateTestSSECKey(1)), - expected: true, - }, - { - name: "SSE-C algorithm only", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte("AES256"), - }, - expected: true, - }, - { - name: "SSE-C key MD5 only", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryptionCustomerKeyMD5: []byte("somemd5"), - }, - expected: true, - }, - { - name: "Other encryption type (SSE-KMS)", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - }, - expected: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := IsSSECEncrypted(tc.metadata) - if result != tc.expected { - t.Errorf("Expected %v, got %v", tc.expected, result) - } - }) - } -} - -// TestSSEKMSIsEncrypted tests detection of SSE-KMS encryption from metadata -func TestSSEKMSIsEncrypted(t *testing.T) { - testCases := []struct { - name string - metadata map[string][]byte - expected bool - }{ - { - name: "Empty metadata", - metadata: CreateTestMetadata(), - expected: false, - }, - { - name: "Valid SSE-KMS metadata", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - s3_constants.AmzEncryptedDataKey: []byte("encrypted-key"), - }, - expected: true, - }, - { - name: "SSE-KMS algorithm only", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - }, - expected: true, - }, - { - name: "SSE-KMS encrypted data key only", - metadata: map[string][]byte{ - s3_constants.AmzEncryptedDataKey: []byte("encrypted-key"), - }, - expected: false, // Only encrypted data key without algorithm header should not be considered SSE-KMS - }, - { - name: "Other encryption type (SSE-C)", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte("AES256"), - }, - expected: false, - }, - { - name: "SSE-S3 (AES256)", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - }, - expected: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := IsSSEKMSEncrypted(tc.metadata) - if result != tc.expected { - t.Errorf("Expected %v, got %v", tc.expected, result) - } - }) - } -} - -// TestSSETypeDiscrimination tests that SSE types don't interfere with each other -func TestSSETypeDiscrimination(t *testing.T) { - // Test SSE-C headers don't trigger SSE-KMS detection - t.Run("SSE-C headers don't trigger SSE-KMS", func(t *testing.T) { - req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) - keyPair := GenerateTestSSECKey(1) - SetupTestSSECHeaders(req, keyPair) - - // Should detect SSE-C, not SSE-KMS - if !IsSSECRequest(req) { - t.Error("Should detect SSE-C request") - } - if IsSSEKMSRequest(req) { - t.Error("Should not detect SSE-KMS request for SSE-C headers") - } - }) - - // Test SSE-KMS headers don't trigger SSE-C detection - t.Run("SSE-KMS headers don't trigger SSE-C", func(t *testing.T) { - req := CreateTestHTTPRequest("PUT", "/bucket/object", nil) - SetupTestSSEKMSHeaders(req, "test-key-id") - - // Should detect SSE-KMS, not SSE-C - if IsSSECRequest(req) { - t.Error("Should not detect SSE-C request for SSE-KMS headers") - } - if !IsSSEKMSRequest(req) { - t.Error("Should detect SSE-KMS request") - } - }) - - // Test metadata discrimination - t.Run("Metadata type discrimination", func(t *testing.T) { - ssecMetadata := CreateTestMetadataWithSSEC(GenerateTestSSECKey(1)) - - // Should detect as SSE-C, not SSE-KMS - if !IsSSECEncrypted(ssecMetadata) { - t.Error("Should detect SSE-C encrypted metadata") - } - if IsSSEKMSEncrypted(ssecMetadata) { - t.Error("Should not detect SSE-KMS for SSE-C metadata") - } - }) -} - -// TestSSECParseCorruptedMetadata tests handling of corrupted SSE-C metadata -func TestSSECParseCorruptedMetadata(t *testing.T) { - testCases := []struct { - name string - metadata map[string][]byte - expectError bool - errorMessage string - }{ - { - name: "Missing algorithm", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryptionCustomerKeyMD5: []byte("valid-md5"), - }, - expectError: false, // Detection should still work with partial metadata - }, - { - name: "Invalid key MD5 format", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte("AES256"), - s3_constants.AmzServerSideEncryptionCustomerKeyMD5: []byte("invalid-base64!"), - }, - expectError: false, // Detection should work, validation happens later - }, - { - name: "Empty values", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte(""), - s3_constants.AmzServerSideEncryptionCustomerKeyMD5: []byte(""), - }, - expectError: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Test that detection doesn't panic on corrupted metadata - result := IsSSECEncrypted(tc.metadata) - // The detection should be robust and not crash - t.Logf("Detection result for %s: %v", tc.name, result) - }) - } -} - -// TestSSEKMSParseCorruptedMetadata tests handling of corrupted SSE-KMS metadata -func TestSSEKMSParseCorruptedMetadata(t *testing.T) { - testCases := []struct { - name string - metadata map[string][]byte - }{ - { - name: "Invalid encrypted data key", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - s3_constants.AmzEncryptedDataKey: []byte("invalid-base64!"), - }, - }, - { - name: "Invalid encryption context", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - s3_constants.AmzEncryptionContextMeta: []byte("invalid-json"), - }, - }, - { - name: "Empty values", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte(""), - s3_constants.AmzEncryptedDataKey: []byte(""), - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Test that detection doesn't panic on corrupted metadata - result := IsSSEKMSEncrypted(tc.metadata) - t.Logf("Detection result for %s: %v", tc.name, result) - }) - } -} - -// TestSSEMetadataDeserialization tests SSE-KMS metadata deserialization with various inputs -func TestSSEMetadataDeserialization(t *testing.T) { - testCases := []struct { - name string - data []byte - expectError bool - }{ - { - name: "Empty data", - data: []byte{}, - expectError: true, - }, - { - name: "Invalid JSON", - data: []byte("invalid-json"), - expectError: true, - }, - { - name: "Valid JSON but wrong structure", - data: []byte(`{"wrong": "structure"}`), - expectError: false, // Our deserialization might be lenient - }, - { - name: "Null data", - data: nil, - expectError: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - _, err := DeserializeSSEKMSMetadata(tc.data) - if tc.expectError && err == nil { - t.Error("Expected error but got none") - } - if !tc.expectError && err != nil { - t.Errorf("Expected no error but got: %v", err) - } - }) - } -} - -// TestGeneralSSEDetection tests the general SSE detection that works across types -func TestGeneralSSEDetection(t *testing.T) { - testCases := []struct { - name string - metadata map[string][]byte - expected bool - }{ - { - name: "No encryption", - metadata: CreateTestMetadata(), - expected: false, - }, - { - name: "SSE-C encrypted", - metadata: CreateTestMetadataWithSSEC(GenerateTestSSECKey(1)), - expected: true, - }, - { - name: "SSE-KMS encrypted", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - }, - expected: true, - }, - { - name: "SSE-S3 encrypted", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - }, - expected: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := IsAnySSEEncrypted(tc.metadata) - if result != tc.expected { - t.Errorf("Expected %v, got %v", tc.expected, result) - } - }) - } -} diff --git a/weed/s3api/s3_sse_multipart_test.go b/weed/s3api/s3_sse_multipart_test.go deleted file mode 100644 index c4dc9a45a..000000000 --- a/weed/s3api/s3_sse_multipart_test.go +++ /dev/null @@ -1,569 +0,0 @@ -package s3api - -import ( - "bytes" - "fmt" - "io" - "strings" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" -) - -// TestSSECMultipartUpload tests SSE-C with multipart uploads -func TestSSECMultipartUpload(t *testing.T) { - keyPair := GenerateTestSSECKey(1) - customerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: keyPair.Key, - KeyMD5: keyPair.KeyMD5, - } - - // Test data larger than typical part size - testData := strings.Repeat("Hello, SSE-C multipart world! ", 1000) // ~30KB - - t.Run("Single part encryption/decryption", func(t *testing.T) { - // Encrypt the data - encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Decrypt the data - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data: %v", err) - } - - if string(decryptedData) != testData { - t.Error("Decrypted data doesn't match original") - } - }) - - t.Run("Simulated multipart upload parts", func(t *testing.T) { - // Simulate multiple parts (each part gets encrypted separately) - partSize := 5 * 1024 // 5KB parts - var encryptedParts [][]byte - var partIVs [][]byte - - for i := 0; i < len(testData); i += partSize { - end := i + partSize - if end > len(testData) { - end = len(testData) - } - - partData := testData[i:end] - - // Each part is encrypted separately in multipart uploads - encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(partData), customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader for part %d: %v", i/partSize, err) - } - - encryptedPart, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted part %d: %v", i/partSize, err) - } - - encryptedParts = append(encryptedParts, encryptedPart) - partIVs = append(partIVs, iv) - } - - // Simulate reading back the multipart object - var reconstructedData strings.Builder - - for i, encryptedPart := range encryptedParts { - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedPart), customerKey, partIVs[i]) - if err != nil { - t.Fatalf("Failed to create decrypted reader for part %d: %v", i, err) - } - - decryptedPart, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted part %d: %v", i, err) - } - - reconstructedData.Write(decryptedPart) - } - - if reconstructedData.String() != testData { - t.Error("Reconstructed multipart data doesn't match original") - } - }) - - t.Run("Multipart with different part sizes", func(t *testing.T) { - partSizes := []int{1024, 2048, 4096, 8192} // Various part sizes - - for _, partSize := range partSizes { - t.Run(fmt.Sprintf("PartSize_%d", partSize), func(t *testing.T) { - var encryptedParts [][]byte - var partIVs [][]byte - - for i := 0; i < len(testData); i += partSize { - end := i + partSize - if end > len(testData) { - end = len(testData) - } - - partData := testData[i:end] - - encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(partData), customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - encryptedPart, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted part: %v", err) - } - - encryptedParts = append(encryptedParts, encryptedPart) - partIVs = append(partIVs, iv) - } - - // Verify reconstruction - var reconstructedData strings.Builder - - for j, encryptedPart := range encryptedParts { - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedPart), customerKey, partIVs[j]) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - decryptedPart, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted part: %v", err) - } - - reconstructedData.Write(decryptedPart) - } - - if reconstructedData.String() != testData { - t.Errorf("Reconstructed data doesn't match original for part size %d", partSize) - } - }) - } - }) -} - -// TestSSEKMSMultipartUpload tests SSE-KMS with multipart uploads -func TestSSEKMSMultipartUpload(t *testing.T) { - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - // Test data larger than typical part size - testData := strings.Repeat("Hello, SSE-KMS multipart world! ", 1000) // ~30KB - encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false) - - t.Run("Single part encryption/decryption", func(t *testing.T) { - // Encrypt the data - encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(testData), kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Decrypt the data - decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data: %v", err) - } - - if string(decryptedData) != testData { - t.Error("Decrypted data doesn't match original") - } - }) - - t.Run("Simulated multipart upload parts", func(t *testing.T) { - // Simulate multiple parts (each part might use the same or different KMS operations) - partSize := 5 * 1024 // 5KB parts - var encryptedParts [][]byte - var sseKeys []*SSEKMSKey - - for i := 0; i < len(testData); i += partSize { - end := i + partSize - if end > len(testData) { - end = len(testData) - } - - partData := testData[i:end] - - // Each part might get its own data key in KMS multipart uploads - encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(partData), kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create encrypted reader for part %d: %v", i/partSize, err) - } - - encryptedPart, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted part %d: %v", i/partSize, err) - } - - encryptedParts = append(encryptedParts, encryptedPart) - sseKeys = append(sseKeys, sseKey) - } - - // Simulate reading back the multipart object - var reconstructedData strings.Builder - - for i, encryptedPart := range encryptedParts { - decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedPart), sseKeys[i]) - if err != nil { - t.Fatalf("Failed to create decrypted reader for part %d: %v", i, err) - } - - decryptedPart, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted part %d: %v", i, err) - } - - reconstructedData.Write(decryptedPart) - } - - if reconstructedData.String() != testData { - t.Error("Reconstructed multipart data doesn't match original") - } - }) - - t.Run("Multipart consistency checks", func(t *testing.T) { - // Test that all parts use the same KMS key ID but different data keys - partSize := 5 * 1024 - var sseKeys []*SSEKMSKey - - for i := 0; i < len(testData); i += partSize { - end := i + partSize - if end > len(testData) { - end = len(testData) - } - - partData := testData[i:end] - - _, sseKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(partData), kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - sseKeys = append(sseKeys, sseKey) - } - - // Verify all parts use the same KMS key ID - for i, sseKey := range sseKeys { - if sseKey.KeyID != kmsKey.KeyID { - t.Errorf("Part %d has wrong KMS key ID: expected %s, got %s", i, kmsKey.KeyID, sseKey.KeyID) - } - } - - // Verify each part has different encrypted data keys (they should be unique) - for i := 0; i < len(sseKeys); i++ { - for j := i + 1; j < len(sseKeys); j++ { - if bytes.Equal(sseKeys[i].EncryptedDataKey, sseKeys[j].EncryptedDataKey) { - t.Errorf("Parts %d and %d have identical encrypted data keys (should be unique)", i, j) - } - } - } - }) -} - -// TestMultipartSSEMixedScenarios tests edge cases with multipart and SSE -func TestMultipartSSEMixedScenarios(t *testing.T) { - t.Run("Empty parts handling", func(t *testing.T) { - keyPair := GenerateTestSSECKey(1) - customerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: keyPair.Key, - KeyMD5: keyPair.KeyMD5, - } - - // Test empty part - encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(""), customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader for empty data: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted empty data: %v", err) - } - - // Empty part should produce empty encrypted data, but still have a valid IV - if len(encryptedData) != 0 { - t.Errorf("Expected empty encrypted data for empty part, got %d bytes", len(encryptedData)) - } - if len(iv) != s3_constants.AESBlockSize { - t.Errorf("Expected IV of size %d, got %d", s3_constants.AESBlockSize, len(iv)) - } - - // Decrypt and verify - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader for empty data: %v", err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted empty data: %v", err) - } - - if len(decryptedData) != 0 { - t.Errorf("Expected empty decrypted data, got %d bytes", len(decryptedData)) - } - }) - - t.Run("Single byte parts", func(t *testing.T) { - keyPair := GenerateTestSSECKey(1) - customerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: keyPair.Key, - KeyMD5: keyPair.KeyMD5, - } - - testData := "ABCDEFGHIJ" - var encryptedParts [][]byte - var partIVs [][]byte - - // Encrypt each byte as a separate part - for i, b := range []byte(testData) { - partData := string(b) - - encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(partData), customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader for byte %d: %v", i, err) - } - - encryptedPart, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted byte %d: %v", i, err) - } - - encryptedParts = append(encryptedParts, encryptedPart) - partIVs = append(partIVs, iv) - } - - // Reconstruct - var reconstructedData strings.Builder - - for i, encryptedPart := range encryptedParts { - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedPart), customerKey, partIVs[i]) - if err != nil { - t.Fatalf("Failed to create decrypted reader for byte %d: %v", i, err) - } - - decryptedPart, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted byte %d: %v", i, err) - } - - reconstructedData.Write(decryptedPart) - } - - if reconstructedData.String() != testData { - t.Errorf("Expected %s, got %s", testData, reconstructedData.String()) - } - }) - - t.Run("Very large parts", func(t *testing.T) { - keyPair := GenerateTestSSECKey(1) - customerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: keyPair.Key, - KeyMD5: keyPair.KeyMD5, - } - - // Create a large part (1MB) - largeData := make([]byte, 1024*1024) - for i := range largeData { - largeData[i] = byte(i % 256) - } - - // Encrypt - encryptedReader, iv, err := CreateSSECEncryptedReader(bytes.NewReader(largeData), customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader for large data: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted large data: %v", err) - } - - // Decrypt - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader for large data: %v", err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted large data: %v", err) - } - - if !bytes.Equal(decryptedData, largeData) { - t.Error("Large data doesn't match after encryption/decryption") - } - }) -} - -func TestSSECLargeObjectChunkReassembly(t *testing.T) { - keyPair := GenerateTestSSECKey(1) - customerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: keyPair.Key, - KeyMD5: keyPair.KeyMD5, - } - - const chunkSize = 8 * 1024 * 1024 // matches putToFiler chunk size - totalSize := chunkSize*2 + 3*1024*1024 - plaintext := make([]byte, totalSize) - for i := range plaintext { - plaintext[i] = byte(i % 251) - } - - encryptedReader, iv, err := CreateSSECEncryptedReader(bytes.NewReader(plaintext), customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - var reconstructed bytes.Buffer - offset := int64(0) - for offset < int64(len(encryptedData)) { - end := offset + chunkSize - if end > int64(len(encryptedData)) { - end = int64(len(encryptedData)) - } - - chunkIV := make([]byte, len(iv)) - copy(chunkIV, iv) - chunkReader := bytes.NewReader(encryptedData[offset:end]) - decryptedReader, decErr := CreateSSECDecryptedReaderWithOffset(chunkReader, customerKey, chunkIV, uint64(offset)) - if decErr != nil { - t.Fatalf("Failed to create decrypted reader for offset %d: %v", offset, decErr) - } - decryptedChunk, decErr := io.ReadAll(decryptedReader) - if decErr != nil { - t.Fatalf("Failed to read decrypted chunk at offset %d: %v", offset, decErr) - } - reconstructed.Write(decryptedChunk) - offset = end - } - - if !bytes.Equal(reconstructed.Bytes(), plaintext) { - t.Fatalf("Reconstructed data mismatch: expected %d bytes, got %d", len(plaintext), reconstructed.Len()) - } -} - -// TestMultipartSSEPerformance tests performance characteristics of SSE with multipart -func TestMultipartSSEPerformance(t *testing.T) { - if testing.Short() { - t.Skip("Skipping performance test in short mode") - } - - t.Run("SSE-C performance with multiple parts", func(t *testing.T) { - keyPair := GenerateTestSSECKey(1) - customerKey := &SSECustomerKey{ - Algorithm: "AES256", - Key: keyPair.Key, - KeyMD5: keyPair.KeyMD5, - } - - partSize := 64 * 1024 // 64KB parts - numParts := 10 - - for partNum := 0; partNum < numParts; partNum++ { - partData := make([]byte, partSize) - for i := range partData { - partData[i] = byte((partNum + i) % 256) - } - - // Encrypt - encryptedReader, iv, err := CreateSSECEncryptedReader(bytes.NewReader(partData), customerKey) - if err != nil { - t.Fatalf("Failed to create encrypted reader for part %d: %v", partNum, err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data for part %d: %v", partNum, err) - } - - // Decrypt - decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader for part %d: %v", partNum, err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data for part %d: %v", partNum, err) - } - - if !bytes.Equal(decryptedData, partData) { - t.Errorf("Data mismatch for part %d", partNum) - } - } - }) - - t.Run("SSE-KMS performance with multiple parts", func(t *testing.T) { - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - partSize := 64 * 1024 // 64KB parts - numParts := 5 // Fewer parts for KMS due to overhead - encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false) - - for partNum := 0; partNum < numParts; partNum++ { - partData := make([]byte, partSize) - for i := range partData { - partData[i] = byte((partNum + i) % 256) - } - - // Encrypt - encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(bytes.NewReader(partData), kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create encrypted reader for part %d: %v", partNum, err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data for part %d: %v", partNum, err) - } - - // Decrypt - decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey) - if err != nil { - t.Fatalf("Failed to create decrypted reader for part %d: %v", partNum, err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data for part %d: %v", partNum, err) - } - - if !bytes.Equal(decryptedData, partData) { - t.Errorf("Data mismatch for part %d", partNum) - } - } - }) -} diff --git a/weed/s3api/s3_sse_s3.go b/weed/s3api/s3_sse_s3.go index d9ea5a919..801221ed3 100644 --- a/weed/s3api/s3_sse_s3.go +++ b/weed/s3api/s3_sse_s3.go @@ -137,13 +137,6 @@ func CreateSSES3DecryptedReader(reader io.Reader, key *SSES3Key, iv []byte) (io. return decryptReader, nil } -// GetSSES3Headers returns the headers for SSE-S3 encrypted objects -func GetSSES3Headers() map[string]string { - return map[string]string{ - s3_constants.AmzServerSideEncryption: SSES3Algorithm, - } -} - // SerializeSSES3Metadata serializes SSE-S3 metadata for storage using envelope encryption func SerializeSSES3Metadata(key *SSES3Key) ([]byte, error) { if err := ValidateSSES3Key(key); err != nil { @@ -339,7 +332,7 @@ func (km *SSES3KeyManager) InitializeWithFiler(filerClient filer_pb.FilerClient) v := util.GetViper() cfgKEK := v.GetString(sseS3KEKConfigKey) // hex-encoded, drop-in for filer file - cfgKey := v.GetString(sseS3KeyConfigKey) // any string, HKDF-derived + cfgKey := v.GetString(sseS3KeyConfigKey) // any string, HKDF-derived if cfgKEK != "" && cfgKey != "" { return fmt.Errorf("only one of %s and %s may be set, not both", sseS3KEKConfigKey, sseS3KeyConfigKey) @@ -454,7 +447,6 @@ func (km *SSES3KeyManager) loadSuperKeyFromFiler() error { return nil } - // GetOrCreateKey gets an existing key or creates a new one // With envelope encryption, we always generate a new DEK since we don't store them func (km *SSES3KeyManager) GetOrCreateKey(keyID string) (*SSES3Key, error) { @@ -532,14 +524,6 @@ func (km *SSES3KeyManager) StoreKey(key *SSES3Key) { // The DEK is encrypted with the super key and stored in object metadata } -// GetKey is now a no-op since we don't cache keys -// Keys are retrieved by decrypting the encrypted DEK from object metadata -func (km *SSES3KeyManager) GetKey(keyID string) (*SSES3Key, bool) { - // No-op: With envelope encryption, keys are not cached - // Each object's metadata contains the encrypted DEK - return nil, false -} - // GetMasterKey returns a derived key from the master KEK for STS signing // This uses HKDF to isolate the STS security domain from the SSE-S3 domain func (km *SSES3KeyManager) GetMasterKey() []byte { @@ -596,47 +580,6 @@ func InitializeGlobalSSES3KeyManager(filerClient *wdclient.FilerClient, grpcDial return globalSSES3KeyManager.InitializeWithFiler(wrapper) } -// ProcessSSES3Request processes an SSE-S3 request and returns encryption metadata -func ProcessSSES3Request(r *http.Request) (map[string][]byte, error) { - if !IsSSES3RequestInternal(r) { - return nil, nil - } - - // Generate or retrieve encryption key - keyManager := GetSSES3KeyManager() - key, err := keyManager.GetOrCreateKey("") - if err != nil { - return nil, fmt.Errorf("get SSE-S3 key: %w", err) - } - - // Serialize key metadata - keyData, err := SerializeSSES3Metadata(key) - if err != nil { - return nil, fmt.Errorf("serialize SSE-S3 metadata: %w", err) - } - - // Store key in manager - keyManager.StoreKey(key) - - // Return metadata - metadata := map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte(SSES3Algorithm), - s3_constants.SeaweedFSSSES3Key: keyData, - } - - return metadata, nil -} - -// GetSSES3KeyFromMetadata extracts SSE-S3 key from object metadata -func GetSSES3KeyFromMetadata(metadata map[string][]byte, keyManager *SSES3KeyManager) (*SSES3Key, error) { - keyData, exists := metadata[s3_constants.SeaweedFSSSES3Key] - if !exists { - return nil, fmt.Errorf("SSE-S3 key not found in metadata") - } - - return DeserializeSSES3Metadata(keyData, keyManager) -} - // GetSSES3IV extracts the IV for single-part SSE-S3 objects // Priority: 1) object-level metadata (for inline/small files), 2) first chunk metadata func GetSSES3IV(entry *filer_pb.Entry, sseS3Key *SSES3Key, keyManager *SSES3KeyManager) ([]byte, error) { diff --git a/weed/s3api/s3_sse_s3_test.go b/weed/s3api/s3_sse_s3_test.go deleted file mode 100644 index af64850d9..000000000 --- a/weed/s3api/s3_sse_s3_test.go +++ /dev/null @@ -1,1079 +0,0 @@ -package s3api - -import ( - "bytes" - "encoding/hex" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/util" -) - -// TestSSES3EncryptionDecryption tests basic SSE-S3 encryption and decryption -func TestSSES3EncryptionDecryption(t *testing.T) { - // Generate SSE-S3 key - sseS3Key, err := GenerateSSES3Key() - if err != nil { - t.Fatalf("Failed to generate SSE-S3 key: %v", err) - } - - // Test data - testData := []byte("Hello, World! This is a test of SSE-S3 encryption.") - - // Create encrypted reader - dataReader := bytes.NewReader(testData) - encryptedReader, iv, err := CreateSSES3EncryptedReader(dataReader, sseS3Key) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - // Read encrypted data - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Verify data is actually encrypted (different from original) - if bytes.Equal(encryptedData, testData) { - t.Error("Data doesn't appear to be encrypted") - } - - // Create decrypted reader - encryptedReader2 := bytes.NewReader(encryptedData) - decryptedReader, err := CreateSSES3DecryptedReader(encryptedReader2, sseS3Key, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - // Read decrypted data - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data: %v", err) - } - - // Verify decrypted data matches original - if !bytes.Equal(decryptedData, testData) { - t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData) - } -} - -// TestSSES3IsRequestInternal tests detection of SSE-S3 requests -func TestSSES3IsRequestInternal(t *testing.T) { - testCases := []struct { - name string - headers map[string]string - expected bool - }{ - { - name: "Valid SSE-S3 request", - headers: map[string]string{ - s3_constants.AmzServerSideEncryption: "AES256", - }, - expected: true, - }, - { - name: "No SSE headers", - headers: map[string]string{}, - expected: false, - }, - { - name: "SSE-KMS request", - headers: map[string]string{ - s3_constants.AmzServerSideEncryption: "aws:kms", - }, - expected: false, - }, - { - name: "SSE-C request", - headers: map[string]string{ - s3_constants.AmzServerSideEncryptionCustomerAlgorithm: "AES256", - }, - expected: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req := &http.Request{Header: make(http.Header)} - for k, v := range tc.headers { - req.Header.Set(k, v) - } - - result := IsSSES3RequestInternal(req) - if result != tc.expected { - t.Errorf("Expected %v, got %v", tc.expected, result) - } - }) - } -} - -// TestSSES3MetadataSerialization tests SSE-S3 metadata serialization and deserialization -func TestSSES3MetadataSerialization(t *testing.T) { - // Initialize global key manager - globalSSES3KeyManager = NewSSES3KeyManager() - defer func() { - globalSSES3KeyManager = NewSSES3KeyManager() - }() - - // Set up the key manager with a super key for testing - keyManager := GetSSES3KeyManager() - keyManager.superKey = make([]byte, 32) - for i := range keyManager.superKey { - keyManager.superKey[i] = byte(i) - } - - // Generate SSE-S3 key - sseS3Key, err := GenerateSSES3Key() - if err != nil { - t.Fatalf("Failed to generate SSE-S3 key: %v", err) - } - - // Add IV to the key - sseS3Key.IV = make([]byte, 16) - for i := range sseS3Key.IV { - sseS3Key.IV[i] = byte(i * 2) - } - - // Serialize metadata - serialized, err := SerializeSSES3Metadata(sseS3Key) - if err != nil { - t.Fatalf("Failed to serialize SSE-S3 metadata: %v", err) - } - - if len(serialized) == 0 { - t.Error("Serialized metadata is empty") - } - - // Deserialize metadata - deserializedKey, err := DeserializeSSES3Metadata(serialized, keyManager) - if err != nil { - t.Fatalf("Failed to deserialize SSE-S3 metadata: %v", err) - } - - // Verify key matches - if !bytes.Equal(deserializedKey.Key, sseS3Key.Key) { - t.Error("Deserialized key doesn't match original key") - } - - // Verify IV matches - if !bytes.Equal(deserializedKey.IV, sseS3Key.IV) { - t.Error("Deserialized IV doesn't match original IV") - } - - // Verify algorithm matches - if deserializedKey.Algorithm != sseS3Key.Algorithm { - t.Errorf("Algorithm mismatch: expected %s, got %s", sseS3Key.Algorithm, deserializedKey.Algorithm) - } - - // Verify key ID matches - if deserializedKey.KeyID != sseS3Key.KeyID { - t.Errorf("Key ID mismatch: expected %s, got %s", sseS3Key.KeyID, deserializedKey.KeyID) - } -} - -// TestDetectPrimarySSETypeS3 tests detection of SSE-S3 as primary encryption type -func TestDetectPrimarySSETypeS3(t *testing.T) { - s3a := &S3ApiServer{} - - testCases := []struct { - name string - entry *filer_pb.Entry - expected string - }{ - { - name: "Single SSE-S3 chunk", - entry: &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - }, - Attributes: &filer_pb.FuseAttributes{}, - Chunks: []*filer_pb.FileChunk{ - { - FileId: "1,123", - Offset: 0, - Size: 1024, - SseType: filer_pb.SSEType_SSE_S3, - SseMetadata: []byte("metadata"), - }, - }, - }, - expected: s3_constants.SSETypeS3, - }, - { - name: "Multiple SSE-S3 chunks", - entry: &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - }, - Attributes: &filer_pb.FuseAttributes{}, - Chunks: []*filer_pb.FileChunk{ - { - FileId: "1,123", - Offset: 0, - Size: 1024, - SseType: filer_pb.SSEType_SSE_S3, - SseMetadata: []byte("metadata1"), - }, - { - FileId: "2,456", - Offset: 1024, - Size: 1024, - SseType: filer_pb.SSEType_SSE_S3, - SseMetadata: []byte("metadata2"), - }, - }, - }, - expected: s3_constants.SSETypeS3, - }, - { - name: "Mixed SSE-S3 and SSE-KMS chunks (SSE-S3 majority)", - entry: &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - }, - Attributes: &filer_pb.FuseAttributes{}, - Chunks: []*filer_pb.FileChunk{ - { - FileId: "1,123", - Offset: 0, - Size: 1024, - SseType: filer_pb.SSEType_SSE_S3, - SseMetadata: []byte("metadata1"), - }, - { - FileId: "2,456", - Offset: 1024, - Size: 1024, - SseType: filer_pb.SSEType_SSE_S3, - SseMetadata: []byte("metadata2"), - }, - { - FileId: "3,789", - Offset: 2048, - Size: 1024, - SseType: filer_pb.SSEType_SSE_KMS, - SseMetadata: []byte("metadata3"), - }, - }, - }, - expected: s3_constants.SSETypeS3, - }, - { - name: "No chunks, SSE-S3 metadata without KMS key ID", - entry: &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - }, - Attributes: &filer_pb.FuseAttributes{}, - Chunks: []*filer_pb.FileChunk{}, - }, - expected: s3_constants.SSETypeS3, - }, - { - name: "No chunks, SSE-KMS metadata with KMS key ID", - entry: &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-id"), - }, - Attributes: &filer_pb.FuseAttributes{}, - Chunks: []*filer_pb.FileChunk{}, - }, - expected: s3_constants.SSETypeKMS, - }, - { - name: "SSE-C chunks", - entry: &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte("AES256"), - }, - Attributes: &filer_pb.FuseAttributes{}, - Chunks: []*filer_pb.FileChunk{ - { - FileId: "1,123", - Offset: 0, - Size: 1024, - SseType: filer_pb.SSEType_SSE_C, - SseMetadata: []byte("metadata"), - }, - }, - }, - expected: s3_constants.SSETypeC, - }, - { - name: "Unencrypted", - entry: &filer_pb.Entry{ - Extended: map[string][]byte{}, - Attributes: &filer_pb.FuseAttributes{}, - Chunks: []*filer_pb.FileChunk{ - { - FileId: "1,123", - Offset: 0, - Size: 1024, - }, - }, - }, - expected: "None", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := s3a.detectPrimarySSEType(tc.entry) - if result != tc.expected { - t.Errorf("Expected %s, got %s", tc.expected, result) - } - }) - } -} - -// TestSSES3EncryptionWithBaseIV tests multipart encryption with base IV -func TestSSES3EncryptionWithBaseIV(t *testing.T) { - // Generate SSE-S3 key - sseS3Key, err := GenerateSSES3Key() - if err != nil { - t.Fatalf("Failed to generate SSE-S3 key: %v", err) - } - - // Generate base IV - baseIV := make([]byte, 16) - for i := range baseIV { - baseIV[i] = byte(i) - } - - // Test data for two parts - testData1 := []byte("Part 1 of multipart upload test.") - testData2 := []byte("Part 2 of multipart upload test.") - - // Encrypt part 1 at offset 0 - dataReader1 := bytes.NewReader(testData1) - encryptedReader1, iv1, err := CreateSSES3EncryptedReaderWithBaseIV(dataReader1, sseS3Key, baseIV, 0) - if err != nil { - t.Fatalf("Failed to create encrypted reader for part 1: %v", err) - } - - encryptedData1, err := io.ReadAll(encryptedReader1) - if err != nil { - t.Fatalf("Failed to read encrypted data for part 1: %v", err) - } - - // Encrypt part 2 at offset (simulating second part) - dataReader2 := bytes.NewReader(testData2) - offset2 := int64(len(testData1)) - encryptedReader2, iv2, err := CreateSSES3EncryptedReaderWithBaseIV(dataReader2, sseS3Key, baseIV, offset2) - if err != nil { - t.Fatalf("Failed to create encrypted reader for part 2: %v", err) - } - - encryptedData2, err := io.ReadAll(encryptedReader2) - if err != nil { - t.Fatalf("Failed to read encrypted data for part 2: %v", err) - } - - // IVs should be different (offset-based) - if bytes.Equal(iv1, iv2) { - t.Error("IVs should be different for different offsets") - } - - // Decrypt part 1 - decryptedReader1, err := CreateSSES3DecryptedReader(bytes.NewReader(encryptedData1), sseS3Key, iv1) - if err != nil { - t.Fatalf("Failed to create decrypted reader for part 1: %v", err) - } - - decryptedData1, err := io.ReadAll(decryptedReader1) - if err != nil { - t.Fatalf("Failed to read decrypted data for part 1: %v", err) - } - - // Decrypt part 2 - decryptedReader2, err := CreateSSES3DecryptedReader(bytes.NewReader(encryptedData2), sseS3Key, iv2) - if err != nil { - t.Fatalf("Failed to create decrypted reader for part 2: %v", err) - } - - decryptedData2, err := io.ReadAll(decryptedReader2) - if err != nil { - t.Fatalf("Failed to read decrypted data for part 2: %v", err) - } - - // Verify decrypted data matches original - if !bytes.Equal(decryptedData1, testData1) { - t.Errorf("Decrypted part 1 doesn't match original.\nOriginal: %s\nDecrypted: %s", testData1, decryptedData1) - } - - if !bytes.Equal(decryptedData2, testData2) { - t.Errorf("Decrypted part 2 doesn't match original.\nOriginal: %s\nDecrypted: %s", testData2, decryptedData2) - } -} - -// TestSSES3WrongKeyDecryption tests that wrong key fails decryption -func TestSSES3WrongKeyDecryption(t *testing.T) { - // Generate two different keys - sseS3Key1, err := GenerateSSES3Key() - if err != nil { - t.Fatalf("Failed to generate SSE-S3 key 1: %v", err) - } - - sseS3Key2, err := GenerateSSES3Key() - if err != nil { - t.Fatalf("Failed to generate SSE-S3 key 2: %v", err) - } - - // Test data - testData := []byte("Secret data encrypted with key 1") - - // Encrypt with key 1 - dataReader := bytes.NewReader(testData) - encryptedReader, iv, err := CreateSSES3EncryptedReader(dataReader, sseS3Key1) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Try to decrypt with key 2 (wrong key) - decryptedReader, err := CreateSSES3DecryptedReader(bytes.NewReader(encryptedData), sseS3Key2, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data: %v", err) - } - - // Decrypted data should NOT match original (wrong key produces garbage) - if bytes.Equal(decryptedData, testData) { - t.Error("Decryption with wrong key should not produce correct plaintext") - } -} - -// TestSSES3KeyGeneration tests SSE-S3 key generation -func TestSSES3KeyGeneration(t *testing.T) { - // Generate multiple keys - keys := make([]*SSES3Key, 10) - for i := range keys { - key, err := GenerateSSES3Key() - if err != nil { - t.Fatalf("Failed to generate SSE-S3 key %d: %v", i, err) - } - keys[i] = key - - // Verify key properties - if len(key.Key) != SSES3KeySize { - t.Errorf("Key %d has wrong size: expected %d, got %d", i, SSES3KeySize, len(key.Key)) - } - - if key.Algorithm != SSES3Algorithm { - t.Errorf("Key %d has wrong algorithm: expected %s, got %s", i, SSES3Algorithm, key.Algorithm) - } - - if key.KeyID == "" { - t.Errorf("Key %d has empty key ID", i) - } - } - - // Verify keys are unique - for i := 0; i < len(keys); i++ { - for j := i + 1; j < len(keys); j++ { - if bytes.Equal(keys[i].Key, keys[j].Key) { - t.Errorf("Keys %d and %d are identical (should be unique)", i, j) - } - if keys[i].KeyID == keys[j].KeyID { - t.Errorf("Key IDs %d and %d are identical (should be unique)", i, j) - } - } - } -} - -// TestSSES3VariousSizes tests SSE-S3 encryption/decryption with various data sizes -func TestSSES3VariousSizes(t *testing.T) { - sizes := []int{1, 15, 16, 17, 100, 1024, 4096, 1048576} - - for _, size := range sizes { - t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) { - // Generate test data - testData := make([]byte, size) - for i := range testData { - testData[i] = byte(i % 256) - } - - // Generate key - sseS3Key, err := GenerateSSES3Key() - if err != nil { - t.Fatalf("Failed to generate SSE-S3 key: %v", err) - } - - // Encrypt - dataReader := bytes.NewReader(testData) - encryptedReader, iv, err := CreateSSES3EncryptedReader(dataReader, sseS3Key) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - - encryptedData, err := io.ReadAll(encryptedReader) - if err != nil { - t.Fatalf("Failed to read encrypted data: %v", err) - } - - // Verify encrypted size matches original - if len(encryptedData) != size { - t.Errorf("Encrypted size mismatch: expected %d, got %d", size, len(encryptedData)) - } - - // Decrypt - decryptedReader, err := CreateSSES3DecryptedReader(bytes.NewReader(encryptedData), sseS3Key, iv) - if err != nil { - t.Fatalf("Failed to create decrypted reader: %v", err) - } - - decryptedData, err := io.ReadAll(decryptedReader) - if err != nil { - t.Fatalf("Failed to read decrypted data: %v", err) - } - - // Verify - if !bytes.Equal(decryptedData, testData) { - t.Errorf("Decrypted data doesn't match original for size %d", size) - } - }) - } -} - -// TestSSES3ResponseHeaders tests that SSE-S3 response headers are set correctly -func TestSSES3ResponseHeaders(t *testing.T) { - w := httptest.NewRecorder() - - // Simulate setting SSE-S3 response headers - w.Header().Set(s3_constants.AmzServerSideEncryption, SSES3Algorithm) - - // Verify headers - algorithm := w.Header().Get(s3_constants.AmzServerSideEncryption) - if algorithm != "AES256" { - t.Errorf("Expected algorithm AES256, got %s", algorithm) - } - - // Should NOT have customer key headers - if w.Header().Get(s3_constants.AmzServerSideEncryptionCustomerAlgorithm) != "" { - t.Error("Should not have SSE-C customer algorithm header") - } - - if w.Header().Get(s3_constants.AmzServerSideEncryptionCustomerKeyMD5) != "" { - t.Error("Should not have SSE-C customer key MD5 header") - } - - // Should NOT have KMS key ID - if w.Header().Get(s3_constants.AmzServerSideEncryptionAwsKmsKeyId) != "" { - t.Error("Should not have SSE-KMS key ID header") - } -} - -// TestSSES3IsEncryptedInternal tests detection of SSE-S3 encryption from metadata -func TestSSES3IsEncryptedInternal(t *testing.T) { - testCases := []struct { - name string - metadata map[string][]byte - expected bool - }{ - { - name: "Empty metadata", - metadata: map[string][]byte{}, - expected: false, - }, - { - name: "Valid SSE-S3 metadata with key", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - s3_constants.SeaweedFSSSES3Key: []byte("test-key-data"), - }, - expected: true, - }, - { - name: "SSE-S3 header without key (orphaned header - GitHub #7562)", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - }, - expected: false, // Should not be considered encrypted without the key - }, - { - name: "SSE-KMS metadata", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - }, - expected: false, - }, - { - name: "SSE-C metadata", - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte("AES256"), - }, - expected: false, - }, - { - name: "Key without header", - metadata: map[string][]byte{ - s3_constants.SeaweedFSSSES3Key: []byte("test-key-data"), - }, - expected: false, // Need both header and key - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := IsSSES3EncryptedInternal(tc.metadata) - if result != tc.expected { - t.Errorf("Expected %v, got %v", tc.expected, result) - } - }) - } -} - -// TestSSES3InvalidMetadataDeserialization tests error handling for invalid metadata -func TestSSES3InvalidMetadataDeserialization(t *testing.T) { - keyManager := NewSSES3KeyManager() - keyManager.superKey = make([]byte, 32) - - testCases := []struct { - name string - metadata []byte - shouldError bool - }{ - { - name: "Empty metadata", - metadata: []byte{}, - shouldError: true, - }, - { - name: "Invalid JSON", - metadata: []byte("not valid json"), - shouldError: true, - }, - { - name: "Missing keyId", - metadata: []byte(`{"algorithm":"AES256"}`), - shouldError: true, - }, - { - name: "Invalid base64 encrypted DEK", - metadata: []byte(`{"keyId":"test","algorithm":"AES256","encryptedDEK":"not-valid-base64!","nonce":"dGVzdA=="}`), - shouldError: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - _, err := DeserializeSSES3Metadata(tc.metadata, keyManager) - if tc.shouldError && err == nil { - t.Error("Expected error but got none") - } - if !tc.shouldError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - }) - } -} - -// setViperKey is a test helper that sets a config key via its WEED_ env var. -func setViperKey(t *testing.T, key, value string) { - t.Helper() - util.GetViper().SetDefault(key, "") - t.Setenv("WEED_"+strings.ReplaceAll(strings.ToUpper(key), ".", "_"), value) -} - -// TestSSES3KEKConfig tests that sse_s3.kek (hex format) is used as KEK -func TestSSES3KEKConfig(t *testing.T) { - testKey := make([]byte, 32) - for i := range testKey { - testKey[i] = byte(i + 50) - } - setViperKey(t, sseS3KEKConfigKey, hex.EncodeToString(testKey)) - - km := NewSSES3KeyManager() - err := km.InitializeWithFiler(nil) - if err != nil { - t.Fatalf("InitializeWithFiler failed: %v", err) - } - - if !bytes.Equal(km.superKey, testKey) { - t.Errorf("superKey mismatch: expected %x, got %x", testKey, km.superKey) - } - - // Round-trip DEK encryption - dek := make([]byte, 32) - for i := range dek { - dek[i] = byte(i) - } - encrypted, nonce, err := km.encryptKeyWithSuperKey(dek) - if err != nil { - t.Fatalf("encryptKeyWithSuperKey failed: %v", err) - } - decrypted, err := km.decryptKeyWithSuperKey(encrypted, nonce) - if err != nil { - t.Fatalf("decryptKeyWithSuperKey failed: %v", err) - } - if !bytes.Equal(decrypted, dek) { - t.Error("round-trip DEK mismatch") - } -} - -// TestSSES3KEKConfigInvalidHex tests rejection of bad hex -func TestSSES3KEKConfigInvalidHex(t *testing.T) { - setViperKey(t, sseS3KEKConfigKey, "not-valid-hex") - - km := NewSSES3KeyManager() - err := km.InitializeWithFiler(nil) - if err == nil { - t.Fatal("expected error for invalid hex, got nil") - } - if !strings.Contains(err.Error(), "hex-encoded") { - t.Errorf("expected hex error, got: %v", err) - } -} - -// TestSSES3KEKConfigWrongSize tests rejection of wrong-size hex key -func TestSSES3KEKConfigWrongSize(t *testing.T) { - setViperKey(t, sseS3KEKConfigKey, hex.EncodeToString(make([]byte, 16))) - - km := NewSSES3KeyManager() - err := km.InitializeWithFiler(nil) - if err == nil { - t.Fatal("expected error for wrong key size, got nil") - } - if !strings.Contains(err.Error(), "32 bytes") { - t.Errorf("expected size error, got: %v", err) - } -} - -// TestSSES3KeyConfig tests that sse_s3.key (any string, HKDF) works -func TestSSES3KeyConfig(t *testing.T) { - setViperKey(t, sseS3KeyConfigKey, "my-secret-passphrase") - - km := NewSSES3KeyManager() - err := km.InitializeWithFiler(nil) - if err != nil { - t.Fatalf("InitializeWithFiler failed: %v", err) - } - - if len(km.superKey) != SSES3KeySize { - t.Fatalf("expected %d-byte superKey, got %d", SSES3KeySize, len(km.superKey)) - } - - // Deterministic: same input → same output - expected, err := deriveKeyFromSecret("my-secret-passphrase") - if err != nil { - t.Fatalf("deriveKeyFromSecret failed: %v", err) - } - if !bytes.Equal(km.superKey, expected) { - t.Errorf("superKey mismatch: expected %x, got %x", expected, km.superKey) - } -} - -// TestSSES3KeyConfigDifferentSecrets tests different strings produce different keys -func TestSSES3KeyConfigDifferentSecrets(t *testing.T) { - k1, _ := deriveKeyFromSecret("secret-one") - k2, _ := deriveKeyFromSecret("secret-two") - if bytes.Equal(k1, k2) { - t.Error("different secrets should produce different keys") - } -} - -// TestSSES3BothConfigsReject tests that setting both config keys is rejected -func TestSSES3BothConfigsReject(t *testing.T) { - setViperKey(t, sseS3KEKConfigKey, hex.EncodeToString(make([]byte, 32))) - setViperKey(t, sseS3KeyConfigKey, "some-passphrase") - - km := NewSSES3KeyManager() - err := km.InitializeWithFiler(nil) - if err == nil { - t.Fatal("expected error when both configs set, got nil") - } - if !strings.Contains(err.Error(), "only one") { - t.Errorf("expected 'only one' error, got: %v", err) - } -} - -// TestGetSSES3Headers tests SSE-S3 header generation -func TestGetSSES3Headers(t *testing.T) { - headers := GetSSES3Headers() - - if len(headers) == 0 { - t.Error("Expected headers to be non-empty") - } - - algorithm, exists := headers[s3_constants.AmzServerSideEncryption] - if !exists { - t.Error("Expected AmzServerSideEncryption header to exist") - } - - if algorithm != "AES256" { - t.Errorf("Expected algorithm AES256, got %s", algorithm) - } -} - -// TestProcessSSES3Request tests processing of SSE-S3 requests -func TestProcessSSES3Request(t *testing.T) { - // Initialize global key manager - globalSSES3KeyManager = NewSSES3KeyManager() - defer func() { - globalSSES3KeyManager = NewSSES3KeyManager() - }() - - // Set up the key manager with a super key for testing - keyManager := GetSSES3KeyManager() - keyManager.superKey = make([]byte, 32) - for i := range keyManager.superKey { - keyManager.superKey[i] = byte(i) - } - - // Create SSE-S3 request - req := httptest.NewRequest("PUT", "/bucket/object", nil) - req.Header.Set(s3_constants.AmzServerSideEncryption, "AES256") - - // Process request - metadata, err := ProcessSSES3Request(req) - if err != nil { - t.Fatalf("Failed to process SSE-S3 request: %v", err) - } - - if metadata == nil { - t.Fatal("Expected metadata to be non-nil") - } - - // Verify metadata contains SSE algorithm - if sseAlgo, exists := metadata[s3_constants.AmzServerSideEncryption]; !exists { - t.Error("Expected SSE algorithm in metadata") - } else if string(sseAlgo) != "AES256" { - t.Errorf("Expected AES256, got %s", string(sseAlgo)) - } - - // Verify metadata contains key data - if _, exists := metadata[s3_constants.SeaweedFSSSES3Key]; !exists { - t.Error("Expected SSE-S3 key data in metadata") - } -} - -// TestGetSSES3KeyFromMetadata tests extraction of SSE-S3 key from metadata -func TestGetSSES3KeyFromMetadata(t *testing.T) { - // Initialize global key manager - globalSSES3KeyManager = NewSSES3KeyManager() - defer func() { - globalSSES3KeyManager = NewSSES3KeyManager() - }() - - // Set up the key manager with a super key for testing - keyManager := GetSSES3KeyManager() - keyManager.superKey = make([]byte, 32) - for i := range keyManager.superKey { - keyManager.superKey[i] = byte(i) - } - - // Generate and serialize key - sseS3Key, err := GenerateSSES3Key() - if err != nil { - t.Fatalf("Failed to generate SSE-S3 key: %v", err) - } - - sseS3Key.IV = make([]byte, 16) - for i := range sseS3Key.IV { - sseS3Key.IV[i] = byte(i) - } - - serialized, err := SerializeSSES3Metadata(sseS3Key) - if err != nil { - t.Fatalf("Failed to serialize SSE-S3 metadata: %v", err) - } - - metadata := map[string][]byte{ - s3_constants.SeaweedFSSSES3Key: serialized, - } - - // Extract key - extractedKey, err := GetSSES3KeyFromMetadata(metadata, keyManager) - if err != nil { - t.Fatalf("Failed to get SSE-S3 key from metadata: %v", err) - } - - // Verify key matches - if !bytes.Equal(extractedKey.Key, sseS3Key.Key) { - t.Error("Extracted key doesn't match original key") - } - - if !bytes.Equal(extractedKey.IV, sseS3Key.IV) { - t.Error("Extracted IV doesn't match original IV") - } -} - -// TestSSES3EnvelopeEncryption tests that envelope encryption works correctly -func TestSSES3EnvelopeEncryption(t *testing.T) { - // Initialize key manager with a super key - keyManager := NewSSES3KeyManager() - keyManager.superKey = make([]byte, 32) - for i := range keyManager.superKey { - keyManager.superKey[i] = byte(i + 100) - } - - // Generate a DEK - dek := make([]byte, 32) - for i := range dek { - dek[i] = byte(i) - } - - // Encrypt DEK with super key - encryptedDEK, nonce, err := keyManager.encryptKeyWithSuperKey(dek) - if err != nil { - t.Fatalf("Failed to encrypt DEK: %v", err) - } - - if len(encryptedDEK) == 0 { - t.Error("Encrypted DEK is empty") - } - - if len(nonce) == 0 { - t.Error("Nonce is empty") - } - - // Decrypt DEK with super key - decryptedDEK, err := keyManager.decryptKeyWithSuperKey(encryptedDEK, nonce) - if err != nil { - t.Fatalf("Failed to decrypt DEK: %v", err) - } - - // Verify DEK matches - if !bytes.Equal(decryptedDEK, dek) { - t.Error("Decrypted DEK doesn't match original DEK") - } -} - -// TestValidateSSES3Key tests SSE-S3 key validation -func TestValidateSSES3Key(t *testing.T) { - testCases := []struct { - name string - key *SSES3Key - shouldError bool - errorMsg string - }{ - { - name: "Nil key", - key: nil, - shouldError: true, - errorMsg: "SSE-S3 key cannot be nil", - }, - { - name: "Valid key", - key: &SSES3Key{ - Key: make([]byte, 32), - KeyID: "test-key", - Algorithm: "AES256", - }, - shouldError: false, - }, - { - name: "Valid key with IV", - key: &SSES3Key{ - Key: make([]byte, 32), - KeyID: "test-key", - Algorithm: "AES256", - IV: make([]byte, 16), - }, - shouldError: false, - }, - { - name: "Invalid key size (too small)", - key: &SSES3Key{ - Key: make([]byte, 16), - KeyID: "test-key", - Algorithm: "AES256", - }, - shouldError: true, - errorMsg: "invalid SSE-S3 key size", - }, - { - name: "Invalid key size (too large)", - key: &SSES3Key{ - Key: make([]byte, 64), - KeyID: "test-key", - Algorithm: "AES256", - }, - shouldError: true, - errorMsg: "invalid SSE-S3 key size", - }, - { - name: "Nil key bytes", - key: &SSES3Key{ - Key: nil, - KeyID: "test-key", - Algorithm: "AES256", - }, - shouldError: true, - errorMsg: "SSE-S3 key bytes cannot be nil", - }, - { - name: "Empty key ID", - key: &SSES3Key{ - Key: make([]byte, 32), - KeyID: "", - Algorithm: "AES256", - }, - shouldError: true, - errorMsg: "SSE-S3 key ID cannot be empty", - }, - { - name: "Invalid algorithm", - key: &SSES3Key{ - Key: make([]byte, 32), - KeyID: "test-key", - Algorithm: "INVALID", - }, - shouldError: true, - errorMsg: "invalid SSE-S3 algorithm", - }, - { - name: "Invalid IV length", - key: &SSES3Key{ - Key: make([]byte, 32), - KeyID: "test-key", - Algorithm: "AES256", - IV: make([]byte, 8), // Wrong size - }, - shouldError: true, - errorMsg: "invalid SSE-S3 IV length", - }, - { - name: "Empty IV is allowed (set during encryption)", - key: &SSES3Key{ - Key: make([]byte, 32), - KeyID: "test-key", - Algorithm: "AES256", - IV: []byte{}, // Empty is OK - }, - shouldError: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := ValidateSSES3Key(tc.key) - if tc.shouldError { - if err == nil { - t.Error("Expected error but got none") - } else if tc.errorMsg != "" && !strings.Contains(err.Error(), tc.errorMsg) { - t.Errorf("Expected error containing %q, got: %v", tc.errorMsg, err) - } - } else { - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - } - }) - } -} diff --git a/weed/s3api/s3_validation_utils.go b/weed/s3api/s3_validation_utils.go index f69fc9c26..16e63595c 100644 --- a/weed/s3api/s3_validation_utils.go +++ b/weed/s3api/s3_validation_utils.go @@ -58,14 +58,6 @@ func ValidateSSEKMSKey(sseKey *SSEKMSKey) error { return nil } -// ValidateSSECKey validates that an SSE-C key is not nil -func ValidateSSECKey(customerKey *SSECustomerKey) error { - if customerKey == nil { - return fmt.Errorf("SSE-C customer key cannot be nil") - } - return nil -} - // ValidateSSES3Key validates that an SSE-S3 key has valid structure and contents func ValidateSSES3Key(sseKey *SSES3Key) error { if sseKey == nil { diff --git a/weed/s3api/s3api_acl_helper.go b/weed/s3api/s3api_acl_helper.go index 6cfa17f34..5c4804536 100644 --- a/weed/s3api/s3api_acl_helper.go +++ b/weed/s3api/s3api_acl_helper.go @@ -20,16 +20,6 @@ type AccountManager interface { GetAccountIdByEmail(email string) string } -// GetAccountId get AccountId from request headers, AccountAnonymousId will be return if not presen -func GetAccountId(r *http.Request) string { - id := r.Header.Get(s3_constants.AmzAccountId) - if len(id) == 0 { - return s3_constants.AccountAnonymousId - } else { - return id - } -} - // ExtractAcl extracts the acl from the request body, or from the header if request body is empty func ExtractAcl(r *http.Request, accountManager AccountManager, ownership, bucketOwnerId, ownerId, accountId string) (grants []*s3.Grant, errCode s3err.ErrorCode) { if r.Body != nil && r.Body != http.NoBody { @@ -318,83 +308,6 @@ func ValidateAndTransferGrants(accountManager AccountManager, grants []*s3.Grant return result, s3err.ErrNone } -// DetermineReqGrants generates the grant set (Grants) according to accountId and reqPermission. -func DetermineReqGrants(accountId, aclAction string) (grants []*s3.Grant) { - // group grantee (AllUsers) - grants = append(grants, &s3.Grant{ - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - Permission: &aclAction, - }) - grants = append(grants, &s3.Grant{ - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - Permission: &s3_constants.PermissionFullControl, - }) - - // canonical grantee (accountId) - grants = append(grants, &s3.Grant{ - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: &accountId, - }, - Permission: &aclAction, - }) - grants = append(grants, &s3.Grant{ - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: &accountId, - }, - Permission: &s3_constants.PermissionFullControl, - }) - - // group grantee (AuthenticateUsers) - if accountId != s3_constants.AccountAnonymousId { - grants = append(grants, &s3.Grant{ - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAuthenticatedUsers, - }, - Permission: &aclAction, - }) - grants = append(grants, &s3.Grant{ - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAuthenticatedUsers, - }, - Permission: &s3_constants.PermissionFullControl, - }) - } - return -} - -func SetAcpOwnerHeader(r *http.Request, acpOwnerId string) { - r.Header.Set(s3_constants.ExtAmzOwnerKey, acpOwnerId) -} - -func GetAcpOwner(entryExtended map[string][]byte, defaultOwner string) string { - ownerIdBytes, ok := entryExtended[s3_constants.ExtAmzOwnerKey] - if ok && len(ownerIdBytes) > 0 { - return string(ownerIdBytes) - } - return defaultOwner -} - -func SetAcpGrantsHeader(r *http.Request, acpGrants []*s3.Grant) { - if len(acpGrants) > 0 { - a, err := json.Marshal(acpGrants) - if err == nil { - r.Header.Set(s3_constants.ExtAmzAclKey, string(a)) - } else { - glog.Warning("Marshal acp grants err", err) - } - } -} - // GetAcpGrants return grants parsed from entry func GetAcpGrants(entryExtended map[string][]byte) []*s3.Grant { acpBytes, ok := entryExtended[s3_constants.ExtAmzAclKey] @@ -433,82 +346,3 @@ func AssembleEntryWithAcp(objectEntry *filer_pb.Entry, objectOwner string, grant return s3err.ErrNone } - -// GrantEquals Compare whether two Grants are equal in meaning, not completely -// equal (compare Grantee.Type and the corresponding Value for equality, other -// fields of Grantee are ignored) -func GrantEquals(a, b *s3.Grant) bool { - // grant - if a == b { - return true - } - - if a == nil || b == nil { - return false - } - - // grant.Permission - if a.Permission != b.Permission { - if a.Permission == nil || b.Permission == nil { - return false - } - - if *a.Permission != *b.Permission { - return false - } - } - - // grant.Grantee - ag := a.Grantee - bg := b.Grantee - if ag != bg { - if ag == nil || bg == nil { - return false - } - // grantee.Type - if ag.Type != bg.Type { - if ag.Type == nil || bg.Type == nil { - return false - } - if *ag.Type != *bg.Type { - return false - } - } - // value corresponding to granteeType - if ag.Type != nil { - switch *ag.Type { - case s3_constants.GrantTypeGroup: - if ag.URI != bg.URI { - if ag.URI == nil || bg.URI == nil { - return false - } - - if *ag.URI != *bg.URI { - return false - } - } - case s3_constants.GrantTypeCanonicalUser: - if ag.ID != bg.ID { - if ag.ID == nil || bg.ID == nil { - return false - } - - if *ag.ID != *bg.ID { - return false - } - } - case s3_constants.GrantTypeAmazonCustomerByEmail: - if ag.EmailAddress != bg.EmailAddress { - if ag.EmailAddress == nil || bg.EmailAddress == nil { - return false - } - - if *ag.EmailAddress != *bg.EmailAddress { - return false - } - } - } - } - } - return true -} diff --git a/weed/s3api/s3api_acl_helper_test.go b/weed/s3api/s3api_acl_helper_test.go deleted file mode 100644 index d3a625ce2..000000000 --- a/weed/s3api/s3api_acl_helper_test.go +++ /dev/null @@ -1,710 +0,0 @@ -package s3api - -import ( - "bytes" - "encoding/json" - "io" - "net/http" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" -) - -var accountManager *IdentityAccessManagement - -func init() { - accountManager = &IdentityAccessManagement{} - _ = accountManager.loadS3ApiConfiguration(&iam_pb.S3ApiConfiguration{ - Accounts: []*iam_pb.Account{ - { - Id: "accountA", - DisplayName: "accountAName", - EmailAddress: "accountA@example.com", - }, - { - Id: "accountB", - DisplayName: "accountBName", - EmailAddress: "accountB@example.com", - }, - }, - }) -} - -func TestGetAccountId(t *testing.T) { - req := &http.Request{ - Header: make(map[string][]string), - } - //case1 - //accountId: "admin" - req.Header.Set(s3_constants.AmzAccountId, s3_constants.AccountAdminId) - if GetAccountId(req) != s3_constants.AccountAdminId { - t.Fatal("expect accountId: admin") - } - - //case2 - //accountId: "anoymous" - req.Header.Set(s3_constants.AmzAccountId, s3_constants.AccountAnonymousId) - if GetAccountId(req) != s3_constants.AccountAnonymousId { - t.Fatal("expect accountId: anonymous") - } - - //case3 - //accountId is nil => "anonymous" - req.Header.Del(s3_constants.AmzAccountId) - if GetAccountId(req) != s3_constants.AccountAnonymousId { - t.Fatal("expect accountId: anonymous") - } -} - -func TestExtractAcl(t *testing.T) { - type Case struct { - id int - resultErrCode, expectErrCode s3err.ErrorCode - resultGrants, expectGrants []*s3.Grant - } - testCases := make([]*Case, 0) - accountAdminId := "admin" - { - //case1 (good case) - //parse acp from request body - req := &http.Request{ - Header: make(map[string][]string), - } - req.Body = io.NopCloser(bytes.NewReader([]byte(` - - - admin - admin - - - - - admin - - FULL_CONTROL - - - - http://acs.amazonaws.com/groups/global/AllUsers - - FULL_CONTROL - - - - `))) - objectWriter := "accountA" - grants, errCode := ExtractAcl(req, accountManager, s3_constants.OwnershipObjectWriter, accountAdminId, accountAdminId, objectWriter) - testCases = append(testCases, &Case{ - 1, - errCode, s3err.ErrNone, - grants, []*s3.Grant{ - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: &accountAdminId, - }, - Permission: &s3_constants.PermissionFullControl, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - Permission: &s3_constants.PermissionFullControl, - }, - }, - }) - } - - { - //case2 (good case) - //parse acp from header (cannedAcl) - req := &http.Request{ - Header: make(map[string][]string), - } - req.Body = nil - req.Header.Set(s3_constants.AmzCannedAcl, s3_constants.CannedAclPrivate) - objectWriter := "accountA" - grants, errCode := ExtractAcl(req, accountManager, s3_constants.OwnershipObjectWriter, accountAdminId, accountAdminId, objectWriter) - testCases = append(testCases, &Case{ - 2, - errCode, s3err.ErrNone, - grants, []*s3.Grant{ - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: &objectWriter, - }, - Permission: &s3_constants.PermissionFullControl, - }, - }, - }) - } - - { - //case3 (bad case) - //parse acp from request body (content is invalid) - req := &http.Request{ - Header: make(map[string][]string), - } - req.Body = io.NopCloser(bytes.NewReader([]byte("zdfsaf"))) - req.Header.Set(s3_constants.AmzCannedAcl, s3_constants.CannedAclPrivate) - objectWriter := "accountA" - _, errCode := ExtractAcl(req, accountManager, s3_constants.OwnershipObjectWriter, accountAdminId, accountAdminId, objectWriter) - testCases = append(testCases, &Case{ - id: 3, - resultErrCode: errCode, expectErrCode: s3err.ErrInvalidRequest, - }) - } - - //case4 (bad case) - //parse acp from header (cannedAcl is invalid) - req := &http.Request{ - Header: make(map[string][]string), - } - req.Body = nil - req.Header.Set(s3_constants.AmzCannedAcl, "dfaksjfk") - objectWriter := "accountA" - _, errCode := ExtractAcl(req, accountManager, s3_constants.OwnershipObjectWriter, accountAdminId, "", objectWriter) - testCases = append(testCases, &Case{ - id: 4, - resultErrCode: errCode, expectErrCode: s3err.ErrInvalidRequest, - }) - - { - //case5 (bad case) - //parse acp from request body: owner is inconsistent - req.Body = io.NopCloser(bytes.NewReader([]byte(` - - - admin - admin - - - - - admin - - FULL_CONTROL - - - - http://acs.amazonaws.com/groups/global/AllUsers - - FULL_CONTROL - - - - `))) - objectWriter = "accountA" - _, errCode := ExtractAcl(req, accountManager, s3_constants.OwnershipObjectWriter, accountAdminId, objectWriter, objectWriter) - testCases = append(testCases, &Case{ - id: 5, - resultErrCode: errCode, expectErrCode: s3err.ErrAccessDenied, - }) - } - - for _, tc := range testCases { - if tc.resultErrCode != tc.expectErrCode { - t.Fatalf("case[%d]: errorCode not expect", tc.id) - } - if !grantsEquals(tc.resultGrants, tc.expectGrants) { - t.Fatalf("case[%d]: grants not expect", tc.id) - } - } -} - -func TestParseAndValidateAclHeaders(t *testing.T) { - type Case struct { - id int - resultOwner, expectOwner string - resultErrCode, expectErrCode s3err.ErrorCode - resultGrants, expectGrants []*s3.Grant - } - testCases := make([]*Case, 0) - bucketOwner := "admin" - - { - //case1 (good case) - //parse custom acl - req := &http.Request{ - Header: make(map[string][]string), - } - objectWriter := "accountA" - req.Header.Set(s3_constants.AmzAclFullControl, `uri="http://acs.amazonaws.com/groups/global/AllUsers", id="anonymous", emailAddress="admin@example.com"`) - ownerId, grants, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipObjectWriter, bucketOwner, objectWriter, false) - testCases = append(testCases, &Case{ - 1, - ownerId, objectWriter, - errCode, s3err.ErrNone, - grants, []*s3.Grant{ - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - Permission: &s3_constants.PermissionFullControl, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: aws.String(s3_constants.AccountAnonymousId), - }, - Permission: &s3_constants.PermissionFullControl, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: aws.String(s3_constants.AccountAdminId), - }, - Permission: &s3_constants.PermissionFullControl, - }, - }, - }) - } - { - //case2 (good case) - //parse canned acl (ownership=ObjectWriter) - req := &http.Request{ - Header: make(map[string][]string), - } - objectWriter := "accountA" - req.Header.Set(s3_constants.AmzCannedAcl, s3_constants.CannedAclBucketOwnerFullControl) - ownerId, grants, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipObjectWriter, bucketOwner, objectWriter, false) - testCases = append(testCases, &Case{ - 2, - ownerId, objectWriter, - errCode, s3err.ErrNone, - grants, []*s3.Grant{ - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: &objectWriter, - }, - Permission: &s3_constants.PermissionFullControl, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: &bucketOwner, - }, - Permission: &s3_constants.PermissionFullControl, - }, - }, - }) - } - { - //case3 (good case) - //parse canned acl (ownership=OwnershipBucketOwnerPreferred) - req := &http.Request{ - Header: make(map[string][]string), - } - objectWriter := "accountA" - req.Header.Set(s3_constants.AmzCannedAcl, s3_constants.CannedAclBucketOwnerFullControl) - ownerId, grants, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipBucketOwnerPreferred, bucketOwner, objectWriter, false) - testCases = append(testCases, &Case{ - 3, - ownerId, bucketOwner, - errCode, s3err.ErrNone, - grants, []*s3.Grant{ - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: &bucketOwner, - }, - Permission: &s3_constants.PermissionFullControl, - }, - }, - }) - } - { - //case4 (bad case) - //parse custom acl (grantee id not exists) - req := &http.Request{ - Header: make(map[string][]string), - } - objectWriter := "accountA" - req.Header.Set(s3_constants.AmzAclFullControl, `uri="http://acs.amazonaws.com/groups/global/AllUsers", id="notExistsAccount", emailAddress="admin@example.com"`) - _, _, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipObjectWriter, bucketOwner, objectWriter, false) - testCases = append(testCases, &Case{ - id: 4, - resultErrCode: errCode, expectErrCode: s3err.ErrInvalidRequest, - }) - } - - { - //case5 (bad case) - //parse custom acl (invalid format) - req := &http.Request{ - Header: make(map[string][]string), - } - objectWriter := "accountA" - req.Header.Set(s3_constants.AmzAclFullControl, `uri="http:sfasf"`) - _, _, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipObjectWriter, bucketOwner, objectWriter, false) - testCases = append(testCases, &Case{ - id: 5, - resultErrCode: errCode, expectErrCode: s3err.ErrInvalidRequest, - }) - } - - { - //case6 (bad case) - //parse canned acl (invalid value) - req := &http.Request{ - Header: make(map[string][]string), - } - objectWriter := "accountA" - req.Header.Set(s3_constants.AmzCannedAcl, `uri="http:sfasf"`) - _, _, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipObjectWriter, bucketOwner, objectWriter, false) - testCases = append(testCases, &Case{ - id: 5, - resultErrCode: errCode, expectErrCode: s3err.ErrInvalidRequest, - }) - } - - for _, tc := range testCases { - if tc.expectErrCode != tc.resultErrCode { - t.Errorf("case[%d]: errCode unexpect", tc.id) - } - if tc.resultOwner != tc.expectOwner { - t.Errorf("case[%d]: ownerId unexpect", tc.id) - } - if !grantsEquals(tc.resultGrants, tc.expectGrants) { - t.Fatalf("case[%d]: grants not expect", tc.id) - } - } -} - -func grantsEquals(a, b []*s3.Grant) bool { - if len(a) != len(b) { - return false - } - for i, grant := range a { - if !GrantEquals(grant, b[i]) { - return false - } - } - return true -} - -func TestDetermineReqGrants(t *testing.T) { - { - //case1: request account is anonymous - accountId := s3_constants.AccountAnonymousId - reqPermission := s3_constants.PermissionRead - - resultGrants := DetermineReqGrants(accountId, reqPermission) - expectGrants := []*s3.Grant{ - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - Permission: &reqPermission, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - Permission: &s3_constants.PermissionFullControl, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: &accountId, - }, - Permission: &reqPermission, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: &accountId, - }, - Permission: &s3_constants.PermissionFullControl, - }, - } - if !grantsEquals(resultGrants, expectGrants) { - t.Fatalf("grants not expect") - } - } - { - //case2: request account is not anonymous (Iam authed) - accountId := "accountX" - reqPermission := s3_constants.PermissionRead - - resultGrants := DetermineReqGrants(accountId, reqPermission) - expectGrants := []*s3.Grant{ - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - Permission: &reqPermission, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - Permission: &s3_constants.PermissionFullControl, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: &accountId, - }, - Permission: &reqPermission, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeCanonicalUser, - ID: &accountId, - }, - Permission: &s3_constants.PermissionFullControl, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAuthenticatedUsers, - }, - Permission: &reqPermission, - }, - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAuthenticatedUsers, - }, - Permission: &s3_constants.PermissionFullControl, - }, - } - if !grantsEquals(resultGrants, expectGrants) { - t.Fatalf("grants not expect") - } - } -} - -func TestAssembleEntryWithAcp(t *testing.T) { - defaultOwner := "admin" - - //case1 - //assemble with non-empty grants - expectOwner := "accountS" - expectGrants := []*s3.Grant{ - { - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - ID: aws.String(s3_constants.AccountAdminId), - URI: &s3_constants.GranteeGroupAllUsers, - }, - }, - } - entry := &filer_pb.Entry{} - AssembleEntryWithAcp(entry, expectOwner, expectGrants) - - resultOwner := GetAcpOwner(entry.Extended, defaultOwner) - if resultOwner != expectOwner { - t.Fatalf("owner not expect") - } - - resultGrants := GetAcpGrants(entry.Extended) - if !grantsEquals(resultGrants, expectGrants) { - t.Fatal("grants not expect") - } - - //case2 - //assemble with empty grants (override) - AssembleEntryWithAcp(entry, "", nil) - resultOwner = GetAcpOwner(entry.Extended, defaultOwner) - if resultOwner != defaultOwner { - t.Fatalf("owner not expect") - } - - resultGrants = GetAcpGrants(entry.Extended) - if len(resultGrants) != 0 { - t.Fatal("grants not expect") - } - -} - -func TestGrantEquals(t *testing.T) { - testCases := map[bool]bool{ - GrantEquals(nil, nil): true, - - GrantEquals(&s3.Grant{}, nil): false, - - GrantEquals(&s3.Grant{}, &s3.Grant{}): true, - - GrantEquals(&s3.Grant{ - Permission: &s3_constants.PermissionRead, - }, &s3.Grant{}): false, - - GrantEquals(&s3.Grant{ - Permission: &s3_constants.PermissionRead, - }, &s3.Grant{ - Permission: &s3_constants.PermissionRead, - }): true, - - GrantEquals(&s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{}, - }, &s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{}, - }): true, - - GrantEquals(&s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - }, - }, &s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{}, - }): false, - - //type not present, compare other fields of grant is meaningless - GrantEquals(&s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - ID: aws.String(s3_constants.AccountAdminId), - //EmailAddress: &s3account.AccountAdmin.EmailAddress, - }, - }, &s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - ID: aws.String(s3_constants.AccountAdminId), - }, - }): true, - - GrantEquals(&s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - }, - }, &s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - }, - }): true, - - GrantEquals(&s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - }, &s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - }): true, - - GrantEquals(&s3.Grant{ - Permission: &s3_constants.PermissionWrite, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - }, &s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - }): false, - - GrantEquals(&s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - ID: aws.String(s3_constants.AccountAdminId), - }, - }, &s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - ID: aws.String(s3_constants.AccountAdminId), - }, - }): true, - - GrantEquals(&s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - ID: aws.String(s3_constants.AccountAdminId), - URI: &s3_constants.GranteeGroupAllUsers, - }, - }, &s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - ID: aws.String(s3_constants.AccountAdminId), - }, - }): false, - - GrantEquals(&s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - ID: aws.String(s3_constants.AccountAdminId), - URI: &s3_constants.GranteeGroupAllUsers, - }, - }, &s3.Grant{ - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - }): true, - } - - for tc, expect := range testCases { - if tc != expect { - t.Fatal("TestGrantEquals not expect!") - } - } -} - -func TestSetAcpOwnerHeader(t *testing.T) { - ownerId := "accountZ" - req := &http.Request{ - Header: make(map[string][]string), - } - SetAcpOwnerHeader(req, ownerId) - - if req.Header.Get(s3_constants.ExtAmzOwnerKey) != ownerId { - t.Fatalf("owner unexpect") - } -} - -func TestSetAcpGrantsHeader(t *testing.T) { - req := &http.Request{ - Header: make(map[string][]string), - } - grants := []*s3.Grant{ - { - Permission: &s3_constants.PermissionRead, - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - ID: aws.String(s3_constants.AccountAdminId), - URI: &s3_constants.GranteeGroupAllUsers, - }, - }, - } - SetAcpGrantsHeader(req, grants) - - grantsJson, _ := json.Marshal(grants) - if req.Header.Get(s3_constants.ExtAmzAclKey) != string(grantsJson) { - t.Fatalf("owner unexpect") - } -} diff --git a/weed/s3api/s3api_bucket_handlers.go b/weed/s3api/s3api_bucket_handlers.go index 5abbd5d22..eec6a11ed 100644 --- a/weed/s3api/s3api_bucket_handlers.go +++ b/weed/s3api/s3api_bucket_handlers.go @@ -150,14 +150,6 @@ func isBucketOwnedByIdentity(entry *filer_pb.Entry, identity *Identity) bool { return true } -// isBucketVisibleToIdentity is kept for backward compatibility with tests. -// It checks if a bucket should be visible based on ownership only. -// Deprecated: Use isBucketOwnedByIdentity instead. The ListBucketsHandler -// now uses OR logic: a bucket is visible if user owns it OR has List permission. -func isBucketVisibleToIdentity(entry *filer_pb.Entry, identity *Identity) bool { - return isBucketOwnedByIdentity(entry, identity) -} - func (s3a *S3ApiServer) PutBucketHandler(w http.ResponseWriter, r *http.Request) { // collect parameters diff --git a/weed/s3api/s3api_bucket_handlers_test.go b/weed/s3api/s3api_bucket_handlers_test.go deleted file mode 100644 index ee79381b3..000000000 --- a/weed/s3api/s3api_bucket_handlers_test.go +++ /dev/null @@ -1,1085 +0,0 @@ -package s3api - -import ( - "encoding/json" - "encoding/xml" - "fmt" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/aws/aws-sdk-go/service/s3" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestPutBucketAclCannedAclSupport(t *testing.T) { - // Test that the ExtractAcl function can handle various canned ACLs - // This tests the core functionality without requiring a fully initialized S3ApiServer - - testCases := []struct { - name string - cannedAcl string - shouldWork bool - description string - }{ - { - name: "private", - cannedAcl: s3_constants.CannedAclPrivate, - shouldWork: true, - description: "private ACL should be accepted", - }, - { - name: "public-read", - cannedAcl: s3_constants.CannedAclPublicRead, - shouldWork: true, - description: "public-read ACL should be accepted", - }, - { - name: "public-read-write", - cannedAcl: s3_constants.CannedAclPublicReadWrite, - shouldWork: true, - description: "public-read-write ACL should be accepted", - }, - { - name: "authenticated-read", - cannedAcl: s3_constants.CannedAclAuthenticatedRead, - shouldWork: true, - description: "authenticated-read ACL should be accepted", - }, - { - name: "bucket-owner-read", - cannedAcl: s3_constants.CannedAclBucketOwnerRead, - shouldWork: true, - description: "bucket-owner-read ACL should be accepted", - }, - { - name: "bucket-owner-full-control", - cannedAcl: s3_constants.CannedAclBucketOwnerFullControl, - shouldWork: true, - description: "bucket-owner-full-control ACL should be accepted", - }, - { - name: "invalid-acl", - cannedAcl: "invalid-acl-value", - shouldWork: false, - description: "invalid ACL should be rejected", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Create a request with the specified canned ACL - req := httptest.NewRequest("PUT", "/bucket?acl", nil) - req.Header.Set(s3_constants.AmzCannedAcl, tc.cannedAcl) - req.Header.Set(s3_constants.AmzAccountId, "test-account-123") - - // Create a mock IAM for testing - mockIam := &mockIamInterface{} - - // Test the ACL extraction directly - grants, errCode := ExtractAcl(req, mockIam, "", "test-account-123", "test-account-123", "test-account-123") - - if tc.shouldWork { - assert.Equal(t, s3err.ErrNone, errCode, "Expected ACL parsing to succeed for %s", tc.cannedAcl) - assert.NotEmpty(t, grants, "Expected grants to be generated for valid ACL %s", tc.cannedAcl) - t.Logf("✓ PASS: %s - %s", tc.name, tc.description) - } else { - assert.NotEqual(t, s3err.ErrNone, errCode, "Expected ACL parsing to fail for invalid ACL %s", tc.cannedAcl) - t.Logf("✓ PASS: %s - %s", tc.name, tc.description) - } - }) - } -} - -// TestBucketWithoutACLIsNotPublicRead tests that buckets without ACLs are not public-read -func TestBucketWithoutACLIsNotPublicRead(t *testing.T) { - // Create a bucket config without ACL (like a freshly created bucket) - config := &BucketConfig{ - Name: "test-bucket", - IsPublicRead: false, // Should be explicitly false - } - - // Verify that buckets without ACL are not public-read - assert.False(t, config.IsPublicRead, "Bucket without ACL should not be public-read") -} - -func TestBucketConfigInitialization(t *testing.T) { - // Test that BucketConfig properly initializes IsPublicRead field - config := &BucketConfig{ - Name: "test-bucket", - IsPublicRead: false, // Explicitly set to false for private buckets - } - - // Verify proper initialization - assert.False(t, config.IsPublicRead, "Newly created bucket should not be public-read by default") -} - -// TestUpdateBucketConfigCacheConsistency tests that updateBucketConfigCacheFromEntry -// properly handles the IsPublicRead flag consistently with getBucketConfig -func TestUpdateBucketConfigCacheConsistency(t *testing.T) { - t.Run("bucket without ACL should have IsPublicRead=false", func(t *testing.T) { - // Simulate an entry without ACL (like a freshly created bucket) - entry := &filer_pb.Entry{ - Name: "test-bucket", - Attributes: &filer_pb.FuseAttributes{ - FileMode: 0755, - }, - // Extended is nil or doesn't contain ACL - } - - // Test what updateBucketConfigCacheFromEntry would create - config := &BucketConfig{ - Name: entry.Name, - Entry: entry, - IsPublicRead: false, // Should be explicitly false - } - - // When Extended is nil, IsPublicRead should be false - assert.False(t, config.IsPublicRead, "Bucket without Extended metadata should not be public-read") - - // When Extended exists but has no ACL key, IsPublicRead should also be false - entry.Extended = make(map[string][]byte) - entry.Extended["some-other-key"] = []byte("some-value") - - config = &BucketConfig{ - Name: entry.Name, - Entry: entry, - IsPublicRead: false, // Should be explicitly false - } - - // Simulate the else branch: no ACL means private bucket - if _, exists := entry.Extended[s3_constants.ExtAmzAclKey]; !exists { - config.IsPublicRead = false - } - - assert.False(t, config.IsPublicRead, "Bucket with Extended but no ACL should not be public-read") - }) - - t.Run("bucket with public-read ACL should have IsPublicRead=true", func(t *testing.T) { - // Create a mock public-read ACL using AWS S3 SDK types - publicReadGrants := []*s3.Grant{ - { - Grantee: &s3.Grantee{ - Type: &s3_constants.GrantTypeGroup, - URI: &s3_constants.GranteeGroupAllUsers, - }, - Permission: &s3_constants.PermissionRead, - }, - } - - aclBytes, err := json.Marshal(publicReadGrants) - require.NoError(t, err) - - entry := &filer_pb.Entry{ - Name: "public-bucket", - Extended: map[string][]byte{ - s3_constants.ExtAmzAclKey: aclBytes, - }, - } - - config := &BucketConfig{ - Name: entry.Name, - Entry: entry, - IsPublicRead: false, // Start with false - } - - // Simulate what updateBucketConfigCacheFromEntry would do - if acl, exists := entry.Extended[s3_constants.ExtAmzAclKey]; exists { - config.ACL = acl - config.IsPublicRead = parseAndCachePublicReadStatus(acl) - } - - assert.True(t, config.IsPublicRead, "Bucket with public-read ACL should be public-read") - }) -} - -// mockIamInterface is a simple mock for testing -type mockIamInterface struct{} - -func (m *mockIamInterface) GetAccountNameById(canonicalId string) string { - return "test-user-" + canonicalId -} - -func (m *mockIamInterface) GetAccountIdByEmail(email string) string { - return "account-for-" + email -} - -// TestListAllMyBucketsResultNamespace verifies that the ListAllMyBucketsResult -// XML response includes the proper S3 namespace URI -func TestListAllMyBucketsResultNamespace(t *testing.T) { - // Create a sample ListAllMyBucketsResult response - response := ListAllMyBucketsResult{ - Owner: CanonicalUser{ - ID: "test-owner-id", - DisplayName: "test-owner", - }, - Buckets: ListAllMyBucketsList{ - Bucket: []ListAllMyBucketsEntry{ - { - Name: "test-bucket", - CreationDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), - }, - }, - }, - } - - // Marshal the response to XML - xmlData, err := xml.Marshal(response) - require.NoError(t, err, "Failed to marshal XML response") - - xmlString := string(xmlData) - - // Verify that the XML contains the proper namespace - assert.Contains(t, xmlString, `xmlns="http://s3.amazonaws.com/doc/2006-03-01/"`, - "XML response should contain the S3 namespace URI") - - // Verify the root element has the correct name - assert.Contains(t, xmlString, "", "XML should contain Owner element") - assert.Contains(t, xmlString, "", "XML should contain Buckets element") - assert.Contains(t, xmlString, "", "XML should contain Bucket element") - assert.Contains(t, xmlString, "test-bucket", "XML should contain bucket name") - - t.Logf("Generated XML:\n%s", xmlString) -} - -// TestListBucketsOwnershipFiltering tests that ListBucketsHandler properly filters -// buckets based on ownership, allowing only bucket owners (or admins) to see their buckets -func TestListBucketsOwnershipFiltering(t *testing.T) { - testCases := []struct { - name string - buckets []testBucket - requestIdentityId string - requestIsAdmin bool - expectedBucketNames []string - description string - }{ - { - name: "non-admin sees only owned buckets", - buckets: []testBucket{ - {name: "user1-bucket", ownerId: "user1"}, - {name: "user2-bucket", ownerId: "user2"}, - {name: "user1-bucket2", ownerId: "user1"}, - }, - requestIdentityId: "user1", - requestIsAdmin: false, - expectedBucketNames: []string{"user1-bucket", "user1-bucket2"}, - description: "Non-admin user should only see buckets they own", - }, - { - name: "admin sees all buckets", - buckets: []testBucket{ - {name: "user1-bucket", ownerId: "user1"}, - {name: "user2-bucket", ownerId: "user2"}, - {name: "user3-bucket", ownerId: "user3"}, - }, - requestIdentityId: "admin", - requestIsAdmin: true, - expectedBucketNames: []string{"user1-bucket", "user2-bucket", "user3-bucket"}, - description: "Admin should see all buckets regardless of owner", - }, - { - name: "buckets without owner are hidden from non-admins", - buckets: []testBucket{ - {name: "owned-bucket", ownerId: "user1"}, - {name: "unowned-bucket", ownerId: ""}, // No owner set - }, - requestIdentityId: "user2", - requestIsAdmin: false, - expectedBucketNames: []string{}, - description: "Buckets without owner should be hidden from non-admin users", - }, - { - name: "unauthenticated user sees no buckets", - buckets: []testBucket{ - {name: "owned-bucket", ownerId: "user1"}, - {name: "unowned-bucket", ownerId: ""}, - }, - requestIdentityId: "", - requestIsAdmin: false, - expectedBucketNames: []string{}, - description: "Unauthenticated requests should not see any buckets", - }, - { - name: "admin sees buckets regardless of ownership", - buckets: []testBucket{ - {name: "user1-bucket", ownerId: "user1"}, - {name: "user2-bucket", ownerId: "user2"}, - {name: "unowned-bucket", ownerId: ""}, - }, - requestIdentityId: "admin", - requestIsAdmin: true, - expectedBucketNames: []string{"user1-bucket", "user2-bucket", "unowned-bucket"}, - description: "Admin should see all buckets regardless of ownership", - }, - { - name: "buckets with nil Extended metadata hidden from non-admins", - buckets: []testBucket{ - {name: "bucket-no-extended", ownerId: "", nilExtended: true}, - {name: "bucket-with-owner", ownerId: "user1"}, - }, - requestIdentityId: "user1", - requestIsAdmin: false, - expectedBucketNames: []string{"bucket-with-owner"}, - description: "Buckets with nil Extended (no owner) should be hidden from non-admins", - }, - { - name: "user sees only their bucket among many", - buckets: []testBucket{ - {name: "alice-bucket", ownerId: "alice"}, - {name: "bob-bucket", ownerId: "bob"}, - {name: "charlie-bucket", ownerId: "charlie"}, - {name: "alice-bucket2", ownerId: "alice"}, - }, - requestIdentityId: "bob", - requestIsAdmin: false, - expectedBucketNames: []string{"bob-bucket"}, - description: "User should see only their single bucket among many", - }, - { - name: "admin sees buckets without owners", - buckets: []testBucket{ - {name: "owned-bucket", ownerId: "user1"}, - {name: "unowned-bucket", ownerId: ""}, - {name: "no-metadata-bucket", ownerId: "", nilExtended: true}, - }, - requestIdentityId: "admin", - requestIsAdmin: true, - expectedBucketNames: []string{"owned-bucket", "unowned-bucket", "no-metadata-bucket"}, - description: "Admin should see all buckets including those without owners", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Create mock entries - entries := make([]*filer_pb.Entry, 0, len(tc.buckets)) - for _, bucket := range tc.buckets { - entry := &filer_pb.Entry{ - Name: bucket.name, - IsDirectory: true, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - if !bucket.nilExtended { - entry.Extended = make(map[string][]byte) - if bucket.ownerId != "" { - entry.Extended[s3_constants.AmzIdentityId] = []byte(bucket.ownerId) - } - } - - entries = append(entries, entry) - } - - // Filter entries using the actual production code - var filteredBuckets []string - for _, entry := range entries { - var identity *Identity - if tc.requestIdentityId != "" { - identity = mockIdentity(tc.requestIdentityId, tc.requestIsAdmin) - } - if isBucketVisibleToIdentity(entry, identity) { - filteredBuckets = append(filteredBuckets, entry.Name) - } - } - - // Assert expected buckets match filtered buckets - assert.ElementsMatch(t, tc.expectedBucketNames, filteredBuckets, - "%s - Expected buckets: %v, Got: %v", tc.description, tc.expectedBucketNames, filteredBuckets) - }) - } -} - -// testBucket represents a bucket for testing with ownership metadata -type testBucket struct { - name string - ownerId string - nilExtended bool -} - -// mockIdentity creates a mock Identity for testing bucket visibility -func mockIdentity(name string, isAdmin bool) *Identity { - identity := &Identity{ - Name: name, - } - if isAdmin { - identity.Credentials = []*Credential{ - { - AccessKey: "admin-key", - SecretKey: "admin-secret", - }, - } - identity.Actions = []Action{Action(s3_constants.ACTION_ADMIN)} - } - return identity -} - -// TestListBucketsOwnershipEdgeCases tests edge cases in ownership filtering -func TestListBucketsOwnershipEdgeCases(t *testing.T) { - t.Run("malformed owner id with special characters", func(t *testing.T) { - entry := &filer_pb.Entry{ - Name: "test-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("user@domain.com"), - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - identity := mockIdentity("user@domain.com", false) - - // Should match exactly even with special characters - isVisible := isBucketVisibleToIdentity(entry, identity) - - assert.True(t, isVisible, "Should match owner ID with special characters exactly") - }) - - t.Run("owner id with unicode characters", func(t *testing.T) { - unicodeOwnerId := "用户123" - entry := &filer_pb.Entry{ - Name: "test-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte(unicodeOwnerId), - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - identity := mockIdentity(unicodeOwnerId, false) - - isVisible := isBucketVisibleToIdentity(entry, identity) - - assert.True(t, isVisible, "Should handle unicode owner IDs correctly") - }) - - t.Run("owner id with binary data", func(t *testing.T) { - entry := &filer_pb.Entry{ - Name: "test-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte{0x00, 0x01, 0x02, 0xFF}, - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - identity := mockIdentity("normaluser", false) - - // Should not panic when converting binary data to string - assert.NotPanics(t, func() { - isVisible := isBucketVisibleToIdentity(entry, identity) - assert.False(t, isVisible, "Binary owner ID should not match normal user") - }) - }) - - t.Run("empty owner id in Extended", func(t *testing.T) { - entry := &filer_pb.Entry{ - Name: "test-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte(""), - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - identity := mockIdentity("user1", false) - - isVisible := isBucketVisibleToIdentity(entry, identity) - - assert.False(t, isVisible, "Empty owner ID should be treated as unowned (hidden from non-admins)") - }) - - t.Run("nil Extended map safe access", func(t *testing.T) { - entry := &filer_pb.Entry{ - Name: "test-bucket", - IsDirectory: true, - Extended: nil, // Explicitly nil - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - identity := mockIdentity("user1", false) - - // Should not panic with nil Extended map - assert.NotPanics(t, func() { - isVisible := isBucketVisibleToIdentity(entry, identity) - assert.False(t, isVisible, "Nil Extended (no owner) should be hidden from non-admins") - }) - }) - - t.Run("very long owner id", func(t *testing.T) { - longOwnerId := strings.Repeat("a", 10000) - entry := &filer_pb.Entry{ - Name: "test-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte(longOwnerId), - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - identity := mockIdentity(longOwnerId, false) - - // Should handle very long owner IDs without panic - assert.NotPanics(t, func() { - isVisible := isBucketVisibleToIdentity(entry, identity) - assert.True(t, isVisible, "Long owner ID should match correctly") - }) - }) -} - -// TestListBucketsOwnershipWithPermissions tests that ownership filtering -// works in conjunction with permission checks -func TestListBucketsOwnershipWithPermissions(t *testing.T) { - t.Run("ownership check before permission check", func(t *testing.T) { - // Simulate scenario where ownership check filters first, - // then permission check applies to remaining buckets - entries := []*filer_pb.Entry{ - { - Name: "owned-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("user1"), - }, - Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()}, - }, - { - Name: "other-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("user2"), - }, - Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()}, - }, - } - - identity := mockIdentity("user1", false) - - // First pass: ownership filtering - var afterOwnershipFilter []*filer_pb.Entry - for _, entry := range entries { - if isBucketVisibleToIdentity(entry, identity) { - afterOwnershipFilter = append(afterOwnershipFilter, entry) - } - } - - // Only owned-bucket should remain after ownership filter - assert.Len(t, afterOwnershipFilter, 1, "Only owned bucket should pass ownership filter") - assert.Equal(t, "owned-bucket", afterOwnershipFilter[0].Name) - - // Permission checks would apply to afterOwnershipFilter entries - // (not tested here as it depends on IAM system) - }) - - t.Run("admin bypasses ownership but not permissions", func(t *testing.T) { - entries := []*filer_pb.Entry{ - { - Name: "user1-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("user1"), - }, - Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()}, - }, - { - Name: "user2-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("user2"), - }, - Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()}, - }, - } - - identity := mockIdentity("admin-user", true) - - // Admin bypasses ownership check - var afterOwnershipFilter []*filer_pb.Entry - for _, entry := range entries { - if isBucketVisibleToIdentity(entry, identity) { - afterOwnershipFilter = append(afterOwnershipFilter, entry) - } - } - - // Admin should see all buckets after ownership filter - assert.Len(t, afterOwnershipFilter, 2, "Admin should see all buckets after ownership filter") - // Note: Permission checks still apply to admins in actual implementation - }) -} - -// TestListBucketsOwnershipCaseSensitivity tests case sensitivity in owner matching -func TestListBucketsOwnershipCaseSensitivity(t *testing.T) { - entry := &filer_pb.Entry{ - Name: "test-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("User1"), - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - testCases := []struct { - requestIdentityId string - shouldMatch bool - }{ - {"User1", true}, - {"user1", false}, // Case sensitive - {"USER1", false}, // Case sensitive - {"User2", false}, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("identity_%s", tc.requestIdentityId), func(t *testing.T) { - identity := mockIdentity(tc.requestIdentityId, false) - isVisible := isBucketVisibleToIdentity(entry, identity) - - if tc.shouldMatch { - assert.True(t, isVisible, "Identity %s should match (case sensitive)", tc.requestIdentityId) - } else { - assert.False(t, isVisible, "Identity %s should not match (case sensitive)", tc.requestIdentityId) - } - }) - } -} - -// TestListBucketsIssue7647 reproduces and verifies the fix for issue #7647 -// where an admin user with proper permissions could create buckets but couldn't list them -func TestListBucketsIssue7647(t *testing.T) { - t.Run("admin user can see their created buckets", func(t *testing.T) { - // Simulate the exact scenario from issue #7647: - // User "root" with ["Admin", "Read", "Write", "Tagging", "List"] permissions - - // Create identity for root user with Admin action - rootIdentity := &Identity{ - Name: "root", - Credentials: []*Credential{ - { - AccessKey: "ROOTID", - SecretKey: "ROOTSECRET", - }, - }, - Actions: []Action{ - s3_constants.ACTION_ADMIN, - s3_constants.ACTION_READ, - s3_constants.ACTION_WRITE, - s3_constants.ACTION_TAGGING, - s3_constants.ACTION_LIST, - }, - } - - // Create a bucket entry as if it was created by the root user - bucketEntry := &filer_pb.Entry{ - Name: "test", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("root"), - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - Mtime: time.Now().Unix(), - }, - } - - // Test bucket visibility - should be visible to root (owner) - isVisible := isBucketVisibleToIdentity(bucketEntry, rootIdentity) - assert.True(t, isVisible, "Root user should see their own bucket") - - // Test that admin can also see buckets they don't own - otherUserBucket := &filer_pb.Entry{ - Name: "other-bucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("otheruser"), - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - Mtime: time.Now().Unix(), - }, - } - - isVisible = isBucketVisibleToIdentity(otherUserBucket, rootIdentity) - assert.True(t, isVisible, "Admin user should see all buckets, even ones they don't own") - - // Test permission check for List action - canList := rootIdentity.CanDo(s3_constants.ACTION_LIST, "test", "") - assert.True(t, canList, "Root user with List action should be able to list buckets") - }) - - t.Run("admin user sees buckets without owner metadata", func(t *testing.T) { - // Admin users should see buckets even if they don't have owner metadata - // (this can happen with legacy buckets or manual creation) - - rootIdentity := &Identity{ - Name: "root", - Actions: []Action{ - s3_constants.ACTION_ADMIN, - s3_constants.ACTION_LIST, - }, - } - - bucketWithoutOwner := &filer_pb.Entry{ - Name: "legacy-bucket", - IsDirectory: true, - Extended: map[string][]byte{}, // No owner metadata - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - isVisible := isBucketVisibleToIdentity(bucketWithoutOwner, rootIdentity) - assert.True(t, isVisible, "Admin should see buckets without owner metadata") - }) - - t.Run("non-admin user cannot see buckets without owner", func(t *testing.T) { - // Non-admin users should not see buckets without owner metadata - - regularUser := &Identity{ - Name: "user1", - Actions: []Action{ - s3_constants.ACTION_READ, - s3_constants.ACTION_LIST, - }, - } - - bucketWithoutOwner := &filer_pb.Entry{ - Name: "legacy-bucket", - IsDirectory: true, - Extended: map[string][]byte{}, // No owner metadata - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - isVisible := isBucketVisibleToIdentity(bucketWithoutOwner, regularUser) - assert.False(t, isVisible, "Non-admin should not see buckets without owner metadata") - }) -} - -// TestListBucketsIssue7796 reproduces and verifies the fix for issue #7796 -// where a user with bucket-specific List permission (e.g., "List:geoserver") -// couldn't see buckets they have access to but don't own -func TestListBucketsIssue7796(t *testing.T) { - t.Run("user with bucket-specific List permission can see bucket they don't own", func(t *testing.T) { - // Simulate the exact scenario from issue #7796: - // User "geoserver" with ["List:geoserver", "Read:geoserver", "Write:geoserver", ...] permissions - // But the bucket "geoserver" was created by a different user (e.g., admin) - - geoserverIdentity := &Identity{ - Name: "geoserver", - Credentials: []*Credential{ - { - AccessKey: "geoserver", - SecretKey: "secret", - }, - }, - Actions: []Action{ - Action("List:geoserver"), - Action("Read:geoserver"), - Action("Write:geoserver"), - Action("Admin:geoserver"), - Action("List:geoserver-ttl"), - Action("Read:geoserver-ttl"), - Action("Write:geoserver-ttl"), - }, - } - - // Bucket was created by admin, not by geoserver user - geoserverBucket := &filer_pb.Entry{ - Name: "geoserver", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("admin"), // Different owner - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - Mtime: time.Now().Unix(), - }, - } - - // Test ownership check - should return false (not owned by geoserver) - isOwner := isBucketOwnedByIdentity(geoserverBucket, geoserverIdentity) - assert.False(t, isOwner, "geoserver user should not be owner of bucket created by admin") - - // Test permission check - should return true (has List:geoserver permission) - canList := geoserverIdentity.CanDo(s3_constants.ACTION_LIST, "geoserver", "") - assert.True(t, canList, "geoserver user with List:geoserver should be able to list geoserver bucket") - - // Verify the combined visibility logic: ownership OR permission - isVisible := isOwner || canList - assert.True(t, isVisible, "Bucket should be visible due to permission (even though not owner)") - }) - - t.Run("user with bucket-specific permission sees bucket without owner metadata", func(t *testing.T) { - // Bucket exists but has no owner metadata (legacy bucket or created before ownership tracking) - - geoserverIdentity := &Identity{ - Name: "geoserver", - Actions: []Action{ - Action("List:geoserver"), - Action("Read:geoserver"), - }, - } - - bucketWithoutOwner := &filer_pb.Entry{ - Name: "geoserver", - IsDirectory: true, - Extended: map[string][]byte{}, // No owner metadata - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - // Not owner (no owner metadata) - isOwner := isBucketOwnedByIdentity(bucketWithoutOwner, geoserverIdentity) - assert.False(t, isOwner, "No owner metadata means not owned") - - // But has permission - canList := geoserverIdentity.CanDo(s3_constants.ACTION_LIST, "geoserver", "") - assert.True(t, canList, "Has explicit List:geoserver permission") - - // Verify the combined visibility logic: ownership OR permission - isVisible := isOwner || canList - assert.True(t, isVisible, "Bucket should be visible due to permission (even without owner metadata)") - }) - - t.Run("user cannot see bucket they neither own nor have permission for", func(t *testing.T) { - // User has no ownership and no permission for the bucket - - geoserverIdentity := &Identity{ - Name: "geoserver", - Actions: []Action{ - Action("List:geoserver"), - Action("Read:geoserver"), - }, - } - - otherBucket := &filer_pb.Entry{ - Name: "otherbucket", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("admin"), - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - // Not owner - isOwner := isBucketOwnedByIdentity(otherBucket, geoserverIdentity) - assert.False(t, isOwner, "geoserver doesn't own otherbucket") - - // No permission for this bucket - canList := geoserverIdentity.CanDo(s3_constants.ACTION_LIST, "otherbucket", "") - assert.False(t, canList, "geoserver has no List permission for otherbucket") - - // Verify the combined visibility logic: ownership OR permission - isVisible := isOwner || canList - assert.False(t, isVisible, "Bucket should NOT be visible (neither owner nor has permission)") - }) - - t.Run("user with wildcard permission sees matching buckets", func(t *testing.T) { - // User has "List:geo*" permission - should see any bucket starting with "geo" - - geoIdentity := &Identity{ - Name: "geouser", - Actions: []Action{ - Action("List:geo*"), - Action("Read:geo*"), - }, - } - - geoBucket := &filer_pb.Entry{ - Name: "geoserver", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("admin"), - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - geoTTLBucket := &filer_pb.Entry{ - Name: "geoserver-ttl", - IsDirectory: true, - Extended: map[string][]byte{ - s3_constants.AmzIdentityId: []byte("admin"), - }, - Attributes: &filer_pb.FuseAttributes{ - Crtime: time.Now().Unix(), - }, - } - - // Not owner of either bucket - isOwnerGeo := isBucketOwnedByIdentity(geoBucket, geoIdentity) - isOwnerGeoTTL := isBucketOwnedByIdentity(geoTTLBucket, geoIdentity) - assert.False(t, isOwnerGeo) - assert.False(t, isOwnerGeoTTL) - - // But has permission via wildcard - canListGeo := geoIdentity.CanDo(s3_constants.ACTION_LIST, "geoserver", "") - canListGeoTTL := geoIdentity.CanDo(s3_constants.ACTION_LIST, "geoserver-ttl", "") - assert.True(t, canListGeo) - assert.True(t, canListGeoTTL) - - // Verify the combined visibility logic for matching buckets - assert.True(t, isOwnerGeo || canListGeo, "geoserver bucket should be visible via wildcard permission") - assert.True(t, isOwnerGeoTTL || canListGeoTTL, "geoserver-ttl bucket should be visible via wildcard permission") - - // Should NOT have permission for unrelated buckets - canListOther := geoIdentity.CanDo(s3_constants.ACTION_LIST, "otherbucket", "") - assert.False(t, canListOther, "No permission for otherbucket") - assert.False(t, false || canListOther, "otherbucket should NOT be visible (no ownership, no permission)") - }) - - t.Run("integration test: complete handler filtering logic", func(t *testing.T) { - // This test simulates the complete filtering logic as used in ListBucketsHandler - // to verify that the combination of ownership OR permission check works correctly - - // User "geoserver" with bucket-specific permissions (same as issue #7796) - geoserverIdentity := &Identity{ - Name: "geoserver", - Credentials: []*Credential{ - {AccessKey: "geoserver", SecretKey: "secret"}, - }, - Actions: []Action{ - Action("List:geoserver"), - Action("Read:geoserver"), - Action("Write:geoserver"), - Action("Admin:geoserver"), - Action("List:geoserver-ttl"), - Action("Read:geoserver-ttl"), - Action("Write:geoserver-ttl"), - }, - } - - // Create test buckets with various ownership scenarios - buckets := []*filer_pb.Entry{ - { - // Bucket owned by admin but geoserver has permission - Name: "geoserver", - IsDirectory: true, - Extended: map[string][]byte{s3_constants.AmzIdentityId: []byte("admin")}, - Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()}, - }, - { - // Bucket with no owner but geoserver has permission - Name: "geoserver-ttl", - IsDirectory: true, - Extended: map[string][]byte{}, - Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()}, - }, - { - // Bucket owned by geoserver (should be visible via ownership) - Name: "geoserver-owned", - IsDirectory: true, - Extended: map[string][]byte{s3_constants.AmzIdentityId: []byte("geoserver")}, - Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()}, - }, - { - // Bucket owned by someone else, no permission for geoserver - Name: "otherbucket", - IsDirectory: true, - Extended: map[string][]byte{s3_constants.AmzIdentityId: []byte("otheruser")}, - Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()}, - }, - } - - // Simulate the exact filtering logic from ListBucketsHandler - var visibleBuckets []string - for _, entry := range buckets { - if !entry.IsDirectory { - continue - } - - // Check ownership - isOwner := isBucketOwnedByIdentity(entry, geoserverIdentity) - - // Skip permission check if user is already the owner (optimization) - if !isOwner { - // Check permission - hasPermission := geoserverIdentity.CanDo(s3_constants.ACTION_LIST, entry.Name, "") - if !hasPermission { - continue - } - } - - visibleBuckets = append(visibleBuckets, entry.Name) - } - - // Expected: geoserver should see: - // - "geoserver" (has List:geoserver permission, even though owned by admin) - // - "geoserver-ttl" (has List:geoserver-ttl permission, even though no owner) - // - "geoserver-owned" (owns this bucket) - // NOT "otherbucket" (neither owns nor has permission) - expectedBuckets := []string{"geoserver", "geoserver-ttl", "geoserver-owned"} - assert.ElementsMatch(t, expectedBuckets, visibleBuckets, - "geoserver should see buckets they own OR have permission for") - - // Verify "otherbucket" is NOT in the list - assert.NotContains(t, visibleBuckets, "otherbucket", - "geoserver should NOT see buckets they neither own nor have permission for") - }) -} - -func TestListBucketsIssue8516PolicyBasedVisibility(t *testing.T) { - iam := &IdentityAccessManagement{} - require.NoError(t, iam.PutPolicy("listOnly", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:ListBucket","Resource":"arn:aws:s3:::policy-bucket"}]}`)) - - identity := &Identity{ - Name: "policy-user", - Account: &AccountAdmin, - PolicyNames: []string{"listOnly"}, - } - - req := httptest.NewRequest("GET", "http://s3.amazonaws.com/", nil) - buckets := []*filer_pb.Entry{ - { - Name: "policy-bucket", - IsDirectory: true, - Extended: map[string][]byte{s3_constants.AmzIdentityId: []byte("admin")}, - Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()}, - }, - { - Name: "other-bucket", - IsDirectory: true, - Extended: map[string][]byte{s3_constants.AmzIdentityId: []byte("admin")}, - Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()}, - }, - } - - var visibleBuckets []string - for _, entry := range buckets { - isOwner := isBucketOwnedByIdentity(entry, identity) - if !isOwner { - if errCode := iam.VerifyActionPermission(req, identity, s3_constants.ACTION_LIST, entry.Name, ""); errCode != s3err.ErrNone { - continue - } - } - visibleBuckets = append(visibleBuckets, entry.Name) - } - - assert.Equal(t, []string{"policy-bucket"}, visibleBuckets) -} diff --git a/weed/s3api/s3api_conditional_headers_test.go b/weed/s3api/s3api_conditional_headers_test.go deleted file mode 100644 index 9cd220603..000000000 --- a/weed/s3api/s3api_conditional_headers_test.go +++ /dev/null @@ -1,984 +0,0 @@ -package s3api - -import ( - "bytes" - "encoding/hex" - "fmt" - "net/http" - "net/url" - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" -) - -// TestConditionalHeadersWithExistingObjects tests conditional headers against existing objects -// This addresses the PR feedback about missing test coverage for object existence scenarios -func TestConditionalHeadersWithExistingObjects(t *testing.T) { - bucket := "test-bucket" - object := "/test-object" - - // Mock object with known ETag and modification time - testObject := &filer_pb.Entry{ - Name: "test-object", - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte("\"abc123\""), - }, - Attributes: &filer_pb.FuseAttributes{ - Mtime: time.Date(2024, 6, 15, 12, 0, 0, 0, time.UTC).Unix(), // June 15, 2024 - FileSize: 1024, // Add file size - }, - Chunks: []*filer_pb.FileChunk{ - // Add a mock chunk to make calculateETagFromChunks work - { - FileId: "test-file-id", - Offset: 0, - Size: 1024, - }, - }, - } - - // Test If-None-Match with existing object - t.Run("IfNoneMatch_ObjectExists", func(t *testing.T) { - // Test case 1: If-None-Match=* when object exists (should fail) - t.Run("Asterisk_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfNoneMatch, "*") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when object exists with If-None-Match=*, got %v", errCode) - } - }) - - // Test case 2: If-None-Match with matching ETag (should fail) - t.Run("MatchingETag_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfNoneMatch, "\"abc123\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when ETag matches, got %v", errCode) - } - }) - - // Test case 3: If-None-Match with non-matching ETag (should succeed) - t.Run("NonMatchingETag_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfNoneMatch, "\"xyz789\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when ETag doesn't match, got %v", errCode) - } - }) - - // Test case 4: If-None-Match with multiple ETags, one matching (should fail) - t.Run("MultipleETags_OneMatches_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfNoneMatch, "\"xyz789\", \"abc123\", \"def456\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when one ETag matches, got %v", errCode) - } - }) - - // Test case 5: If-None-Match with multiple ETags, none matching (should succeed) - t.Run("MultipleETags_NoneMatch_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfNoneMatch, "\"xyz789\", \"def456\", \"ghi123\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when no ETags match, got %v", errCode) - } - }) - }) - - // Test If-Match with existing object - t.Run("IfMatch_ObjectExists", func(t *testing.T) { - // Test case 1: If-Match with matching ETag (should succeed) - t.Run("MatchingETag_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfMatch, "\"abc123\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when ETag matches, got %v", errCode) - } - }) - - // Test case 2: If-Match with non-matching ETag (should fail) - t.Run("NonMatchingETag_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfMatch, "\"xyz789\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when ETag doesn't match, got %v", errCode) - } - }) - - // Test case 3: If-Match with multiple ETags, one matching (should succeed) - t.Run("MultipleETags_OneMatches_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfMatch, "\"xyz789\", \"abc123\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when one ETag matches, got %v", errCode) - } - }) - - // Test case 4: If-Match with wildcard * (should succeed if object exists) - t.Run("Wildcard_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfMatch, "*") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when If-Match=* and object exists, got %v", errCode) - } - }) - }) - - // Test If-Modified-Since with existing object - t.Run("IfModifiedSince_ObjectExists", func(t *testing.T) { - // Test case 1: If-Modified-Since with date before object modification (should succeed) - t.Run("DateBefore_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - dateBeforeModification := time.Date(2024, 6, 14, 12, 0, 0, 0, time.UTC) - req.Header.Set(s3_constants.IfModifiedSince, dateBeforeModification.Format(time.RFC1123)) - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when object was modified after date, got %v", errCode) - } - }) - - // Test case 2: If-Modified-Since with date after object modification (should fail) - t.Run("DateAfter_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - dateAfterModification := time.Date(2024, 6, 16, 12, 0, 0, 0, time.UTC) - req.Header.Set(s3_constants.IfModifiedSince, dateAfterModification.Format(time.RFC1123)) - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when object wasn't modified since date, got %v", errCode) - } - }) - - // Test case 3: If-Modified-Since with exact modification date (should fail - not after) - t.Run("ExactDate_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - exactDate := time.Date(2024, 6, 15, 12, 0, 0, 0, time.UTC) - req.Header.Set(s3_constants.IfModifiedSince, exactDate.Format(time.RFC1123)) - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when object modification time equals header date, got %v", errCode) - } - }) - }) - - // Test If-Unmodified-Since with existing object - t.Run("IfUnmodifiedSince_ObjectExists", func(t *testing.T) { - // Test case 1: If-Unmodified-Since with date after object modification (should succeed) - t.Run("DateAfter_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - dateAfterModification := time.Date(2024, 6, 16, 12, 0, 0, 0, time.UTC) - req.Header.Set(s3_constants.IfUnmodifiedSince, dateAfterModification.Format(time.RFC1123)) - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when object wasn't modified after date, got %v", errCode) - } - }) - - // Test case 2: If-Unmodified-Since with date before object modification (should fail) - t.Run("DateBefore_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(testObject) - req := createTestPutRequest(bucket, object, "test content") - dateBeforeModification := time.Date(2024, 6, 14, 12, 0, 0, 0, time.UTC) - req.Header.Set(s3_constants.IfUnmodifiedSince, dateBeforeModification.Format(time.RFC1123)) - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when object was modified after date, got %v", errCode) - } - }) - }) -} - -// TestConditionalHeadersForReads tests conditional headers for read operations (GET, HEAD) -// This implements AWS S3 conditional reads behavior where different conditions return different status codes -// See: https://docs.aws.amazon.com/AmazonS3/latest/userguide/conditional-reads.html -func TestConditionalHeadersForReads(t *testing.T) { - bucket := "test-bucket" - object := "/test-read-object" - - // Mock existing object to test conditional headers against - existingObject := &filer_pb.Entry{ - Name: "test-read-object", - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte("\"read123\""), - }, - Attributes: &filer_pb.FuseAttributes{ - Mtime: time.Date(2024, 6, 15, 12, 0, 0, 0, time.UTC).Unix(), - FileSize: 1024, - }, - Chunks: []*filer_pb.FileChunk{ - { - FileId: "read-file-id", - Offset: 0, - Size: 1024, - }, - }, - } - - // Test conditional reads with existing object - t.Run("ConditionalReads_ObjectExists", func(t *testing.T) { - // Test If-None-Match with existing object (should return 304 Not Modified) - t.Run("IfNoneMatch_ObjectExists_ShouldReturn304", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfNoneMatch, "\"read123\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrNotModified { - t.Errorf("Expected ErrNotModified when If-None-Match matches, got %v", errCode) - } - }) - - // Test If-None-Match=* with existing object (should return 304 Not Modified) - t.Run("IfNoneMatchAsterisk_ObjectExists_ShouldReturn304", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfNoneMatch, "*") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrNotModified { - t.Errorf("Expected ErrNotModified when If-None-Match=* with existing object, got %v", errCode) - } - }) - - // Test If-None-Match with non-matching ETag (should succeed) - t.Run("IfNoneMatch_NonMatchingETag_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfNoneMatch, "\"different-etag\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when If-None-Match doesn't match, got %v", errCode) - } - }) - - // Test If-Match with matching ETag (should succeed) - t.Run("IfMatch_MatchingETag_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfMatch, "\"read123\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when If-Match matches, got %v", errCode) - } - }) - - // Test If-Match with non-matching ETag (should return 412 Precondition Failed) - t.Run("IfMatch_NonMatchingETag_ShouldReturn412", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfMatch, "\"different-etag\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when If-Match doesn't match, got %v", errCode) - } - }) - - // Test If-Match=* with existing object (should succeed) - t.Run("IfMatchAsterisk_ObjectExists_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfMatch, "*") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when If-Match=* with existing object, got %v", errCode) - } - }) - - // Test If-Modified-Since (object modified after date - should succeed) - t.Run("IfModifiedSince_ObjectModifiedAfter_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfModifiedSince, "Sat, 14 Jun 2024 12:00:00 GMT") // Before object mtime - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when object modified after If-Modified-Since date, got %v", errCode) - } - }) - - // Test If-Modified-Since (object not modified since date - should return 304) - t.Run("IfModifiedSince_ObjectNotModified_ShouldReturn304", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfModifiedSince, "Sun, 16 Jun 2024 12:00:00 GMT") // After object mtime - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrNotModified { - t.Errorf("Expected ErrNotModified when object not modified since If-Modified-Since date, got %v", errCode) - } - }) - - // Test If-Unmodified-Since (object not modified since date - should succeed) - t.Run("IfUnmodifiedSince_ObjectNotModified_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfUnmodifiedSince, "Sun, 16 Jun 2024 12:00:00 GMT") // After object mtime - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when object not modified since If-Unmodified-Since date, got %v", errCode) - } - }) - - // Test If-Unmodified-Since (object modified since date - should return 412) - t.Run("IfUnmodifiedSince_ObjectModified_ShouldReturn412", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfUnmodifiedSince, "Fri, 14 Jun 2024 12:00:00 GMT") // Before object mtime - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when object modified since If-Unmodified-Since date, got %v", errCode) - } - }) - }) - - // Test conditional reads with non-existent object - t.Run("ConditionalReads_ObjectNotExists", func(t *testing.T) { - // Test If-None-Match with non-existent object (should succeed) - t.Run("IfNoneMatch_ObjectNotExists_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfNoneMatch, "\"any-etag\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when object doesn't exist with If-None-Match, got %v", errCode) - } - }) - - // Test If-Match with non-existent object (should return 412) - t.Run("IfMatch_ObjectNotExists_ShouldReturn412", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfMatch, "\"any-etag\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when object doesn't exist with If-Match, got %v", errCode) - } - }) - - // Test If-Modified-Since with non-existent object (should succeed) - t.Run("IfModifiedSince_ObjectNotExists_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfModifiedSince, "Sat, 15 Jun 2024 12:00:00 GMT") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when object doesn't exist with If-Modified-Since, got %v", errCode) - } - }) - - // Test If-Unmodified-Since with non-existent object (should return 412) - t.Run("IfUnmodifiedSince_ObjectNotExists_ShouldReturn412", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object - - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfUnmodifiedSince, "Sat, 15 Jun 2024 12:00:00 GMT") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if errCode.ErrorCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when object doesn't exist with If-Unmodified-Since, got %v", errCode) - } - }) - }) -} - -// Helper function to create a GET request for testing -func createTestGetRequest(bucket, object string) *http.Request { - return &http.Request{ - Method: "GET", - Header: make(http.Header), - URL: &url.URL{ - Path: fmt.Sprintf("/%s/%s", bucket, object), - }, - } -} - -// TestConditionalHeadersWithNonExistentObjects tests the original scenarios (object doesn't exist) -func TestConditionalHeadersWithNonExistentObjects(t *testing.T) { - s3a := NewS3ApiServerForTest() - if s3a == nil { - t.Skip("S3ApiServer not available for testing") - } - - bucket := "test-bucket" - object := "/test-object" - - // Test If-None-Match header when object doesn't exist - t.Run("IfNoneMatch_ObjectDoesNotExist", func(t *testing.T) { - // Test case 1: If-None-Match=* when object doesn't exist (should return ErrNone) - t.Run("Asterisk_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object exists - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfNoneMatch, "*") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when object doesn't exist, got %v", errCode) - } - }) - - // Test case 2: If-None-Match with specific ETag when object doesn't exist - t.Run("SpecificETag_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object exists - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfNoneMatch, "\"some-etag\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when object doesn't exist, got %v", errCode) - } - }) - }) - - // Test If-Match header when object doesn't exist - t.Run("IfMatch_ObjectDoesNotExist", func(t *testing.T) { - // Test case 1: If-Match with specific ETag when object doesn't exist (should fail - critical bug fix) - t.Run("SpecificETag_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object exists - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfMatch, "\"some-etag\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when object doesn't exist with If-Match header, got %v", errCode) - } - }) - - // Test case 2: If-Match with wildcard * when object doesn't exist (should fail) - t.Run("Wildcard_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object exists - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfMatch, "*") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when object doesn't exist with If-Match=*, got %v", errCode) - } - }) - }) - - // Test date format validation (works regardless of object existence) - t.Run("DateFormatValidation", func(t *testing.T) { - // Test case 1: Valid If-Modified-Since date format - t.Run("IfModifiedSince_ValidFormat", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object exists - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfModifiedSince, time.Now().Format(time.RFC1123)) - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone with valid date format, got %v", errCode) - } - }) - - // Test case 2: Invalid If-Modified-Since date format - t.Run("IfModifiedSince_InvalidFormat", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object exists - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfModifiedSince, "invalid-date") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrInvalidRequest { - t.Errorf("Expected ErrInvalidRequest for invalid date format, got %v", errCode) - } - }) - - // Test case 3: Invalid If-Unmodified-Since date format - t.Run("IfUnmodifiedSince_InvalidFormat", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object exists - req := createTestPutRequest(bucket, object, "test content") - req.Header.Set(s3_constants.IfUnmodifiedSince, "invalid-date") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrInvalidRequest { - t.Errorf("Expected ErrInvalidRequest for invalid date format, got %v", errCode) - } - }) - }) - - // Test no conditional headers - t.Run("NoConditionalHeaders", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No object exists - req := createTestPutRequest(bucket, object, "test content") - // Don't set any conditional headers - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when no conditional headers, got %v", errCode) - } - }) -} - -// TestETagMatching tests the etagMatches helper function -func TestETagMatching(t *testing.T) { - s3a := NewS3ApiServerForTest() - if s3a == nil { - t.Skip("S3ApiServer not available for testing") - } - - testCases := []struct { - name string - headerValue string - objectETag string - expected bool - }{ - { - name: "ExactMatch", - headerValue: "\"abc123\"", - objectETag: "abc123", - expected: true, - }, - { - name: "ExactMatchWithQuotes", - headerValue: "\"abc123\"", - objectETag: "\"abc123\"", - expected: true, - }, - { - name: "NoMatch", - headerValue: "\"abc123\"", - objectETag: "def456", - expected: false, - }, - { - name: "MultipleETags_FirstMatch", - headerValue: "\"abc123\", \"def456\"", - objectETag: "abc123", - expected: true, - }, - { - name: "MultipleETags_SecondMatch", - headerValue: "\"abc123\", \"def456\"", - objectETag: "def456", - expected: true, - }, - { - name: "MultipleETags_NoMatch", - headerValue: "\"abc123\", \"def456\"", - objectETag: "ghi789", - expected: false, - }, - { - name: "WithSpaces", - headerValue: " \"abc123\" , \"def456\" ", - objectETag: "def456", - expected: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := s3a.etagMatches(tc.headerValue, tc.objectETag) - if result != tc.expected { - t.Errorf("Expected %v, got %v for headerValue='%s', objectETag='%s'", - tc.expected, result, tc.headerValue, tc.objectETag) - } - }) - } -} - -// TestGetObjectETagWithMd5AndChunks tests the fix for issue #7274 -// When an object has both Attributes.Md5 and multiple chunks, getObjectETag should -// prefer Attributes.Md5 to match the behavior of HeadObject and filer.ETag -func TestGetObjectETagWithMd5AndChunks(t *testing.T) { - s3a := NewS3ApiServerForTest() - if s3a == nil { - t.Skip("S3ApiServer not available for testing") - } - - // Create an object with both Md5 and multiple chunks (like in issue #7274) - // Md5: ZjcmMwrCVGNVgb4HoqHe9g== (base64) = 663726330ac254635581be07a2a1def6 (hex) - md5HexString := "663726330ac254635581be07a2a1def6" - md5Bytes, err := hex.DecodeString(md5HexString) - if err != nil { - t.Fatalf("failed to decode md5 hex string: %v", err) - } - - entry := &filer_pb.Entry{ - Name: "test-multipart-object", - Attributes: &filer_pb.FuseAttributes{ - Mtime: time.Now().Unix(), - FileSize: 5597744, - Md5: md5Bytes, - }, - // Two chunks - if we only used ETagChunks, it would return format "hash-2" - Chunks: []*filer_pb.FileChunk{ - { - FileId: "chunk1", - Offset: 0, - Size: 4194304, - ETag: "9+yCD2DGwMG5uKwAd+y04Q==", - }, - { - FileId: "chunk2", - Offset: 4194304, - Size: 1403440, - ETag: "cs6SVSTgZ8W3IbIrAKmklg==", - }, - }, - } - - // getObjectETag should return the Md5 in hex with quotes - expectedETag := "\"" + md5HexString + "\"" - actualETag := s3a.getObjectETag(entry) - - if actualETag != expectedETag { - t.Errorf("Expected ETag %s, got %s", expectedETag, actualETag) - } - - // Now test that conditional headers work with this ETag - bucket := "test-bucket" - object := "/test-object" - - // Test If-Match with the Md5-based ETag (should succeed) - t.Run("IfMatch_WithMd5BasedETag_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(entry) - req := createTestGetRequest(bucket, object) - // Client sends the ETag from HeadObject (without quotes) - req.Header.Set(s3_constants.IfMatch, md5HexString) - - result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if result.ErrorCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when If-Match uses Md5-based ETag, got %v (ETag was %s)", result.ErrorCode, actualETag) - } - }) - - // Test If-Match with chunk-based ETag format (should fail - this was the old incorrect behavior) - t.Run("IfMatch_WithChunkBasedETag_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(entry) - req := createTestGetRequest(bucket, object) - // If we incorrectly calculated ETag from chunks, it would be in format "hash-2" - req.Header.Set(s3_constants.IfMatch, "123294de680f28bde364b81477549f7d-2") - - result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if result.ErrorCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when If-Match uses chunk-based ETag format, got %v", result.ErrorCode) - } - }) -} - -// TestConditionalHeadersIntegration tests conditional headers with full integration -func TestConditionalHeadersIntegration(t *testing.T) { - // This would be a full integration test that requires a running SeaweedFS instance - t.Skip("Integration test - requires running SeaweedFS instance") -} - -// createTestPutRequest creates a test HTTP PUT request -func createTestPutRequest(bucket, object, content string) *http.Request { - req, _ := http.NewRequest("PUT", "/"+bucket+object, bytes.NewReader([]byte(content))) - req.Header.Set("Content-Type", "application/octet-stream") - - // Set up mux vars to simulate the bucket and object extraction - // In real tests, this would be handled by the gorilla mux router - return req -} - -// NewS3ApiServerForTest creates a minimal S3ApiServer for testing -// Note: This is a simplified version for unit testing conditional logic -func NewS3ApiServerForTest() *S3ApiServer { - // In a real test environment, this would set up a proper S3ApiServer - // with filer connection, etc. For unit testing conditional header logic, - // we create a minimal instance - return &S3ApiServer{ - option: &S3ApiServerOption{ - BucketsPath: "/buckets", - }, - bucketConfigCache: NewBucketConfigCache(60 * time.Minute), - } -} - -// MockEntryGetter implements the simplified EntryGetter interface for testing -// Only mocks the data access dependency - tests use production getObjectETag and etagMatches -type MockEntryGetter struct { - mockEntry *filer_pb.Entry -} - -// Implement only the simplified EntryGetter interface -func (m *MockEntryGetter) getEntry(parentDirectoryPath, entryName string) (*filer_pb.Entry, error) { - if m.mockEntry != nil { - return m.mockEntry, nil - } - return nil, filer_pb.ErrNotFound -} - -// createMockEntryGetter creates a mock EntryGetter for testing -func createMockEntryGetter(mockEntry *filer_pb.Entry) *MockEntryGetter { - return &MockEntryGetter{ - mockEntry: mockEntry, - } -} - -// TestConditionalHeadersMultipartUpload tests conditional headers with multipart uploads -// This verifies AWS S3 compatibility where conditional headers only apply to CompleteMultipartUpload -func TestConditionalHeadersMultipartUpload(t *testing.T) { - bucket := "test-bucket" - object := "/test-multipart-object" - - // Mock existing object to test conditional headers against - existingObject := &filer_pb.Entry{ - Name: "test-multipart-object", - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte("\"existing123\""), - }, - Attributes: &filer_pb.FuseAttributes{ - Mtime: time.Date(2024, 6, 15, 12, 0, 0, 0, time.UTC).Unix(), - FileSize: 2048, - }, - Chunks: []*filer_pb.FileChunk{ - { - FileId: "existing-file-id", - Offset: 0, - Size: 2048, - }, - }, - } - - // Test CompleteMultipartUpload with If-None-Match: * (should fail when object exists) - t.Run("CompleteMultipartUpload_IfNoneMatchAsterisk_ObjectExists_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - // Create a mock CompleteMultipartUpload request with If-None-Match: * - req := &http.Request{ - Method: "POST", - Header: make(http.Header), - URL: &url.URL{ - RawQuery: "uploadId=test-upload-id", - }, - } - req.Header.Set(s3_constants.IfNoneMatch, "*") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when object exists with If-None-Match=*, got %v", errCode) - } - }) - - // Test CompleteMultipartUpload with If-None-Match: * (should succeed when object doesn't exist) - t.Run("CompleteMultipartUpload_IfNoneMatchAsterisk_ObjectNotExists_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No existing object - - req := &http.Request{ - Method: "POST", - Header: make(http.Header), - URL: &url.URL{ - RawQuery: "uploadId=test-upload-id", - }, - } - req.Header.Set(s3_constants.IfNoneMatch, "*") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when object doesn't exist with If-None-Match=*, got %v", errCode) - } - }) - - // Test CompleteMultipartUpload with If-Match (should succeed when ETag matches) - t.Run("CompleteMultipartUpload_IfMatch_ETagMatches_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := &http.Request{ - Method: "POST", - Header: make(http.Header), - URL: &url.URL{ - RawQuery: "uploadId=test-upload-id", - }, - } - req.Header.Set(s3_constants.IfMatch, "\"existing123\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when ETag matches, got %v", errCode) - } - }) - - // Test CompleteMultipartUpload with If-Match (should fail when object doesn't exist) - t.Run("CompleteMultipartUpload_IfMatch_ObjectNotExists_ShouldFail", func(t *testing.T) { - getter := createMockEntryGetter(nil) // No existing object - - req := &http.Request{ - Method: "POST", - Header: make(http.Header), - URL: &url.URL{ - RawQuery: "uploadId=test-upload-id", - }, - } - req.Header.Set(s3_constants.IfMatch, "\"any-etag\"") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Errorf("Expected ErrPreconditionFailed when object doesn't exist with If-Match, got %v", errCode) - } - }) - - // Test CompleteMultipartUpload with If-Match wildcard (should succeed when object exists) - t.Run("CompleteMultipartUpload_IfMatchWildcard_ObjectExists_ShouldSucceed", func(t *testing.T) { - getter := createMockEntryGetter(existingObject) - - req := &http.Request{ - Method: "POST", - Header: make(http.Header), - URL: &url.URL{ - RawQuery: "uploadId=test-upload-id", - }, - } - req.Header.Set(s3_constants.IfMatch, "*") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Errorf("Expected ErrNone when object exists with If-Match=*, got %v", errCode) - } - }) -} - -func TestConditionalHeadersTreatDeleteMarkerAsMissing(t *testing.T) { - bucket := "test-bucket" - object := "/deleted-object" - deleteMarkerEntry := &filer_pb.Entry{ - Name: "deleted-object", - Extended: map[string][]byte{ - s3_constants.ExtDeleteMarkerKey: []byte("true"), - }, - Attributes: &filer_pb.FuseAttributes{ - Mtime: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC).Unix(), - }, - } - - t.Run("WriteIfNoneMatchAsteriskSucceeds", func(t *testing.T) { - getter := createMockEntryGetter(deleteMarkerEntry) - req := createTestPutRequest(bucket, object, "new content") - req.Header.Set(s3_constants.IfNoneMatch, "*") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrNone { - t.Fatalf("expected ErrNone for delete marker with If-None-Match=*, got %v", errCode) - } - }) - - t.Run("WriteIfMatchAsteriskFails", func(t *testing.T) { - getter := createMockEntryGetter(deleteMarkerEntry) - req := createTestPutRequest(bucket, object, "new content") - req.Header.Set(s3_constants.IfMatch, "*") - - s3a := NewS3ApiServerForTest() - errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object) - if errCode != s3err.ErrPreconditionFailed { - t.Fatalf("expected ErrPreconditionFailed for delete marker with If-Match=*, got %v", errCode) - } - }) - - t.Run("ReadIfMatchAsteriskFails", func(t *testing.T) { - getter := createMockEntryGetter(deleteMarkerEntry) - req := &http.Request{Method: http.MethodGet, Header: make(http.Header)} - req.Header.Set(s3_constants.IfMatch, "*") - - s3a := NewS3ApiServerForTest() - result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - if result.ErrorCode != s3err.ErrPreconditionFailed { - t.Fatalf("expected ErrPreconditionFailed for read against delete marker with If-Match=*, got %v", result.ErrorCode) - } - if result.Entry != nil { - t.Fatalf("expected no entry to be returned for delete marker, got %#v", result.Entry) - } - }) -} diff --git a/weed/s3api/s3api_copy_size_calculation.go b/weed/s3api/s3api_copy_size_calculation.go index a11c46cdf..eb8bbf0d8 100644 --- a/weed/s3api/s3api_copy_size_calculation.go +++ b/weed/s3api/s3api_copy_size_calculation.go @@ -4,7 +4,6 @@ import ( "net/http" "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" ) // CopySizeCalculator handles size calculations for different copy scenarios @@ -78,12 +77,6 @@ func (calc *CopySizeCalculator) CalculateActualSize() int64 { return calc.srcSize } -// CalculateEncryptedSize calculates the encrypted size for the given encryption type -func (calc *CopySizeCalculator) CalculateEncryptedSize(encType EncryptionType) int64 { - // With IV in metadata, encrypted size equals actual size - return calc.CalculateActualSize() -} - // getSourceEncryptionType determines the encryption type of the source object func getSourceEncryptionType(metadata map[string][]byte) (EncryptionType, bool) { if IsSSECEncrypted(metadata) { @@ -169,22 +162,6 @@ func (calc *CopySizeCalculator) GetSizeTransitionInfo() *SizeTransitionInfo { return info } -// String returns a string representation of the encryption type -func (e EncryptionType) String() string { - switch e { - case EncryptionTypeNone: - return "None" - case EncryptionTypeSSEC: - return s3_constants.SSETypeC - case EncryptionTypeSSEKMS: - return s3_constants.SSETypeKMS - case EncryptionTypeSSES3: - return s3_constants.SSETypeS3 - default: - return "Unknown" - } -} - // OptimizedSizeCalculation provides size calculations optimized for different scenarios type OptimizedSizeCalculation struct { Strategy UnifiedCopyStrategy diff --git a/weed/s3api/s3api_etag_quoting_test.go b/weed/s3api/s3api_etag_quoting_test.go deleted file mode 100644 index 89223c9b3..000000000 --- a/weed/s3api/s3api_etag_quoting_test.go +++ /dev/null @@ -1,167 +0,0 @@ -package s3api - -import ( - "fmt" - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" -) - -// TestReproIfMatchMismatch tests specifically for the scenario where internal ETag -// is unquoted (common in SeaweedFS) but client sends quoted ETag in If-Match. -func TestReproIfMatchMismatch(t *testing.T) { - bucket := "test-bucket" - object := "/test-key" - etagValue := "37b51d194a7513e45b56f6524f2d51f2" - - // Scenario 1: Internal ETag is UNQUOTED (stored in Extended), Client sends QUOTED If-Match - // This mirrors the behavior we enforced in filer_multipart.go - t.Run("UnquotedInternal_QuotedHeader", func(t *testing.T) { - entry := &filer_pb.Entry{ - Name: "test-key", - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte(etagValue), // Unquoted - }, - Attributes: &filer_pb.FuseAttributes{ - Mtime: time.Now().Unix(), - FileSize: 1024, - }, - } - - getter := &MockEntryGetter{mockEntry: entry} - req := createTestGetRequest(bucket, object) - // Client sends quoted ETag - req.Header.Set(s3_constants.IfMatch, "\""+etagValue+"\"") - - s3a := NewS3ApiServerForTest() - result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - - if result.ErrorCode != s3err.ErrNone { - t.Errorf("Expected success (ErrNone) for unquoted internal ETag and quoted header, got %v. Internal ETag: %s", result.ErrorCode, string(entry.Extended[s3_constants.ExtETagKey])) - } - }) - - // Scenario 2: Internal ETag is QUOTED (stored in Extended), Client sends QUOTED If-Match - // This handles legacy or mixed content - t.Run("QuotedInternal_QuotedHeader", func(t *testing.T) { - entry := &filer_pb.Entry{ - Name: "test-key", - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte("\"" + etagValue + "\""), // Quoted - }, - Attributes: &filer_pb.FuseAttributes{ - Mtime: time.Now().Unix(), - FileSize: 1024, - }, - } - - getter := &MockEntryGetter{mockEntry: entry} - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfMatch, "\""+etagValue+"\"") - - s3a := NewS3ApiServerForTest() - result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - - if result.ErrorCode != s3err.ErrNone { - t.Errorf("Expected success (ErrNone) for quoted internal ETag and quoted header, got %v", result.ErrorCode) - } - }) - - // Scenario 3: Internal ETag is from Md5 (QUOTED by getObjectETag), Client sends QUOTED If-Match - t.Run("Md5Internal_QuotedHeader", func(t *testing.T) { - // Mock Md5 attribute (16 bytes) - md5Bytes := make([]byte, 16) - copy(md5Bytes, []byte("1234567890123456")) // This doesn't match the hex string below, but getObjectETag formats it as hex - - // Expected ETag from Md5 is hex string of bytes - expectedHex := fmt.Sprintf("%x", md5Bytes) - - entry := &filer_pb.Entry{ - Name: "test-key", - Attributes: &filer_pb.FuseAttributes{ - Mtime: time.Now().Unix(), - FileSize: 1024, - Md5: md5Bytes, - }, - } - - getter := &MockEntryGetter{mockEntry: entry} - req := createTestGetRequest(bucket, object) - req.Header.Set(s3_constants.IfMatch, "\""+expectedHex+"\"") - - s3a := NewS3ApiServerForTest() - result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object) - - if result.ErrorCode != s3err.ErrNone { - t.Errorf("Expected success (ErrNone) for Md5 internal ETag and quoted header, got %v", result.ErrorCode) - } - }) - - // Test getObjectETag specifically ensuring it returns quoted strings - t.Run("getObjectETag_ShouldReturnQuoted", func(t *testing.T) { - entry := &filer_pb.Entry{ - Name: "test-key", - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte("unquoted-etag"), - }, - } - - s3a := NewS3ApiServerForTest() - etag := s3a.getObjectETag(entry) - - expected := "\"unquoted-etag\"" - if etag != expected { - t.Errorf("Expected quoted ETag %s, got %s", expected, etag) - } - }) - - // Test getObjectETag fallback when Extended ETag is present but empty - t.Run("getObjectETag_EmptyExtended_ShouldFallback", func(t *testing.T) { - md5Bytes := []byte("1234567890123456") - expectedHex := fmt.Sprintf("\"%x\"", md5Bytes) - - entry := &filer_pb.Entry{ - Name: "test-key-fallback", - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte(""), // Present but empty - }, - Attributes: &filer_pb.FuseAttributes{ - Mtime: time.Now().Unix(), - FileSize: 1024, - Md5: md5Bytes, - }, - } - - s3a := NewS3ApiServerForTest() - etag := s3a.getObjectETag(entry) - - if etag != expectedHex { - t.Errorf("Expected fallback ETag %s, got %s", expectedHex, etag) - } - }) - - // Test newListEntry ETag behavior - t.Run("newListEntry_ShouldReturnQuoted", func(t *testing.T) { - entry := &filer_pb.Entry{ - Name: "test-key", - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte("unquoted-etag"), - }, - Attributes: &filer_pb.FuseAttributes{ - Mtime: time.Now().Unix(), - FileSize: 1024, - }, - } - - s3a := NewS3ApiServerForTest() - listEntry := newListEntry(s3a, entry, "", "bucket/dir", "test-key", "bucket/", false, false, false) - - expected := "\"unquoted-etag\"" - if listEntry.ETag != expected { - t.Errorf("Expected quoted ETag %s, got %s", expected, listEntry.ETag) - } - }) -} diff --git a/weed/s3api/s3api_key_rotation.go b/weed/s3api/s3api_key_rotation.go deleted file mode 100644 index c99c13415..000000000 --- a/weed/s3api/s3api_key_rotation.go +++ /dev/null @@ -1,30 +0,0 @@ -package s3api - -import ( - "net/http" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" -) - -// IsSameObjectCopy determines if this is a same-object copy operation -func IsSameObjectCopy(r *http.Request, srcBucket, srcObject, dstBucket, dstObject string) bool { - return srcBucket == dstBucket && srcObject == dstObject -} - -// NeedsKeyRotation determines if the copy operation requires key rotation -func NeedsKeyRotation(entry *filer_pb.Entry, r *http.Request) bool { - // Check for SSE-C key rotation - if IsSSECEncrypted(entry.Extended) && IsSSECRequest(r) { - return true // Assume different keys for safety - } - - // Check for SSE-KMS key rotation - if IsSSEKMSEncrypted(entry.Extended) && IsSSEKMSRequest(r) { - srcKeyID, _ := GetSourceSSEKMSInfo(entry.Extended) - dstKeyID := r.Header.Get(s3_constants.AmzServerSideEncryptionAwsKmsKeyId) - return srcKeyID != dstKeyID - } - - return false -} diff --git a/weed/s3api/s3api_object_handlers.go b/weed/s3api/s3api_object_handlers.go index 7a7538214..71d6bc26d 100644 --- a/weed/s3api/s3api_object_handlers.go +++ b/weed/s3api/s3api_object_handlers.go @@ -268,15 +268,6 @@ func mimeDetect(r *http.Request, dataReader io.Reader) io.ReadCloser { return io.NopCloser(dataReader) } -func urlEscapeObject(object string) string { - normalized := s3_constants.NormalizeObjectKey(object) - // Ensure leading slash for filer paths - if normalized != "" && !strings.HasPrefix(normalized, "/") { - normalized = "/" + normalized - } - return urlPathEscape(normalized) -} - func entryUrlEncode(dir string, entry string, encodingTypeUrl bool) (dirName string, entryName string, prefix string) { if !encodingTypeUrl { return dir, entry, entry @@ -2895,59 +2886,6 @@ func (m *MultipartSSEReader) Close() error { return lastErr } -// Read implements the io.Reader interface for SSERangeReader -func (r *SSERangeReader) Read(p []byte) (n int, err error) { - // Skip bytes iteratively (no recursion) until we reach the offset - for r.skipped < r.offset { - skipNeeded := r.offset - r.skipped - - // Lazily allocate skip buffer on first use, reuse thereafter - if r.skipBuf == nil { - // Use a fixed 32KB buffer for skipping (avoids per-call allocation) - r.skipBuf = make([]byte, 32*1024) - } - - // Determine how much to skip in this iteration - bufSize := int64(len(r.skipBuf)) - if skipNeeded < bufSize { - bufSize = skipNeeded - } - - skipRead, skipErr := r.reader.Read(r.skipBuf[:bufSize]) - r.skipped += int64(skipRead) - - if skipErr != nil { - return 0, skipErr - } - - // Guard against infinite loop: io.Reader may return (0, nil) - // which is permitted by the interface contract for non-empty buffers. - // If we get zero bytes without an error, treat it as an unexpected EOF. - if skipRead == 0 { - return 0, io.ErrUnexpectedEOF - } - } - - // If we have a remaining limit and it's reached - if r.remaining == 0 { - return 0, io.EOF - } - - // Calculate how much to read - readSize := len(p) - if r.remaining > 0 && int64(readSize) > r.remaining { - readSize = int(r.remaining) - } - - // Read the data - n, err = r.reader.Read(p[:readSize]) - if r.remaining > 0 { - r.remaining -= int64(n) - } - - return n, err -} - // PartBoundaryInfo holds information about a part's chunk boundaries type PartBoundaryInfo struct { PartNumber int `json:"part"` diff --git a/weed/s3api/s3api_object_handlers_copy.go b/weed/s3api/s3api_object_handlers_copy.go index 58f18a038..cac24c946 100644 --- a/weed/s3api/s3api_object_handlers_copy.go +++ b/weed/s3api/s3api_object_handlers_copy.go @@ -14,8 +14,6 @@ import ( "strings" "time" - "modernc.org/strutil" - "github.com/seaweedfs/seaweedfs/weed/filer" "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/operation" @@ -797,58 +795,6 @@ func replaceDirective(reqHeader http.Header) (replaceMeta, replaceTagging bool) return reqHeader.Get(s3_constants.AmzUserMetaDirective) == DirectiveReplace, reqHeader.Get(s3_constants.AmzObjectTaggingDirective) == DirectiveReplace } -func processMetadata(reqHeader, existing http.Header, replaceMeta, replaceTagging bool, getTags func(parentDirectoryPath string, entryName string) (tags map[string]string, err error), dir, name string) (err error) { - if sc := reqHeader.Get(s3_constants.AmzStorageClass); len(sc) == 0 { - if sc := existing.Get(s3_constants.AmzStorageClass); len(sc) > 0 { - reqHeader.Set(s3_constants.AmzStorageClass, sc) - } - } - - if !replaceMeta { - for header := range reqHeader { - if strings.HasPrefix(header, s3_constants.AmzUserMetaPrefix) { - delete(reqHeader, header) - } - } - for k, v := range existing { - if strings.HasPrefix(k, s3_constants.AmzUserMetaPrefix) { - reqHeader[k] = v - } - } - } - - if !replaceTagging { - for header, _ := range reqHeader { - if strings.HasPrefix(header, s3_constants.AmzObjectTagging) { - delete(reqHeader, header) - } - } - - found := false - for k, _ := range existing { - if strings.HasPrefix(k, s3_constants.AmzObjectTaggingPrefix) { - found = true - break - } - } - - if found { - tags, err := getTags(dir, name) - if err != nil { - return err - } - - var tagArr []string - for k, v := range tags { - tagArr = append(tagArr, fmt.Sprintf("%s=%s", k, v)) - } - tagStr := strutil.JoinFields(tagArr, "&") - reqHeader.Set(s3_constants.AmzObjectTagging, tagStr) - } - } - return -} - func processMetadataBytes(reqHeader http.Header, existing map[string][]byte, replaceMeta, replaceTagging bool) (metadata map[string][]byte, err error) { metadata = make(map[string][]byte) @@ -2632,13 +2578,6 @@ func cleanupVersioningMetadata(metadata map[string][]byte) { delete(metadata, s3_constants.ExtETagKey) } -// shouldCreateVersionForCopy determines whether a version should be created during a copy operation -// based on the destination bucket's versioning state. -// Returns true only if versioning is explicitly "Enabled", not "Suspended" or unconfigured. -func shouldCreateVersionForCopy(versioningState string) bool { - return versioningState == s3_constants.VersioningEnabled -} - // isOrphanedSSES3Header checks if a header is an orphaned SSE-S3 encryption header. // An orphaned header is one where the encryption indicator exists but the actual key is missing. // This can happen when an object was previously encrypted but then copied without encryption, diff --git a/weed/s3api/s3api_object_handlers_copy_test.go b/weed/s3api/s3api_object_handlers_copy_test.go deleted file mode 100644 index 93d1475cd..000000000 --- a/weed/s3api/s3api_object_handlers_copy_test.go +++ /dev/null @@ -1,760 +0,0 @@ -package s3api - -import ( - "fmt" - "net/http" - "reflect" - "sort" - "strings" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/util" -) - -type H map[string]string - -func (h H) String() string { - pairs := make([]string, 0, len(h)) - for k, v := range h { - pairs = append(pairs, fmt.Sprintf("%s : %s", k, v)) - } - sort.Strings(pairs) - join := strings.Join(pairs, "\n") - return "\n" + join + "\n" -} - -var processMetadataTestCases = []struct { - caseId int - request H - existing H - getTags H - want H -}{ - { - 201, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-Type": "existing", - }, - H{ - "A": "B", - "a": "b", - "type": "existing", - }, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging": "A=B&a=b&type=existing", - }, - }, - { - 202, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-Type": "existing", - }, - H{ - "A": "B", - "a": "b", - "type": "existing", - }, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=existing", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - }, - }, - - { - 203, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-Type": "existing", - }, - H{ - "A": "B", - "a": "b", - "type": "existing", - }, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging": "A=B&a=b&type=request", - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - }, - - { - 204, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-Type": "existing", - }, - H{ - "A": "B", - "a": "b", - "type": "existing", - }, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - }, - - { - 205, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - H{}, - H{}, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - }, - - { - 206, - H{ - "User-Agent": "firefox", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-Type": "existing", - }, - H{ - "A": "B", - "a": "b", - "type": "existing", - }, - H{ - "User-Agent": "firefox", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - }, - - { - 207, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-Type": "existing", - }, - H{ - "A": "B", - "a": "b", - "type": "existing", - }, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - }, -} -var processMetadataBytesTestCases = []struct { - caseId int - request H - existing H - want H -}{ - { - 101, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-type": "existing", - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-type": "existing", - }, - }, - - { - 102, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-type": "existing", - }, - H{ - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-type": "existing", - }, - }, - - { - 103, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-type": "existing", - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-type": "request", - }, - }, - - { - 104, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-type": "existing", - }, - H{ - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-type": "request", - }, - }, - - { - 105, - H{ - "User-Agent": "firefox", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - H{ - "X-Amz-Meta-My-Meta": "existing", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-type": "existing", - }, - H{}, - }, - - { - 107, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - H{}, - H{ - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging-A": "B", - "X-Amz-Tagging-a": "b", - "X-Amz-Tagging-type": "request", - }, - }, - - { - 108, - H{ - "User-Agent": "firefox", - "X-Amz-Meta-My-Meta": "request", - "X-Amz-Tagging": "A=B&a=b&type=request*", - s3_constants.AmzUserMetaDirective: DirectiveReplace, - s3_constants.AmzObjectTaggingDirective: DirectiveReplace, - }, - H{}, - H{}, - }, -} - -func TestProcessMetadata(t *testing.T) { - for _, tc := range processMetadataTestCases { - reqHeader := transferHToHeader(tc.request) - existing := transferHToHeader(tc.existing) - replaceMeta, replaceTagging := replaceDirective(reqHeader) - err := processMetadata(reqHeader, existing, replaceMeta, replaceTagging, func(_ string, _ string) (tags map[string]string, err error) { - return tc.getTags, nil - }, "", "") - if err != nil { - t.Error(err) - } - - result := transferHeaderToH(reqHeader) - fmtTagging(result, tc.want) - - if !reflect.DeepEqual(result, tc.want) { - t.Error(fmt.Errorf("\n### CaseID: %d ###"+ - "\nRequest:%v"+ - "\nExisting:%v"+ - "\nGetTags:%v"+ - "\nWant:%v"+ - "\nActual:%v", - tc.caseId, tc.request, tc.existing, tc.getTags, tc.want, result)) - } - } -} - -func TestProcessMetadataBytes(t *testing.T) { - for _, tc := range processMetadataBytesTestCases { - reqHeader := transferHToHeader(tc.request) - existing := transferHToBytesArr(tc.existing) - replaceMeta, replaceTagging := replaceDirective(reqHeader) - extends, _ := processMetadataBytes(reqHeader, existing, replaceMeta, replaceTagging) - - result := transferBytesArrToH(extends) - fmtTagging(result, tc.want) - - if !reflect.DeepEqual(result, tc.want) { - t.Error(fmt.Errorf("\n### CaseID: %d ###"+ - "\nRequest:%v"+ - "\nExisting:%v"+ - "\nWant:%v"+ - "\nActual:%v", - tc.caseId, tc.request, tc.existing, tc.want, result)) - } - } -} - -func TestMergeCopyMetadataPreservesInternalFields(t *testing.T) { - existing := map[string][]byte{ - s3_constants.SeaweedFSSSEKMSKey: []byte("kms-secret"), - s3_constants.SeaweedFSSSEIV: []byte("iv"), - "X-Amz-Meta-Old": []byte("old"), - "X-Amz-Tagging-Old": []byte("old-tag"), - s3_constants.AmzStorageClass: []byte("STANDARD"), - } - updated := map[string][]byte{ - "X-Amz-Meta-New": []byte("new"), - "X-Amz-Tagging-New": []byte("new-tag"), - s3_constants.AmzStorageClass: []byte("GLACIER"), - } - - merged := mergeCopyMetadata(existing, updated) - - if got := string(merged[s3_constants.SeaweedFSSSEKMSKey]); got != "kms-secret" { - t.Fatalf("expected internal KMS key to be preserved, got %q", got) - } - if got := string(merged[s3_constants.SeaweedFSSSEIV]); got != "iv" { - t.Fatalf("expected internal IV to be preserved, got %q", got) - } - if _, ok := merged["X-Amz-Meta-Old"]; ok { - t.Fatalf("expected stale user metadata to be removed, got %#v", merged) - } - if _, ok := merged["X-Amz-Tagging-Old"]; ok { - t.Fatalf("expected stale tagging metadata to be removed, got %#v", merged) - } - if got := string(merged["X-Amz-Meta-New"]); got != "new" { - t.Fatalf("expected replacement user metadata to be applied, got %q", got) - } - if got := string(merged["X-Amz-Tagging-New"]); got != "new-tag" { - t.Fatalf("expected replacement tagging metadata to be applied, got %q", got) - } - if got := string(merged[s3_constants.AmzStorageClass]); got != "GLACIER" { - t.Fatalf("expected storage class to be updated, got %q", got) - } -} - -func TestCopyEntryETagPrefersStoredETag(t *testing.T) { - entry := &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte("\"stored-etag\""), - }, - Attributes: &filer_pb.FuseAttributes{}, - } - - if got := copyEntryETag(util.FullPath("/buckets/test-bucket/object.txt"), entry); got != "\"stored-etag\"" { - t.Fatalf("copyEntryETag() = %q, want %q", got, "\"stored-etag\"") - } -} - -func fmtTagging(maps ...map[string]string) { - for _, m := range maps { - if tagging := m[s3_constants.AmzObjectTagging]; len(tagging) > 0 { - split := strings.Split(tagging, "&") - sort.Strings(split) - m[s3_constants.AmzObjectTagging] = strings.Join(split, "&") - } - } -} - -func transferHToHeader(data map[string]string) http.Header { - header := http.Header{} - for k, v := range data { - header.Add(k, v) - } - return header -} - -func transferHToBytesArr(data map[string]string) map[string][]byte { - m := make(map[string][]byte, len(data)) - for k, v := range data { - m[k] = []byte(v) - } - return m -} - -func transferBytesArrToH(data map[string][]byte) H { - m := make(map[string]string, len(data)) - for k, v := range data { - m[k] = string(v) - } - return m -} - -func transferHeaderToH(data map[string][]string) H { - m := make(map[string]string, len(data)) - for k, v := range data { - m[k] = v[len(v)-1] - } - return m -} - -// TestShouldCreateVersionForCopy tests the production function that determines -// whether a version should be created during a copy operation. -// This addresses issue #7505 where copies were incorrectly creating versions for non-versioned buckets. -func TestShouldCreateVersionForCopy(t *testing.T) { - testCases := []struct { - name string - versioningState string - expectedResult bool - description string - }{ - { - name: "VersioningEnabled", - versioningState: s3_constants.VersioningEnabled, - expectedResult: true, - description: "Should create versions in .versions/ directory when versioning is Enabled", - }, - { - name: "VersioningSuspended", - versioningState: s3_constants.VersioningSuspended, - expectedResult: false, - description: "Should NOT create versions when versioning is Suspended", - }, - { - name: "VersioningNotConfigured", - versioningState: "", - expectedResult: false, - description: "Should NOT create versions when versioning is not configured", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Call the actual production function - result := shouldCreateVersionForCopy(tc.versioningState) - - if result != tc.expectedResult { - t.Errorf("Test case %s failed: %s\nExpected shouldCreateVersionForCopy(%q)=%v, got %v", - tc.name, tc.description, tc.versioningState, tc.expectedResult, result) - } - }) - } -} - -// TestCleanupVersioningMetadata tests the production function that removes versioning metadata. -// This ensures objects copied to non-versioned buckets don't carry invalid versioning metadata -// or stale ETag values from the source. -func TestCleanupVersioningMetadata(t *testing.T) { - testCases := []struct { - name string - sourceMetadata map[string][]byte - expectedKeys []string // Keys that should be present after cleanup - removedKeys []string // Keys that should be removed - }{ - { - name: "RemovesAllVersioningMetadata", - sourceMetadata: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte("version-123"), - s3_constants.ExtDeleteMarkerKey: []byte("false"), - s3_constants.ExtIsLatestKey: []byte("true"), - s3_constants.ExtETagKey: []byte("\"abc123\""), - "X-Amz-Meta-Custom": []byte("value"), - }, - expectedKeys: []string{"X-Amz-Meta-Custom"}, - removedKeys: []string{s3_constants.ExtVersionIdKey, s3_constants.ExtDeleteMarkerKey, s3_constants.ExtIsLatestKey, s3_constants.ExtETagKey}, - }, - { - name: "HandlesEmptyMetadata", - sourceMetadata: map[string][]byte{}, - expectedKeys: []string{}, - removedKeys: []string{s3_constants.ExtVersionIdKey, s3_constants.ExtDeleteMarkerKey, s3_constants.ExtIsLatestKey, s3_constants.ExtETagKey}, - }, - { - name: "PreservesNonVersioningMetadata", - sourceMetadata: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte("version-456"), - s3_constants.ExtETagKey: []byte("\"def456\""), - "X-Amz-Meta-Custom": []byte("value1"), - "X-Amz-Meta-Another": []byte("value2"), - s3_constants.ExtIsLatestKey: []byte("true"), - }, - expectedKeys: []string{"X-Amz-Meta-Custom", "X-Amz-Meta-Another"}, - removedKeys: []string{s3_constants.ExtVersionIdKey, s3_constants.ExtETagKey, s3_constants.ExtIsLatestKey}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Create a copy of the source metadata - dstMetadata := make(map[string][]byte) - for k, v := range tc.sourceMetadata { - dstMetadata[k] = v - } - - // Call the actual production function - cleanupVersioningMetadata(dstMetadata) - - // Verify expected keys are present - for _, key := range tc.expectedKeys { - if _, exists := dstMetadata[key]; !exists { - t.Errorf("Expected key %s to be present in destination metadata", key) - } - } - - // Verify removed keys are absent - for _, key := range tc.removedKeys { - if _, exists := dstMetadata[key]; exists { - t.Errorf("Expected key %s to be removed from destination metadata, but it's still present", key) - } - } - - // Verify the count matches to ensure no extra keys are present - if len(dstMetadata) != len(tc.expectedKeys) { - t.Errorf("Expected %d metadata keys, but got %d. Extra keys might be present.", len(tc.expectedKeys), len(dstMetadata)) - } - }) - } -} - -// TestCopyVersioningIntegration validates the metadata shaping that happens -// before copy finalization for each destination versioning mode. -func TestCopyVersioningIntegration(t *testing.T) { - testCases := []struct { - name string - versioningState string - sourceMetadata map[string][]byte - expectVersionPath bool - expectMetadataKeys []string - }{ - { - name: "EnabledPreservesMetadata", - versioningState: s3_constants.VersioningEnabled, - sourceMetadata: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte("v123"), - "X-Amz-Meta-Custom": []byte("value"), - }, - expectVersionPath: true, - expectMetadataKeys: []string{ - s3_constants.ExtVersionIdKey, - "X-Amz-Meta-Custom", - }, - }, - { - name: "SuspendedCleansVersionMetadataBeforeFinalize", - versioningState: s3_constants.VersioningSuspended, - sourceMetadata: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte("v123"), - "X-Amz-Meta-Custom": []byte("value"), - }, - expectVersionPath: false, - expectMetadataKeys: []string{ - "X-Amz-Meta-Custom", - }, - }, - { - name: "NotConfiguredCleansMetadata", - versioningState: "", - sourceMetadata: map[string][]byte{ - s3_constants.ExtVersionIdKey: []byte("v123"), - s3_constants.ExtDeleteMarkerKey: []byte("false"), - "X-Amz-Meta-Custom": []byte("value"), - }, - expectVersionPath: false, - expectMetadataKeys: []string{ - "X-Amz-Meta-Custom", - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Test version creation decision using production function - shouldCreateVersion := shouldCreateVersionForCopy(tc.versioningState) - if shouldCreateVersion != tc.expectVersionPath { - t.Errorf("shouldCreateVersionForCopy(%q) = %v, expected %v", - tc.versioningState, shouldCreateVersion, tc.expectVersionPath) - } - - // Test metadata cleanup using production function - metadata := make(map[string][]byte) - for k, v := range tc.sourceMetadata { - metadata[k] = v - } - - if !shouldCreateVersion { - cleanupVersioningMetadata(metadata) - } - - // Verify only expected keys remain - for _, expectedKey := range tc.expectMetadataKeys { - if _, exists := metadata[expectedKey]; !exists { - t.Errorf("Expected key %q to be present in metadata", expectedKey) - } - } - - // Verify the count matches (no extra keys) - if len(metadata) != len(tc.expectMetadataKeys) { - t.Errorf("Expected %d metadata keys, got %d", len(tc.expectMetadataKeys), len(metadata)) - } - }) - } -} - -// TestIsOrphanedSSES3Header tests detection of orphaned SSE-S3 headers. -// This is a regression test for GitHub issue #7562 where copying from an -// encrypted bucket to an unencrypted bucket left behind the encryption header -// without the actual key, causing subsequent copy operations to fail. -func TestIsOrphanedSSES3Header(t *testing.T) { - testCases := []struct { - name string - headerKey string - metadata map[string][]byte - expected bool - }{ - { - name: "Not an encryption header", - headerKey: "X-Amz-Meta-Custom", - metadata: map[string][]byte{ - "X-Amz-Meta-Custom": []byte("value"), - }, - expected: false, - }, - { - name: "SSE-S3 header with key present (valid)", - headerKey: s3_constants.AmzServerSideEncryption, - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - s3_constants.SeaweedFSSSES3Key: []byte("key-data"), - }, - expected: false, - }, - { - name: "SSE-S3 header without key (orphaned - GitHub #7562)", - headerKey: s3_constants.AmzServerSideEncryption, - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - }, - expected: true, - }, - { - name: "SSE-KMS header (not SSE-S3)", - headerKey: s3_constants.AmzServerSideEncryption, - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("aws:kms"), - }, - expected: false, - }, - { - name: "Different header key entirely", - headerKey: s3_constants.SeaweedFSSSES3Key, - metadata: map[string][]byte{ - s3_constants.AmzServerSideEncryption: []byte("AES256"), - }, - expected: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := isOrphanedSSES3Header(tc.headerKey, tc.metadata) - if result != tc.expected { - t.Errorf("isOrphanedSSES3Header(%q, metadata) = %v, expected %v", - tc.headerKey, result, tc.expected) - } - }) - } -} diff --git a/weed/s3api/s3api_object_handlers_delete_test.go b/weed/s3api/s3api_object_handlers_delete_test.go deleted file mode 100644 index 5596d6130..000000000 --- a/weed/s3api/s3api_object_handlers_delete_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package s3api - -import ( - "encoding/xml" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" -) - -func TestValidateDeleteIfMatch(t *testing.T) { - s3a := NewS3ApiServerForTest() - existingEntry := &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.ExtETagKey: []byte("\"abc123\""), - }, - } - deleteMarkerEntry := &filer_pb.Entry{ - Extended: map[string][]byte{ - s3_constants.ExtDeleteMarkerKey: []byte("true"), - }, - } - - testCases := []struct { - name string - entry *filer_pb.Entry - ifMatch string - missingCode s3err.ErrorCode - expected s3err.ErrorCode - }{ - { - name: "matching etag succeeds", - entry: existingEntry, - ifMatch: "\"abc123\"", - missingCode: s3err.ErrPreconditionFailed, - expected: s3err.ErrNone, - }, - { - name: "wildcard succeeds for existing entry", - entry: existingEntry, - ifMatch: "*", - missingCode: s3err.ErrPreconditionFailed, - expected: s3err.ErrNone, - }, - { - name: "mismatched etag fails", - entry: existingEntry, - ifMatch: "\"other\"", - missingCode: s3err.ErrPreconditionFailed, - expected: s3err.ErrPreconditionFailed, - }, - { - name: "missing current object fails single delete", - entry: nil, - ifMatch: "*", - missingCode: s3err.ErrPreconditionFailed, - expected: s3err.ErrPreconditionFailed, - }, - { - name: "missing current object returns no such key for batch delete", - entry: nil, - ifMatch: "*", - missingCode: s3err.ErrNoSuchKey, - expected: s3err.ErrNoSuchKey, - }, - { - name: "current delete marker behaves like missing object", - entry: normalizeConditionalTargetEntry(deleteMarkerEntry), - ifMatch: "*", - missingCode: s3err.ErrPreconditionFailed, - expected: s3err.ErrPreconditionFailed, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if errCode := s3a.validateDeleteIfMatch(tc.entry, tc.ifMatch, tc.missingCode); errCode != tc.expected { - t.Fatalf("validateDeleteIfMatch() = %v, want %v", errCode, tc.expected) - } - }) - } -} - -func TestDeleteObjectsRequestUnmarshalConditionalETags(t *testing.T) { - var req DeleteObjectsRequest - body := []byte(` - - true - - first.txt - * - - - second.txt - 3HL4kqCxf3vjVBH40Nrjfkd - "abc123" - -`) - - if err := xml.Unmarshal(body, &req); err != nil { - t.Fatalf("xml.Unmarshal() error = %v", err) - } - if !req.Quiet { - t.Fatalf("expected Quiet=true") - } - if len(req.Objects) != 2 { - t.Fatalf("expected 2 objects, got %d", len(req.Objects)) - } - if req.Objects[0].ETag != "*" { - t.Fatalf("expected first object ETag to be '*', got %q", req.Objects[0].ETag) - } - if req.Objects[1].ETag != "\"abc123\"" { - t.Fatalf("expected second object ETag to preserve quotes, got %q", req.Objects[1].ETag) - } - if req.Objects[1].VersionId != "3HL4kqCxf3vjVBH40Nrjfkd" { - t.Fatalf("expected second object VersionId to unmarshal, got %q", req.Objects[1].VersionId) - } -} diff --git a/weed/s3api/s3api_object_handlers_put.go b/weed/s3api/s3api_object_handlers_put.go index 805d63133..adda8b1c7 100644 --- a/weed/s3api/s3api_object_handlers_put.go +++ b/weed/s3api/s3api_object_handlers_put.go @@ -1859,28 +1859,6 @@ func (s3a *S3ApiServer) validateConditionalHeaders(r *http.Request, headers cond return s3err.ErrNone } -// checkConditionalHeadersWithGetter is a testable method that accepts a simple EntryGetter -// Uses the production getObjectETag and etagMatches methods to ensure testing of real logic -func (s3a *S3ApiServer) checkConditionalHeadersWithGetter(getter EntryGetter, r *http.Request, bucket, object string) s3err.ErrorCode { - headers, errCode := parseConditionalHeaders(r) - if errCode != s3err.ErrNone { - return errCode - } - // Get object entry for conditional checks. - bucketDir := "/buckets/" + bucket - entry, entryErr := getter.getEntry(bucketDir, object) - if entryErr != nil { - if errors.Is(entryErr, filer_pb.ErrNotFound) { - entry = nil - } else { - glog.Errorf("checkConditionalHeadersWithGetter: failed to get entry for %s/%s: %v", bucket, object, entryErr) - return s3err.ErrInternalError - } - } - - return s3a.validateConditionalHeaders(r, headers, entry, bucket, object) -} - // checkConditionalHeaders is the production method that uses the S3ApiServer as EntryGetter func (s3a *S3ApiServer) checkConditionalHeaders(r *http.Request, bucket, object string) s3err.ErrorCode { // Fast path: if no conditional headers are present, skip object resolution entirely. @@ -2002,28 +1980,6 @@ func (s3a *S3ApiServer) validateConditionalHeadersForReads(r *http.Request, head return ConditionalHeaderResult{ErrorCode: s3err.ErrNone, Entry: entry} } -// checkConditionalHeadersForReadsWithGetter is a testable method for read operations -// Uses the production getObjectETag and etagMatches methods to ensure testing of real logic -func (s3a *S3ApiServer) checkConditionalHeadersForReadsWithGetter(getter EntryGetter, r *http.Request, bucket, object string) ConditionalHeaderResult { - headers, errCode := parseConditionalHeaders(r) - if errCode != s3err.ErrNone { - return ConditionalHeaderResult{ErrorCode: errCode} - } - // Get object entry for conditional checks. - bucketDir := "/buckets/" + bucket - entry, entryErr := getter.getEntry(bucketDir, object) - if entryErr != nil { - if errors.Is(entryErr, filer_pb.ErrNotFound) { - entry = nil - } else { - glog.Errorf("checkConditionalHeadersForReadsWithGetter: failed to get entry for %s/%s: %v", bucket, object, entryErr) - return ConditionalHeaderResult{ErrorCode: s3err.ErrInternalError} - } - } - - return s3a.validateConditionalHeadersForReads(r, headers, entry, bucket, object) -} - // checkConditionalHeadersForReads is the production method that uses the S3ApiServer as EntryGetter func (s3a *S3ApiServer) checkConditionalHeadersForReads(r *http.Request, bucket, object string) ConditionalHeaderResult { // Fast path: if no conditional headers are present, skip object resolution entirely. diff --git a/weed/s3api/s3api_object_handlers_put_test.go b/weed/s3api/s3api_object_handlers_put_test.go deleted file mode 100644 index a5646bff7..000000000 --- a/weed/s3api/s3api_object_handlers_put_test.go +++ /dev/null @@ -1,341 +0,0 @@ -package s3api - -import ( - "encoding/xml" - "errors" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "sync" - "testing" - - "github.com/gorilla/mux" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3err" - weed_server "github.com/seaweedfs/seaweedfs/weed/server" - "github.com/seaweedfs/seaweedfs/weed/util/constants" -) - -func TestFilerErrorToS3Error(t *testing.T) { - tests := []struct { - name string - err error - expectedErr s3err.ErrorCode - }{ - { - name: "nil error", - err: nil, - expectedErr: s3err.ErrNone, - }, - { - name: "MD5 mismatch error", - err: errors.New(constants.ErrMsgBadDigest), - expectedErr: s3err.ErrBadDigest, - }, - { - name: "Read only error (direct)", - err: weed_server.ErrReadOnly, - expectedErr: s3err.ErrAccessDenied, - }, - { - name: "Read only error (wrapped)", - err: fmt.Errorf("create file /buckets/test/file.txt: %w", weed_server.ErrReadOnly), - expectedErr: s3err.ErrAccessDenied, - }, - { - name: "Context canceled error", - err: errors.New("rpc error: code = Canceled desc = context canceled"), - expectedErr: s3err.ErrInvalidRequest, - }, - { - name: "Context canceled error (simple)", - err: errors.New("context canceled"), - expectedErr: s3err.ErrInvalidRequest, - }, - { - name: "Directory exists error (sentinel)", - err: fmt.Errorf("CreateEntry /path: %w", filer_pb.ErrExistingIsDirectory), - expectedErr: s3err.ErrExistingObjectIsDirectory, - }, - { - name: "Parent is file error (sentinel)", - err: fmt.Errorf("CreateEntry /path: %w", filer_pb.ErrParentIsFile), - expectedErr: s3err.ErrExistingObjectIsFile, - }, - { - name: "Existing is file error (sentinel)", - err: fmt.Errorf("CreateEntry /path: %w", filer_pb.ErrExistingIsFile), - expectedErr: s3err.ErrExistingObjectIsFile, - }, - { - name: "Entry name too long (sentinel)", - err: fmt.Errorf("CreateEntry: %w", filer_pb.ErrEntryNameTooLong), - expectedErr: s3err.ErrKeyTooLongError, - }, - { - name: "Entry name too long (bare sentinel)", - err: filer_pb.ErrEntryNameTooLong, - expectedErr: s3err.ErrKeyTooLongError, - }, - { - name: "Unknown error", - err: errors.New("some random error"), - expectedErr: s3err.ErrInternalError, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := filerErrorToS3Error(tt.err) - if result != tt.expectedErr { - t.Errorf("filerErrorToS3Error(%v) = %v, want %v", tt.err, result, tt.expectedErr) - } - }) - } -} - -// setupKeyLengthTestRouter creates a minimal router that maps requests directly -// to the given handler with {bucket} and {object} mux vars, bypassing auth. -func setupKeyLengthTestRouter(handler http.HandlerFunc) *mux.Router { - router := mux.NewRouter() - bucket := router.PathPrefix("/{bucket}").Subrouter() - bucket.Path("/{object:.+}").HandlerFunc(handler) - return router -} - -func TestPutObjectHandler_KeyTooLong(t *testing.T) { - s3a := &S3ApiServer{} - router := setupKeyLengthTestRouter(s3a.PutObjectHandler) - - longKey := strings.Repeat("a", s3_constants.MaxS3ObjectKeyLength+1) - req := httptest.NewRequest(http.MethodPut, "/bucket/"+longKey, nil) - rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) - - if rr.Code != http.StatusBadRequest { - t.Errorf("expected status %d, got %d", http.StatusBadRequest, rr.Code) - } - var errResp s3err.RESTErrorResponse - if err := xml.Unmarshal(rr.Body.Bytes(), &errResp); err != nil { - t.Fatalf("failed to parse error XML: %v", err) - } - if errResp.Code != "KeyTooLongError" { - t.Errorf("expected error code KeyTooLongError, got %s", errResp.Code) - } -} - -func TestPutObjectHandler_KeyAtLimit(t *testing.T) { - s3a := &S3ApiServer{} - - // Wrap handler to convert panics from uninitialized server state into 500 - // responses. The key length check runs early and writes 400 KeyTooLongError - // before reaching any code that needs a fully initialized server. A panic - // means the handler accepted the key and continued past the check. - panicSafe := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer func() { - if p := recover(); p != nil { - w.WriteHeader(http.StatusInternalServerError) - } - }() - s3a.PutObjectHandler(w, r) - }) - router := setupKeyLengthTestRouter(panicSafe) - - atLimitKey := strings.Repeat("a", s3_constants.MaxS3ObjectKeyLength) - req := httptest.NewRequest(http.MethodPut, "/bucket/"+atLimitKey, nil) - rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) - - // Must NOT be KeyTooLongError — any other response (including 500 from - // the minimal server hitting uninitialized state) proves the key passed. - var errResp s3err.RESTErrorResponse - if rr.Code == http.StatusBadRequest { - if err := xml.Unmarshal(rr.Body.Bytes(), &errResp); err == nil && errResp.Code == "KeyTooLongError" { - t.Errorf("key at exactly %d bytes should not be rejected as too long", s3_constants.MaxS3ObjectKeyLength) - } - } -} - -func TestCopyObjectHandler_KeyTooLong(t *testing.T) { - s3a := &S3ApiServer{} - router := setupKeyLengthTestRouter(s3a.CopyObjectHandler) - - longKey := strings.Repeat("a", s3_constants.MaxS3ObjectKeyLength+1) - req := httptest.NewRequest(http.MethodPut, "/bucket/"+longKey, nil) - req.Header.Set("X-Amz-Copy-Source", "/src-bucket/src-object") - rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) - - if rr.Code != http.StatusBadRequest { - t.Errorf("expected status %d, got %d", http.StatusBadRequest, rr.Code) - } - var errResp s3err.RESTErrorResponse - if err := xml.Unmarshal(rr.Body.Bytes(), &errResp); err != nil { - t.Fatalf("failed to parse error XML: %v", err) - } - if errResp.Code != "KeyTooLongError" { - t.Errorf("expected error code KeyTooLongError, got %s", errResp.Code) - } -} - -func TestNewMultipartUploadHandler_KeyTooLong(t *testing.T) { - s3a := &S3ApiServer{} - router := setupKeyLengthTestRouter(s3a.NewMultipartUploadHandler) - - longKey := strings.Repeat("a", s3_constants.MaxS3ObjectKeyLength+1) - req := httptest.NewRequest(http.MethodPost, "/bucket/"+longKey+"?uploads", nil) - rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) - - if rr.Code != http.StatusBadRequest { - t.Errorf("expected status %d, got %d", http.StatusBadRequest, rr.Code) - } - var errResp s3err.RESTErrorResponse - if err := xml.Unmarshal(rr.Body.Bytes(), &errResp); err != nil { - t.Fatalf("failed to parse error XML: %v", err) - } - if errResp.Code != "KeyTooLongError" { - t.Errorf("expected error code KeyTooLongError, got %s", errResp.Code) - } -} - -type testObjectWriteLockFactory struct { - mu sync.Mutex - locks map[string]*sync.Mutex -} - -func (f *testObjectWriteLockFactory) newLock(bucket, object string) objectWriteLock { - key := bucket + "|" + object - - f.mu.Lock() - lock, ok := f.locks[key] - if !ok { - lock = &sync.Mutex{} - f.locks[key] = lock - } - f.mu.Unlock() - - lock.Lock() - return &testObjectWriteLock{unlock: lock.Unlock} -} - -type testObjectWriteLock struct { - once sync.Once - unlock func() -} - -func (l *testObjectWriteLock) StopShortLivedLock() error { - l.once.Do(l.unlock) - return nil -} - -func TestWithObjectWriteLockSerializesConcurrentPreconditions(t *testing.T) { - s3a := NewS3ApiServerForTest() - lockFactory := &testObjectWriteLockFactory{ - locks: make(map[string]*sync.Mutex), - } - s3a.newObjectWriteLock = lockFactory.newLock - - const workers = 3 - const bucket = "test-bucket" - const object = "/file.txt" - - start := make(chan struct{}) - results := make(chan s3err.ErrorCode, workers) - var wg sync.WaitGroup - - var stateMu sync.Mutex - objectExists := false - - for i := 0; i < workers; i++ { - wg.Add(1) - go func() { - defer wg.Done() - <-start - - errCode := s3a.withObjectWriteLock(bucket, object, - func() s3err.ErrorCode { - stateMu.Lock() - defer stateMu.Unlock() - if objectExists { - return s3err.ErrPreconditionFailed - } - return s3err.ErrNone - }, - func() s3err.ErrorCode { - stateMu.Lock() - defer stateMu.Unlock() - objectExists = true - return s3err.ErrNone - }, - ) - - results <- errCode - }() - } - - close(start) - wg.Wait() - close(results) - - var successCount int - var preconditionFailedCount int - - for errCode := range results { - switch errCode { - case s3err.ErrNone: - successCount++ - case s3err.ErrPreconditionFailed: - preconditionFailedCount++ - default: - t.Fatalf("unexpected error code: %v", errCode) - } - } - - if successCount != 1 { - t.Fatalf("expected exactly one successful writer, got %d", successCount) - } - if preconditionFailedCount != workers-1 { - t.Fatalf("expected %d precondition failures, got %d", workers-1, preconditionFailedCount) - } -} - -func TestResolveFileMode(t *testing.T) { - tests := []struct { - name string - acl string - defaultFileMode uint32 - expected uint32 - }{ - {"no acl, no default", "", 0, 0660}, - {"no acl, with default", "", 0644, 0644}, - {"private", s3_constants.CannedAclPrivate, 0, 0660}, - {"private overrides default", s3_constants.CannedAclPrivate, 0644, 0660}, - {"public-read", s3_constants.CannedAclPublicRead, 0, 0644}, - {"public-read overrides default", s3_constants.CannedAclPublicRead, 0666, 0644}, - {"public-read-write", s3_constants.CannedAclPublicReadWrite, 0, 0666}, - {"authenticated-read", s3_constants.CannedAclAuthenticatedRead, 0, 0644}, - {"bucket-owner-read", s3_constants.CannedAclBucketOwnerRead, 0, 0644}, - {"bucket-owner-full-control", s3_constants.CannedAclBucketOwnerFullControl, 0, 0660}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s3a := &S3ApiServer{ - option: &S3ApiServerOption{ - DefaultFileMode: tt.defaultFileMode, - }, - } - req := httptest.NewRequest(http.MethodPut, "/bucket/object", nil) - if tt.acl != "" { - req.Header.Set(s3_constants.AmzCannedAcl, tt.acl) - } - got := s3a.resolveFileMode(req) - if got != tt.expected { - t.Errorf("resolveFileMode() = %04o, want %04o", got, tt.expected) - } - }) - } -} diff --git a/weed/s3api/s3api_object_handlers_test.go b/weed/s3api/s3api_object_handlers_test.go deleted file mode 100644 index 5ca04c3ce..000000000 --- a/weed/s3api/s3api_object_handlers_test.go +++ /dev/null @@ -1,244 +0,0 @@ -package s3api - -import ( - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/stretchr/testify/assert" -) - -func TestNewListEntryOwnerDisplayName(t *testing.T) { - // Create S3ApiServer with a properly initialized IAM - s3a := &S3ApiServer{ - iam: &IdentityAccessManagement{ - accounts: map[string]*Account{ - "testid": {Id: "testid", DisplayName: "M. Tester"}, - "userid123": {Id: "userid123", DisplayName: "John Doe"}, - }, - }, - } - - // Create test entry with owner metadata - entry := &filer_pb.Entry{ - Name: "test-object", - Attributes: &filer_pb.FuseAttributes{ - Mtime: time.Now().Unix(), - FileSize: 1024, - }, - Extended: map[string][]byte{ - s3_constants.ExtAmzOwnerKey: []byte("testid"), - }, - } - - // Test that display name is correctly looked up from IAM - listEntry := newListEntry(s3a, entry, "", "dir", "test-object", "/buckets/test/", true, false, false) - - assert.NotNil(t, listEntry.Owner, "Owner should be set when fetchOwner is true") - assert.Equal(t, "testid", listEntry.Owner.ID, "Owner ID should match stored owner") - assert.Equal(t, "M. Tester", listEntry.Owner.DisplayName, "Display name should be looked up from IAM") - - // Test with owner that doesn't exist in IAM (should fallback to ID) - entry.Extended[s3_constants.ExtAmzOwnerKey] = []byte("unknown-user") - listEntry = newListEntry(s3a, entry, "", "dir", "test-object", "/buckets/test/", true, false, false) - - assert.Equal(t, "unknown-user", listEntry.Owner.ID, "Owner ID should match stored owner") - assert.Equal(t, "unknown-user", listEntry.Owner.DisplayName, "Display name should fallback to ID when not found in IAM") - - // Test with no owner metadata (should use anonymous) - entry.Extended = make(map[string][]byte) - listEntry = newListEntry(s3a, entry, "", "dir", "test-object", "/buckets/test/", true, false, false) - - assert.Equal(t, s3_constants.AccountAnonymousId, listEntry.Owner.ID, "Should use anonymous ID when no owner metadata") - assert.Equal(t, "anonymous", listEntry.Owner.DisplayName, "Should use anonymous display name when no owner metadata") - - // Test with fetchOwner false (should not set owner) - listEntry = newListEntry(s3a, entry, "", "dir", "test-object", "/buckets/test/", false, false, false) - - assert.Nil(t, listEntry.Owner, "Owner should not be set when fetchOwner is false") -} - -func TestRemoveDuplicateSlashes(t *testing.T) { - tests := []struct { - name string - path string - expectedResult string - }{ - { - name: "empty", - path: "", - expectedResult: "", - }, - { - name: "slash", - path: "/", - expectedResult: "/", - }, - { - name: "object", - path: "object", - expectedResult: "object", - }, - { - name: "correct path", - path: "/path/to/object", - expectedResult: "/path/to/object", - }, - { - name: "path with duplicates", - path: "///path//to/object//", - expectedResult: "/path/to/object/", - }, - } - - for _, tst := range tests { - t.Run(tst.name, func(t *testing.T) { - obj := removeDuplicateSlashes(tst.path) - assert.Equal(t, tst.expectedResult, obj) - }) - } -} - -func TestS3ApiServer_toFilerPath(t *testing.T) { - tests := []struct { - name string - args string - want string - }{ - { - "simple", - "/uploads/eaf10b3b-3b3a-4dcd-92a7-edf2a512276e/67b8b9bf-7cca-4cb6-9b34-22fcb4d6e27d/Bildschirmfoto 2022-09-19 um 21.38.37.png", - "/uploads/eaf10b3b-3b3a-4dcd-92a7-edf2a512276e/67b8b9bf-7cca-4cb6-9b34-22fcb4d6e27d/Bildschirmfoto%202022-09-19%20um%2021.38.37.png", - }, - { - "double prefix", - "//uploads/t.png", - "/uploads/t.png", - }, - { - "triple prefix", - "///uploads/t.png", - "/uploads/t.png", - }, - { - "empty prefix", - "uploads/t.png", - "/uploads/t.png", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equalf(t, tt.want, urlEscapeObject(tt.args), "clean %v", tt.args) - }) - } -} - -func TestPartNumberWithRangeHeader(t *testing.T) { - tests := []struct { - name string - partStartOffset int64 // Part's start offset in the object - partEndOffset int64 // Part's end offset in the object - clientRangeHeader string - expectedStart int64 // Expected absolute start offset - expectedEnd int64 // Expected absolute end offset - expectError bool - }{ - { - name: "No client range - full part", - partStartOffset: 1000, - partEndOffset: 1999, - clientRangeHeader: "", - expectedStart: 1000, - expectedEnd: 1999, - expectError: false, - }, - { - name: "Range within part - start and end", - partStartOffset: 1000, - partEndOffset: 1999, // Part size: 1000 bytes - clientRangeHeader: "bytes=0-99", - expectedStart: 1000, // 1000 + 0 - expectedEnd: 1099, // 1000 + 99 - expectError: false, - }, - { - name: "Range within part - start to end", - partStartOffset: 1000, - partEndOffset: 1999, - clientRangeHeader: "bytes=100-", - expectedStart: 1100, // 1000 + 100 - expectedEnd: 1999, // 1000 + 999 (end of part) - expectError: false, - }, - { - name: "Range suffix - last 100 bytes", - partStartOffset: 1000, - partEndOffset: 1999, // Part size: 1000 bytes - clientRangeHeader: "bytes=-100", - expectedStart: 1900, // 1000 + (1000 - 100) - expectedEnd: 1999, // 1000 + 999 - expectError: false, - }, - { - name: "Range suffix larger than part", - partStartOffset: 1000, - partEndOffset: 1999, // Part size: 1000 bytes - clientRangeHeader: "bytes=-2000", - expectedStart: 1000, // Start of part (clamped) - expectedEnd: 1999, // End of part - expectError: false, - }, - { - name: "Range start beyond part size", - partStartOffset: 1000, - partEndOffset: 1999, - clientRangeHeader: "bytes=1000-1100", - expectedStart: 0, - expectedEnd: 0, - expectError: true, - }, - { - name: "Range end clamped to part size", - partStartOffset: 1000, - partEndOffset: 1999, - clientRangeHeader: "bytes=0-2000", - expectedStart: 1000, // 1000 + 0 - expectedEnd: 1999, // Clamped to end of part - expectError: false, - }, - { - name: "Single byte range at start", - partStartOffset: 5000, - partEndOffset: 9999, // Part size: 5000 bytes - clientRangeHeader: "bytes=0-0", - expectedStart: 5000, - expectedEnd: 5000, - expectError: false, - }, - { - name: "Single byte range in middle", - partStartOffset: 5000, - partEndOffset: 9999, - clientRangeHeader: "bytes=100-100", - expectedStart: 5100, - expectedEnd: 5100, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test the actual range adjustment logic from GetObjectHandler - startOffset, endOffset, err := adjustRangeForPart(tt.partStartOffset, tt.partEndOffset, tt.clientRangeHeader) - - if tt.expectError { - assert.Error(t, err, "Expected error for range %s", tt.clientRangeHeader) - } else { - assert.NoError(t, err, "Unexpected error for range %s: %v", tt.clientRangeHeader, err) - assert.Equal(t, tt.expectedStart, startOffset, "Start offset mismatch") - assert.Equal(t, tt.expectedEnd, endOffset, "End offset mismatch") - } - }) - } -} diff --git a/weed/s3api/s3api_sosapi.go b/weed/s3api/s3api_sosapi.go index 53d7acdb4..673b60993 100644 --- a/weed/s3api/s3api_sosapi.go +++ b/weed/s3api/s3api_sosapi.go @@ -14,7 +14,6 @@ import ( "fmt" "net/http" "strconv" - "strings" "time" "github.com/seaweedfs/seaweedfs/weed/glog" @@ -97,13 +96,6 @@ func isSOSAPIObject(object string) bool { } } -// isSOSAPIClient checks if the request comes from a SOSAPI-compatible client -// by examining the User-Agent header. -func isSOSAPIClient(r *http.Request) bool { - userAgent := r.Header.Get("User-Agent") - return strings.Contains(userAgent, sosAPIClientUserAgent) -} - // generateSystemXML creates the system.xml response containing storage system // capabilities and recommendations. func generateSystemXML() ([]byte, error) { diff --git a/weed/s3api/s3api_sosapi_test.go b/weed/s3api/s3api_sosapi_test.go deleted file mode 100644 index c14bd16f6..000000000 --- a/weed/s3api/s3api_sosapi_test.go +++ /dev/null @@ -1,248 +0,0 @@ -package s3api - -import ( - "encoding/xml" - "net/http/httptest" - "strings" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/master_pb" -) - -func TestIsSOSAPIObject(t *testing.T) { - tests := []struct { - name string - object string - expected bool - }{ - { - name: "system.xml should be detected", - object: ".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/system.xml", - expected: true, - }, - { - name: "capacity.xml should be detected", - object: ".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/capacity.xml", - expected: true, - }, - { - name: "regular object should not be detected", - object: "myfile.txt", - expected: false, - }, - { - name: "similar but different path should not be detected", - object: ".system-other-uuid/system.xml", - expected: false, - }, - { - name: "nested path should not be detected", - object: "prefix/.system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/system.xml", - expected: false, - }, - { - name: "empty string should not be detected", - object: "", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := isSOSAPIObject(tt.object) - if result != tt.expected { - t.Errorf("isSOSAPIObject(%q) = %v, want %v", tt.object, result, tt.expected) - } - }) - } -} - -func TestIsSOSAPIClient(t *testing.T) { - tests := []struct { - name string - userAgent string - expected bool - }{ - { - name: "Veeam backup client should be detected", - userAgent: "APN/1.0 Veeam/1.0 Backup/10.0", - expected: true, - }, - { - name: "exact match should be detected", - userAgent: "APN/1.0 Veeam/1.0", - expected: true, - }, - { - name: "AWS CLI should not be detected", - userAgent: "aws-cli/2.0.0 Python/3.8", - expected: false, - }, - { - name: "empty user agent should not be detected", - userAgent: "", - expected: false, - }, - { - name: "partial match should not be detected", - userAgent: "Veeam/1.0", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/bucket/object", nil) - req.Header.Set("User-Agent", tt.userAgent) - result := isSOSAPIClient(req) - if result != tt.expected { - t.Errorf("isSOSAPIClient() with User-Agent %q = %v, want %v", tt.userAgent, result, tt.expected) - } - }) - } -} - -func TestGenerateSystemXML(t *testing.T) { - xmlData, err := generateSystemXML() - if err != nil { - t.Fatalf("generateSystemXML() failed: %v", err) - } - - // Verify it's valid XML - var si SystemInfo - if err := xml.Unmarshal(xmlData, &si); err != nil { - t.Fatalf("generated XML is invalid: %v", err) - } - - // Verify required fields - if si.ProtocolVersion != sosAPIProtocolVersion { - t.Errorf("ProtocolVersion = %q, want %q", si.ProtocolVersion, sosAPIProtocolVersion) - } - - if !strings.Contains(si.ModelName, "SeaweedFS") { - t.Errorf("ModelName = %q, should contain 'SeaweedFS'", si.ModelName) - } - - if !si.ProtocolCapabilities.CapacityInfo { - t.Error("ProtocolCapabilities.CapacityInfo should be true") - } - - if si.SystemRecommendations == nil { - t.Fatal("SystemRecommendations should not be nil") - } - - if si.SystemRecommendations.KBBlockSize != sosAPIDefaultBlockSizeKB { - t.Errorf("KBBlockSize = %d, want %d", si.SystemRecommendations.KBBlockSize, sosAPIDefaultBlockSizeKB) - } -} - -func TestSOSAPIObjectDetectionEdgeCases(t *testing.T) { - edgeCases := []struct { - object string - expected bool - }{ - // With leading slash - {"/.system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/system.xml", false}, - // URL encoded - {".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c%2Fsystem.xml", false}, - // Mixed case - {".System-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/system.xml", false}, - // Extra slashes - {".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c//system.xml", false}, - // Correct paths - {".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/system.xml", true}, - {".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/capacity.xml", true}, - } - - for _, tc := range edgeCases { - result := isSOSAPIObject(tc.object) - if result != tc.expected { - t.Errorf("isSOSAPIObject(%q) = %v, want %v", tc.object, result, tc.expected) - } - } -} - -func TestCollectBucketUsageFromTopology(t *testing.T) { - topo := &master_pb.TopologyInfo{ - DataCenterInfos: []*master_pb.DataCenterInfo{ - { - RackInfos: []*master_pb.RackInfo{ - { - DataNodeInfos: []*master_pb.DataNodeInfo{ - { - DiskInfos: map[string]*master_pb.DiskInfo{ - "hdd": { - VolumeInfos: []*master_pb.VolumeInformationMessage{ - {Id: 1, Size: 100, Collection: "bucket1"}, - {Id: 2, Size: 200, Collection: "bucket2"}, - {Id: 3, Size: 300, Collection: "bucket1"}, - {Id: 1, Size: 100, Collection: "bucket1"}, // Duplicate (replica), should be ignored - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - - usage := collectBucketUsageFromTopology(topo, "bucket1") - expected := int64(400) // 100 + 300 - if usage != expected { - t.Errorf("collectBucketUsageFromTopology = %d, want %d", usage, expected) - } - - usage2 := collectBucketUsageFromTopology(topo, "bucket2") - expected2 := int64(200) - if usage2 != expected2 { - t.Errorf("collectBucketUsageFromTopology = %d, want %d", usage2, expected2) - } -} - -func TestCalculateClusterCapacity(t *testing.T) { - topo := &master_pb.TopologyInfo{ - DataCenterInfos: []*master_pb.DataCenterInfo{ - { - RackInfos: []*master_pb.RackInfo{ - { - DataNodeInfos: []*master_pb.DataNodeInfo{ - { - DiskInfos: map[string]*master_pb.DiskInfo{ - "hdd": { - MaxVolumeCount: 100, - FreeVolumeCount: 40, - }, - }, - }, - { - DiskInfos: map[string]*master_pb.DiskInfo{ - "hdd": { - MaxVolumeCount: 200, - FreeVolumeCount: 160, - }, - }, - }, - }, - }, - }, - }, - }, - } - - volumeSizeLimitMb := uint64(1000) // 1GB - volumeSizeBytes := int64(1000) * 1024 * 1024 - - total, available := calculateClusterCapacity(topo, volumeSizeLimitMb) - - expectedTotal := int64(300) * volumeSizeBytes - expectedAvailable := int64(200) * volumeSizeBytes - - if total != expectedTotal { - t.Errorf("calculateClusterCapacity total = %d, want %d", total, expectedTotal) - } - if available != expectedAvailable { - t.Errorf("calculateClusterCapacity available = %d, want %d", available, expectedAvailable) - } -} diff --git a/weed/s3api/s3api_sse_chunk_metadata_test.go b/weed/s3api/s3api_sse_chunk_metadata_test.go deleted file mode 100644 index ca38f44f4..000000000 --- a/weed/s3api/s3api_sse_chunk_metadata_test.go +++ /dev/null @@ -1,361 +0,0 @@ -package s3api - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "encoding/json" - "io" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" -) - -// TestSSEKMSChunkMetadataAssignment tests that SSE-KMS creates per-chunk metadata -// with correct ChunkOffset values for each chunk (matching the fix in putToFiler) -func TestSSEKMSChunkMetadataAssignment(t *testing.T) { - kmsKey := SetupTestKMS(t) - defer kmsKey.Cleanup() - - // Generate SSE-KMS key by encrypting test data (this gives us a real SSEKMSKey) - encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false) - testData := "Test data for SSE-KMS chunk metadata validation" - encryptedReader, sseKMSKey, err := CreateSSEKMSEncryptedReader(bytes.NewReader([]byte(testData)), kmsKey.KeyID, encryptionContext) - if err != nil { - t.Fatalf("Failed to create encrypted reader: %v", err) - } - // Read to complete encryption setup - io.ReadAll(encryptedReader) - - // Serialize the base metadata (what putToFiler receives before chunking) - baseMetadata, err := SerializeSSEKMSMetadata(sseKMSKey) - if err != nil { - t.Fatalf("Failed to serialize base SSE-KMS metadata: %v", err) - } - - // Simulate multi-chunk upload scenario (what putToFiler does after UploadReaderInChunks) - simulatedChunks := []*filer_pb.FileChunk{ - {FileId: "chunk1", Offset: 0, Size: 8 * 1024 * 1024}, // 8MB chunk at offset 0 - {FileId: "chunk2", Offset: 8 * 1024 * 1024, Size: 8 * 1024 * 1024}, // 8MB chunk at offset 8MB - {FileId: "chunk3", Offset: 16 * 1024 * 1024, Size: 4 * 1024 * 1024}, // 4MB chunk at offset 16MB - } - - // THIS IS THE CRITICAL FIX: Create per-chunk metadata (lines 421-443 in putToFiler) - for _, chunk := range simulatedChunks { - chunk.SseType = filer_pb.SSEType_SSE_KMS - - // Create a copy of the SSE-KMS key with chunk-specific offset - chunkSSEKey := &SSEKMSKey{ - KeyID: sseKMSKey.KeyID, - EncryptedDataKey: sseKMSKey.EncryptedDataKey, - EncryptionContext: sseKMSKey.EncryptionContext, - BucketKeyEnabled: sseKMSKey.BucketKeyEnabled, - IV: sseKMSKey.IV, - ChunkOffset: chunk.Offset, // Set chunk-specific offset - } - - // Serialize per-chunk metadata - chunkMetadata, serErr := SerializeSSEKMSMetadata(chunkSSEKey) - if serErr != nil { - t.Fatalf("Failed to serialize SSE-KMS metadata for chunk at offset %d: %v", chunk.Offset, serErr) - } - chunk.SseMetadata = chunkMetadata - } - - // VERIFICATION 1: Each chunk should have different metadata (due to different ChunkOffset) - metadataSet := make(map[string]bool) - for i, chunk := range simulatedChunks { - metadataStr := string(chunk.SseMetadata) - if metadataSet[metadataStr] { - t.Errorf("Chunk %d has duplicate metadata (should be unique per chunk)", i) - } - metadataSet[metadataStr] = true - - // Deserialize and verify ChunkOffset - var metadata SSEKMSMetadata - if err := json.Unmarshal(chunk.SseMetadata, &metadata); err != nil { - t.Fatalf("Failed to deserialize chunk %d metadata: %v", i, err) - } - - expectedOffset := chunk.Offset - if metadata.PartOffset != expectedOffset { - t.Errorf("Chunk %d: expected PartOffset=%d, got %d", i, expectedOffset, metadata.PartOffset) - } - - t.Logf("✓ Chunk %d: PartOffset=%d (correct)", i, metadata.PartOffset) - } - - // VERIFICATION 2: Verify metadata can be deserialized and has correct ChunkOffset - for i, chunk := range simulatedChunks { - // Deserialize chunk metadata - deserializedKey, err := DeserializeSSEKMSMetadata(chunk.SseMetadata) - if err != nil { - t.Fatalf("Failed to deserialize chunk %d metadata: %v", i, err) - } - - // Verify the deserialized key has correct ChunkOffset - if deserializedKey.ChunkOffset != chunk.Offset { - t.Errorf("Chunk %d: deserialized ChunkOffset=%d, expected %d", - i, deserializedKey.ChunkOffset, chunk.Offset) - } - - // Verify IV is set (should be inherited from base) - if len(deserializedKey.IV) != aes.BlockSize { - t.Errorf("Chunk %d: invalid IV length: %d", i, len(deserializedKey.IV)) - } - - // Verify KeyID matches - if deserializedKey.KeyID != sseKMSKey.KeyID { - t.Errorf("Chunk %d: KeyID mismatch", i) - } - - t.Logf("✓ Chunk %d: metadata deserialized successfully (ChunkOffset=%d, KeyID=%s)", - i, deserializedKey.ChunkOffset, deserializedKey.KeyID) - } - - // VERIFICATION 3: Ensure base metadata is NOT reused (the bug we're preventing) - var baseMetadataStruct SSEKMSMetadata - if err := json.Unmarshal(baseMetadata, &baseMetadataStruct); err != nil { - t.Fatalf("Failed to deserialize base metadata: %v", err) - } - - // Base metadata should have ChunkOffset=0 - if baseMetadataStruct.PartOffset != 0 { - t.Errorf("Base metadata should have PartOffset=0, got %d", baseMetadataStruct.PartOffset) - } - - // Chunks 2 and 3 should NOT have the same metadata as base (proving we're not reusing) - for i := 1; i < len(simulatedChunks); i++ { - if bytes.Equal(simulatedChunks[i].SseMetadata, baseMetadata) { - t.Errorf("CRITICAL BUG: Chunk %d reuses base metadata (should have per-chunk metadata)", i) - } - } - - t.Log("✓ All chunks have unique per-chunk metadata (bug prevented)") -} - -// TestSSES3ChunkMetadataAssignment tests that SSE-S3 creates per-chunk metadata -// with offset-adjusted IVs for each chunk (matching the fix in putToFiler) -func TestSSES3ChunkMetadataAssignment(t *testing.T) { - // Initialize global SSE-S3 key manager - globalSSES3KeyManager = NewSSES3KeyManager() - defer func() { - globalSSES3KeyManager = NewSSES3KeyManager() - }() - - keyManager := GetSSES3KeyManager() - keyManager.superKey = make([]byte, 32) - rand.Read(keyManager.superKey) - - // Generate SSE-S3 key - sseS3Key, err := GenerateSSES3Key() - if err != nil { - t.Fatalf("Failed to generate SSE-S3 key: %v", err) - } - - // Generate base IV - baseIV := make([]byte, aes.BlockSize) - rand.Read(baseIV) - sseS3Key.IV = baseIV - - // Serialize base metadata (what putToFiler receives) - baseMetadata, err := SerializeSSES3Metadata(sseS3Key) - if err != nil { - t.Fatalf("Failed to serialize base SSE-S3 metadata: %v", err) - } - - // Simulate multi-chunk upload scenario (what putToFiler does after UploadReaderInChunks) - simulatedChunks := []*filer_pb.FileChunk{ - {FileId: "chunk1", Offset: 0, Size: 8 * 1024 * 1024}, // 8MB chunk at offset 0 - {FileId: "chunk2", Offset: 8 * 1024 * 1024, Size: 8 * 1024 * 1024}, // 8MB chunk at offset 8MB - {FileId: "chunk3", Offset: 16 * 1024 * 1024, Size: 4 * 1024 * 1024}, // 4MB chunk at offset 16MB - } - - // THIS IS THE CRITICAL FIX: Create per-chunk metadata (lines 444-468 in putToFiler) - for _, chunk := range simulatedChunks { - chunk.SseType = filer_pb.SSEType_SSE_S3 - - // Calculate chunk-specific IV using base IV and chunk offset - chunkIV, _ := calculateIVWithOffset(sseS3Key.IV, chunk.Offset) - - // Create a copy of the SSE-S3 key with chunk-specific IV - chunkSSEKey := &SSES3Key{ - Key: sseS3Key.Key, - KeyID: sseS3Key.KeyID, - Algorithm: sseS3Key.Algorithm, - IV: chunkIV, // Use chunk-specific IV - } - - // Serialize per-chunk metadata - chunkMetadata, serErr := SerializeSSES3Metadata(chunkSSEKey) - if serErr != nil { - t.Fatalf("Failed to serialize SSE-S3 metadata for chunk at offset %d: %v", chunk.Offset, serErr) - } - chunk.SseMetadata = chunkMetadata - } - - // VERIFICATION 1: Each chunk should have different metadata (due to different IVs) - metadataSet := make(map[string]bool) - for i, chunk := range simulatedChunks { - metadataStr := string(chunk.SseMetadata) - if metadataSet[metadataStr] { - t.Errorf("Chunk %d has duplicate metadata (should be unique per chunk)", i) - } - metadataSet[metadataStr] = true - - // Deserialize and verify IV - deserializedKey, err := DeserializeSSES3Metadata(chunk.SseMetadata, keyManager) - if err != nil { - t.Fatalf("Failed to deserialize chunk %d metadata: %v", i, err) - } - - // Calculate expected IV for this chunk - expectedIV, _ := calculateIVWithOffset(baseIV, chunk.Offset) - if !bytes.Equal(deserializedKey.IV, expectedIV) { - t.Errorf("Chunk %d: IV mismatch\nExpected: %x\nGot: %x", - i, expectedIV[:8], deserializedKey.IV[:8]) - } - - t.Logf("✓ Chunk %d: IV correctly adjusted for offset=%d", i, chunk.Offset) - } - - // VERIFICATION 2: Verify decryption works with per-chunk IVs - for i, chunk := range simulatedChunks { - // Deserialize chunk metadata - deserializedKey, err := DeserializeSSES3Metadata(chunk.SseMetadata, keyManager) - if err != nil { - t.Fatalf("Failed to deserialize chunk %d metadata: %v", i, err) - } - - // Simulate encryption/decryption with the chunk's IV - testData := []byte("Test data for SSE-S3 chunk decryption verification") - block, err := aes.NewCipher(deserializedKey.Key) - if err != nil { - t.Fatalf("Failed to create cipher: %v", err) - } - - // Encrypt with chunk's IV - ciphertext := make([]byte, len(testData)) - stream := cipher.NewCTR(block, deserializedKey.IV) - stream.XORKeyStream(ciphertext, testData) - - // Decrypt with chunk's IV - plaintext := make([]byte, len(ciphertext)) - block2, _ := aes.NewCipher(deserializedKey.Key) - stream2 := cipher.NewCTR(block2, deserializedKey.IV) - stream2.XORKeyStream(plaintext, ciphertext) - - if !bytes.Equal(plaintext, testData) { - t.Errorf("Chunk %d: decryption failed", i) - } - - t.Logf("✓ Chunk %d: encryption/decryption successful with chunk-specific IV", i) - } - - // VERIFICATION 3: Ensure base IV is NOT reused for non-zero offset chunks (the bug we're preventing) - for i := 1; i < len(simulatedChunks); i++ { - if bytes.Equal(simulatedChunks[i].SseMetadata, baseMetadata) { - t.Errorf("CRITICAL BUG: Chunk %d reuses base metadata (should have per-chunk metadata)", i) - } - - // Verify chunk metadata has different IV than base IV - deserializedKey, _ := DeserializeSSES3Metadata(simulatedChunks[i].SseMetadata, keyManager) - if bytes.Equal(deserializedKey.IV, baseIV) { - t.Errorf("CRITICAL BUG: Chunk %d uses base IV (should use offset-adjusted IV)", i) - } - } - - t.Log("✓ All chunks have unique per-chunk IVs (bug prevented)") -} - -// TestSSEChunkMetadataComparison tests that the bug (reusing same metadata for all chunks) -// would cause decryption failures, while the fix (per-chunk metadata) works correctly -func TestSSEChunkMetadataComparison(t *testing.T) { - // Generate test key and IV - key := make([]byte, 32) - rand.Read(key) - baseIV := make([]byte, aes.BlockSize) - rand.Read(baseIV) - - // Create test data for 3 chunks - chunk0Data := []byte("Chunk 0 data at offset 0") - chunk1Data := []byte("Chunk 1 data at offset 8MB") - chunk2Data := []byte("Chunk 2 data at offset 16MB") - - chunkOffsets := []int64{0, 8 * 1024 * 1024, 16 * 1024 * 1024} - chunkDataList := [][]byte{chunk0Data, chunk1Data, chunk2Data} - - // Scenario 1: BUG - Using same IV for all chunks (what the old code did) - t.Run("Bug: Reusing base IV causes decryption failures", func(t *testing.T) { - var encryptedChunks [][]byte - - // Encrypt each chunk with offset-adjusted IV (what encryption does) - for i, offset := range chunkOffsets { - adjustedIV, _ := calculateIVWithOffset(baseIV, offset) - block, _ := aes.NewCipher(key) - stream := cipher.NewCTR(block, adjustedIV) - - ciphertext := make([]byte, len(chunkDataList[i])) - stream.XORKeyStream(ciphertext, chunkDataList[i]) - encryptedChunks = append(encryptedChunks, ciphertext) - } - - // Try to decrypt with base IV (THE BUG) - for i := range encryptedChunks { - block, _ := aes.NewCipher(key) - stream := cipher.NewCTR(block, baseIV) // BUG: Always using base IV - - plaintext := make([]byte, len(encryptedChunks[i])) - stream.XORKeyStream(plaintext, encryptedChunks[i]) - - if i == 0 { - // Chunk 0 should work (offset 0 means base IV = adjusted IV) - if !bytes.Equal(plaintext, chunkDataList[i]) { - t.Errorf("Chunk 0 decryption failed (unexpected)") - } - } else { - // Chunks 1 and 2 should FAIL (wrong IV) - if bytes.Equal(plaintext, chunkDataList[i]) { - t.Errorf("BUG NOT REPRODUCED: Chunk %d decrypted correctly with base IV (should fail)", i) - } else { - t.Logf("✓ Chunk %d: Correctly failed to decrypt with base IV (bug reproduced)", i) - } - } - } - }) - - // Scenario 2: FIX - Using per-chunk offset-adjusted IVs (what the new code does) - t.Run("Fix: Per-chunk IVs enable correct decryption", func(t *testing.T) { - var encryptedChunks [][]byte - var chunkIVs [][]byte - - // Encrypt each chunk with offset-adjusted IV - for i, offset := range chunkOffsets { - adjustedIV, _ := calculateIVWithOffset(baseIV, offset) - chunkIVs = append(chunkIVs, adjustedIV) - - block, _ := aes.NewCipher(key) - stream := cipher.NewCTR(block, adjustedIV) - - ciphertext := make([]byte, len(chunkDataList[i])) - stream.XORKeyStream(ciphertext, chunkDataList[i]) - encryptedChunks = append(encryptedChunks, ciphertext) - } - - // Decrypt with per-chunk IVs (THE FIX) - for i := range encryptedChunks { - block, _ := aes.NewCipher(key) - stream := cipher.NewCTR(block, chunkIVs[i]) // FIX: Using per-chunk IV - - plaintext := make([]byte, len(encryptedChunks[i])) - stream.XORKeyStream(plaintext, encryptedChunks[i]) - - if !bytes.Equal(plaintext, chunkDataList[i]) { - t.Errorf("Chunk %d decryption failed with per-chunk IV (unexpected)", i) - } else { - t.Logf("✓ Chunk %d: Successfully decrypted with per-chunk IV", i) - } - } - }) -} diff --git a/weed/s3api/s3api_streaming_copy.go b/weed/s3api/s3api_streaming_copy.go deleted file mode 100644 index f50f715e3..000000000 --- a/weed/s3api/s3api_streaming_copy.go +++ /dev/null @@ -1,601 +0,0 @@ -package s3api - -import ( - "context" - "crypto/md5" - "crypto/sha256" - "encoding/hex" - "fmt" - "hash" - "io" - "net/http" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" - "github.com/seaweedfs/seaweedfs/weed/util" -) - -// StreamingCopySpec defines the specification for streaming copy operations -type StreamingCopySpec struct { - SourceReader io.Reader - TargetSize int64 - EncryptionSpec *EncryptionSpec - CompressionSpec *CompressionSpec - HashCalculation bool - BufferSize int -} - -// EncryptionSpec defines encryption parameters for streaming -type EncryptionSpec struct { - NeedsDecryption bool - NeedsEncryption bool - SourceKey interface{} // SSECustomerKey or SSEKMSKey - DestinationKey interface{} // SSECustomerKey or SSEKMSKey - SourceType EncryptionType - DestinationType EncryptionType - SourceMetadata map[string][]byte // Source metadata for IV extraction - DestinationIV []byte // Generated IV for destination -} - -// CompressionSpec defines compression parameters for streaming -type CompressionSpec struct { - IsCompressed bool - CompressionType string - NeedsDecompression bool - NeedsCompression bool -} - -// StreamingCopyManager handles streaming copy operations -type StreamingCopyManager struct { - s3a *S3ApiServer - bufferSize int -} - -// NewStreamingCopyManager creates a new streaming copy manager -func NewStreamingCopyManager(s3a *S3ApiServer) *StreamingCopyManager { - return &StreamingCopyManager{ - s3a: s3a, - bufferSize: 64 * 1024, // 64KB default buffer - } -} - -// ExecuteStreamingCopy performs a streaming copy operation and returns the encryption spec -// The encryption spec is needed for SSE-S3 to properly set destination metadata (fixes GitHub #7562) -func (scm *StreamingCopyManager) ExecuteStreamingCopy(ctx context.Context, entry *filer_pb.Entry, r *http.Request, dstPath string, state *EncryptionState) ([]*filer_pb.FileChunk, *EncryptionSpec, error) { - // Create streaming copy specification - spec, err := scm.createStreamingSpec(entry, r, state) - if err != nil { - return nil, nil, fmt.Errorf("create streaming spec: %w", err) - } - - // Create source reader from entry - sourceReader, err := scm.createSourceReader(entry) - if err != nil { - return nil, nil, fmt.Errorf("create source reader: %w", err) - } - defer sourceReader.Close() - - spec.SourceReader = sourceReader - - // Create processing pipeline - processedReader, err := scm.createProcessingPipeline(spec) - if err != nil { - return nil, nil, fmt.Errorf("create processing pipeline: %w", err) - } - - // Stream to destination - chunks, err := scm.streamToDestination(ctx, processedReader, spec, dstPath) - if err != nil { - return nil, nil, err - } - - return chunks, spec.EncryptionSpec, nil -} - -// createStreamingSpec creates a streaming specification based on copy parameters -func (scm *StreamingCopyManager) createStreamingSpec(entry *filer_pb.Entry, r *http.Request, state *EncryptionState) (*StreamingCopySpec, error) { - spec := &StreamingCopySpec{ - BufferSize: scm.bufferSize, - HashCalculation: true, - } - - // Calculate target size - sizeCalc := NewCopySizeCalculator(entry, r) - spec.TargetSize = sizeCalc.CalculateTargetSize() - - // Create encryption specification - encSpec, err := scm.createEncryptionSpec(entry, r, state) - if err != nil { - return nil, err - } - spec.EncryptionSpec = encSpec - - // Create compression specification - spec.CompressionSpec = scm.createCompressionSpec(entry, r) - - return spec, nil -} - -// createEncryptionSpec creates encryption specification for streaming -func (scm *StreamingCopyManager) createEncryptionSpec(entry *filer_pb.Entry, r *http.Request, state *EncryptionState) (*EncryptionSpec, error) { - spec := &EncryptionSpec{ - NeedsDecryption: state.IsSourceEncrypted(), - NeedsEncryption: state.IsTargetEncrypted(), - SourceMetadata: entry.Extended, // Pass source metadata for IV extraction - } - - // Set source encryption details - if state.SrcSSEC { - spec.SourceType = EncryptionTypeSSEC - sourceKey, err := ParseSSECCopySourceHeaders(r) - if err != nil { - return nil, fmt.Errorf("parse SSE-C copy source headers: %w", err) - } - spec.SourceKey = sourceKey - } else if state.SrcSSEKMS { - spec.SourceType = EncryptionTypeSSEKMS - // Extract SSE-KMS key from metadata - if keyData, exists := entry.Extended[s3_constants.SeaweedFSSSEKMSKey]; exists { - sseKey, err := DeserializeSSEKMSMetadata(keyData) - if err != nil { - return nil, fmt.Errorf("deserialize SSE-KMS metadata: %w", err) - } - spec.SourceKey = sseKey - } - } else if state.SrcSSES3 { - spec.SourceType = EncryptionTypeSSES3 - // Extract SSE-S3 key from metadata - if keyData, exists := entry.Extended[s3_constants.SeaweedFSSSES3Key]; exists { - keyManager := GetSSES3KeyManager() - sseKey, err := DeserializeSSES3Metadata(keyData, keyManager) - if err != nil { - return nil, fmt.Errorf("deserialize SSE-S3 metadata: %w", err) - } - spec.SourceKey = sseKey - } - } - - // Set destination encryption details - if state.DstSSEC { - spec.DestinationType = EncryptionTypeSSEC - destKey, err := ParseSSECHeaders(r) - if err != nil { - return nil, fmt.Errorf("parse SSE-C headers: %w", err) - } - spec.DestinationKey = destKey - } else if state.DstSSEKMS { - spec.DestinationType = EncryptionTypeSSEKMS - // Parse KMS parameters - keyID, encryptionContext, bucketKeyEnabled, err := ParseSSEKMSCopyHeaders(r) - if err != nil { - return nil, fmt.Errorf("parse SSE-KMS copy headers: %w", err) - } - - // Create SSE-KMS key for destination - sseKey := &SSEKMSKey{ - KeyID: keyID, - EncryptionContext: encryptionContext, - BucketKeyEnabled: bucketKeyEnabled, - } - spec.DestinationKey = sseKey - } else if state.DstSSES3 { - spec.DestinationType = EncryptionTypeSSES3 - // Generate or retrieve SSE-S3 key - keyManager := GetSSES3KeyManager() - sseKey, err := keyManager.GetOrCreateKey("") - if err != nil { - return nil, fmt.Errorf("get SSE-S3 key: %w", err) - } - spec.DestinationKey = sseKey - } - - return spec, nil -} - -// createCompressionSpec creates compression specification for streaming -func (scm *StreamingCopyManager) createCompressionSpec(entry *filer_pb.Entry, r *http.Request) *CompressionSpec { - return &CompressionSpec{ - IsCompressed: isCompressedEntry(entry), - // For now, we don't change compression during copy - NeedsDecompression: false, - NeedsCompression: false, - } -} - -// createSourceReader creates a reader for the source entry -func (scm *StreamingCopyManager) createSourceReader(entry *filer_pb.Entry) (io.ReadCloser, error) { - // Create a multi-chunk reader that streams from all chunks - return scm.s3a.createMultiChunkReader(entry) -} - -// createProcessingPipeline creates a processing pipeline for the copy operation -func (scm *StreamingCopyManager) createProcessingPipeline(spec *StreamingCopySpec) (io.Reader, error) { - reader := spec.SourceReader - - // Add decryption if needed - if spec.EncryptionSpec.NeedsDecryption { - decryptedReader, err := scm.createDecryptionReader(reader, spec.EncryptionSpec) - if err != nil { - return nil, fmt.Errorf("create decryption reader: %w", err) - } - reader = decryptedReader - } - - // Add decompression if needed - if spec.CompressionSpec.NeedsDecompression { - decompressedReader, err := scm.createDecompressionReader(reader, spec.CompressionSpec) - if err != nil { - return nil, fmt.Errorf("create decompression reader: %w", err) - } - reader = decompressedReader - } - - // Add compression if needed - if spec.CompressionSpec.NeedsCompression { - compressedReader, err := scm.createCompressionReader(reader, spec.CompressionSpec) - if err != nil { - return nil, fmt.Errorf("create compression reader: %w", err) - } - reader = compressedReader - } - - // Add encryption if needed - if spec.EncryptionSpec.NeedsEncryption { - encryptedReader, err := scm.createEncryptionReader(reader, spec.EncryptionSpec) - if err != nil { - return nil, fmt.Errorf("create encryption reader: %w", err) - } - reader = encryptedReader - } - - // Add hash calculation if needed - if spec.HashCalculation { - reader = scm.createHashReader(reader) - } - - return reader, nil -} - -// createDecryptionReader creates a decryption reader based on encryption type -func (scm *StreamingCopyManager) createDecryptionReader(reader io.Reader, encSpec *EncryptionSpec) (io.Reader, error) { - switch encSpec.SourceType { - case EncryptionTypeSSEC: - if sourceKey, ok := encSpec.SourceKey.(*SSECustomerKey); ok { - // Get IV from metadata - iv, err := GetSSECIVFromMetadata(encSpec.SourceMetadata) - if err != nil { - return nil, fmt.Errorf("get IV from metadata: %w", err) - } - return CreateSSECDecryptedReader(reader, sourceKey, iv) - } - return nil, fmt.Errorf("invalid SSE-C source key type") - - case EncryptionTypeSSEKMS: - if sseKey, ok := encSpec.SourceKey.(*SSEKMSKey); ok { - return CreateSSEKMSDecryptedReader(reader, sseKey) - } - return nil, fmt.Errorf("invalid SSE-KMS source key type") - - case EncryptionTypeSSES3: - if sseKey, ok := encSpec.SourceKey.(*SSES3Key); ok { - // For SSE-S3, the IV is stored within the SSES3Key metadata, not as separate metadata - iv := sseKey.IV - if len(iv) == 0 { - return nil, fmt.Errorf("SSE-S3 key is missing IV for streaming copy") - } - return CreateSSES3DecryptedReader(reader, sseKey, iv) - } - return nil, fmt.Errorf("invalid SSE-S3 source key type") - - default: - return reader, nil - } -} - -// createEncryptionReader creates an encryption reader based on encryption type -func (scm *StreamingCopyManager) createEncryptionReader(reader io.Reader, encSpec *EncryptionSpec) (io.Reader, error) { - switch encSpec.DestinationType { - case EncryptionTypeSSEC: - if destKey, ok := encSpec.DestinationKey.(*SSECustomerKey); ok { - encryptedReader, iv, err := CreateSSECEncryptedReader(reader, destKey) - if err != nil { - return nil, err - } - // Store IV in destination metadata (this would need to be handled by caller) - encSpec.DestinationIV = iv - return encryptedReader, nil - } - return nil, fmt.Errorf("invalid SSE-C destination key type") - - case EncryptionTypeSSEKMS: - if sseKey, ok := encSpec.DestinationKey.(*SSEKMSKey); ok { - encryptedReader, updatedKey, err := CreateSSEKMSEncryptedReaderWithBucketKey(reader, sseKey.KeyID, sseKey.EncryptionContext, sseKey.BucketKeyEnabled) - if err != nil { - return nil, err - } - // Store IV from the updated key - encSpec.DestinationIV = updatedKey.IV - return encryptedReader, nil - } - return nil, fmt.Errorf("invalid SSE-KMS destination key type") - - case EncryptionTypeSSES3: - if sseKey, ok := encSpec.DestinationKey.(*SSES3Key); ok { - encryptedReader, iv, err := CreateSSES3EncryptedReader(reader, sseKey) - if err != nil { - return nil, err - } - // Store IV for metadata - encSpec.DestinationIV = iv - return encryptedReader, nil - } - return nil, fmt.Errorf("invalid SSE-S3 destination key type") - - default: - return reader, nil - } -} - -// createDecompressionReader creates a decompression reader -func (scm *StreamingCopyManager) createDecompressionReader(reader io.Reader, compSpec *CompressionSpec) (io.Reader, error) { - if !compSpec.NeedsDecompression { - return reader, nil - } - - switch compSpec.CompressionType { - case "gzip": - // Use SeaweedFS's streaming gzip decompression - pr, pw := io.Pipe() - go func() { - defer pw.Close() - _, err := util.GunzipStream(pw, reader) - if err != nil { - pw.CloseWithError(fmt.Errorf("gzip decompression failed: %v", err)) - } - }() - return pr, nil - default: - // Unknown compression type, return as-is - return reader, nil - } -} - -// createCompressionReader creates a compression reader -func (scm *StreamingCopyManager) createCompressionReader(reader io.Reader, compSpec *CompressionSpec) (io.Reader, error) { - if !compSpec.NeedsCompression { - return reader, nil - } - - switch compSpec.CompressionType { - case "gzip": - // Use SeaweedFS's streaming gzip compression - pr, pw := io.Pipe() - go func() { - defer pw.Close() - _, err := util.GzipStream(pw, reader) - if err != nil { - pw.CloseWithError(fmt.Errorf("gzip compression failed: %v", err)) - } - }() - return pr, nil - default: - // Unknown compression type, return as-is - return reader, nil - } -} - -// HashReader wraps an io.Reader to calculate MD5 and SHA256 hashes -type HashReader struct { - reader io.Reader - md5Hash hash.Hash - sha256Hash hash.Hash -} - -// NewHashReader creates a new hash calculating reader -func NewHashReader(reader io.Reader) *HashReader { - return &HashReader{ - reader: reader, - md5Hash: md5.New(), - sha256Hash: sha256.New(), - } -} - -// Read implements io.Reader and calculates hashes as data flows through -func (hr *HashReader) Read(p []byte) (n int, err error) { - n, err = hr.reader.Read(p) - if n > 0 { - // Update both hashes with the data read - hr.md5Hash.Write(p[:n]) - hr.sha256Hash.Write(p[:n]) - } - return n, err -} - -// MD5Sum returns the current MD5 hash -func (hr *HashReader) MD5Sum() []byte { - return hr.md5Hash.Sum(nil) -} - -// SHA256Sum returns the current SHA256 hash -func (hr *HashReader) SHA256Sum() []byte { - return hr.sha256Hash.Sum(nil) -} - -// MD5Hex returns the MD5 hash as a hex string -func (hr *HashReader) MD5Hex() string { - return hex.EncodeToString(hr.MD5Sum()) -} - -// SHA256Hex returns the SHA256 hash as a hex string -func (hr *HashReader) SHA256Hex() string { - return hex.EncodeToString(hr.SHA256Sum()) -} - -// createHashReader creates a hash calculation reader -func (scm *StreamingCopyManager) createHashReader(reader io.Reader) io.Reader { - return NewHashReader(reader) -} - -// streamToDestination streams the processed data to the destination -func (scm *StreamingCopyManager) streamToDestination(ctx context.Context, reader io.Reader, spec *StreamingCopySpec, dstPath string) ([]*filer_pb.FileChunk, error) { - // For now, we'll use the existing chunk-based approach - // In a full implementation, this would stream directly to the destination - // without creating intermediate chunks - - // This is a placeholder that converts back to chunk-based approach - // A full streaming implementation would write directly to the destination - return scm.streamToChunks(ctx, reader, spec, dstPath) -} - -// streamToChunks converts streaming data back to chunks (temporary implementation) -func (scm *StreamingCopyManager) streamToChunks(ctx context.Context, reader io.Reader, spec *StreamingCopySpec, dstPath string) ([]*filer_pb.FileChunk, error) { - // This is a simplified implementation that reads the stream and creates chunks - // A full implementation would be more sophisticated - - var chunks []*filer_pb.FileChunk - buffer := make([]byte, spec.BufferSize) - offset := int64(0) - - for { - n, err := reader.Read(buffer) - if n > 0 { - // Create chunk for this data, setting SSE type and per-chunk metadata (including chunk-specific IVs for SSE-S3) - chunk, chunkErr := scm.createChunkFromData(buffer[:n], offset, dstPath, spec.EncryptionSpec) - if chunkErr != nil { - return nil, fmt.Errorf("create chunk from data: %w", chunkErr) - } - chunks = append(chunks, chunk) - offset += int64(n) - } - - if err == io.EOF { - break - } - if err != nil { - return nil, fmt.Errorf("read stream: %w", err) - } - } - - return chunks, nil -} - -// createChunkFromData creates a chunk from streaming data -func (scm *StreamingCopyManager) createChunkFromData(data []byte, offset int64, dstPath string, encSpec *EncryptionSpec) (*filer_pb.FileChunk, error) { - // Assign new volume - assignResult, err := scm.s3a.assignNewVolume(dstPath) - if err != nil { - return nil, fmt.Errorf("assign volume: %w", err) - } - - // Create chunk - chunk := &filer_pb.FileChunk{ - Offset: offset, - Size: uint64(len(data)), - } - - // Set SSE type and metadata on chunk if destination is encrypted - // This is critical for GetObject to know to decrypt the data - fixes GitHub #7562 - if encSpec != nil && encSpec.NeedsEncryption { - switch encSpec.DestinationType { - case EncryptionTypeSSEC: - chunk.SseType = filer_pb.SSEType_SSE_C - // SSE-C metadata is handled at object level, not per-chunk for streaming copy - case EncryptionTypeSSEKMS: - chunk.SseType = filer_pb.SSEType_SSE_KMS - // SSE-KMS metadata is handled at object level, not per-chunk for streaming copy - case EncryptionTypeSSES3: - chunk.SseType = filer_pb.SSEType_SSE_S3 - // Create per-chunk SSE-S3 metadata with chunk-specific IV - if sseKey, ok := encSpec.DestinationKey.(*SSES3Key); ok { - // Calculate chunk-specific IV using base IV and chunk offset - baseIV := encSpec.DestinationIV - if len(baseIV) == 0 { - return nil, fmt.Errorf("SSE-S3 encryption requires DestinationIV to be set for chunk at offset %d", offset) - } - chunkIV, _ := calculateIVWithOffset(baseIV, offset) - // Create chunk key with the chunk-specific IV - chunkSSEKey := &SSES3Key{ - Key: sseKey.Key, - KeyID: sseKey.KeyID, - Algorithm: sseKey.Algorithm, - IV: chunkIV, - } - chunkMetadata, serErr := SerializeSSES3Metadata(chunkSSEKey) - if serErr != nil { - return nil, fmt.Errorf("failed to serialize chunk SSE-S3 metadata: %w", serErr) - } - chunk.SseMetadata = chunkMetadata - } - } - } - - // Set file ID - if err := scm.s3a.setChunkFileId(chunk, assignResult); err != nil { - return nil, err - } - - // Upload data - if err := scm.s3a.uploadChunkData(data, assignResult, false); err != nil { - return nil, fmt.Errorf("upload chunk data: %w", err) - } - - return chunk, nil -} - -// createMultiChunkReader creates a reader that streams from multiple chunks -func (s3a *S3ApiServer) createMultiChunkReader(entry *filer_pb.Entry) (io.ReadCloser, error) { - // Create a multi-reader that combines all chunks - var readers []io.Reader - - for _, chunk := range entry.GetChunks() { - chunkReader, err := s3a.createChunkReader(chunk) - if err != nil { - return nil, fmt.Errorf("create chunk reader: %w", err) - } - readers = append(readers, chunkReader) - } - - multiReader := io.MultiReader(readers...) - return &multiReadCloser{reader: multiReader}, nil -} - -// createChunkReader creates a reader for a single chunk -func (s3a *S3ApiServer) createChunkReader(chunk *filer_pb.FileChunk) (io.Reader, error) { - // Get chunk URL - srcUrl, err := s3a.lookupVolumeUrl(chunk.GetFileIdString()) - if err != nil { - return nil, fmt.Errorf("lookup volume URL: %w", err) - } - - // Create HTTP request for chunk data - req, err := http.NewRequest("GET", srcUrl, nil) - if err != nil { - return nil, fmt.Errorf("create HTTP request: %w", err) - } - - // Execute request - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("execute HTTP request: %w", err) - } - - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return nil, fmt.Errorf("HTTP request failed: %d", resp.StatusCode) - } - - return resp.Body, nil -} - -// multiReadCloser wraps a multi-reader with a close method -type multiReadCloser struct { - reader io.Reader -} - -func (mrc *multiReadCloser) Read(p []byte) (int, error) { - return mrc.reader.Read(p) -} - -func (mrc *multiReadCloser) Close() error { - return nil -} diff --git a/weed/s3api/s3err/audit_fluent.go b/weed/s3api/s3err/audit_fluent.go index b63533f1c..ad101cca2 100644 --- a/weed/s3api/s3err/audit_fluent.go +++ b/weed/s3api/s3err/audit_fluent.go @@ -128,13 +128,6 @@ func getOperation(object string, r *http.Request) string { return operation } -func GetAccessHttpLog(r *http.Request, statusCode int, s3errCode ErrorCode) AccessLogHTTP { - return AccessLogHTTP{ - RequestURI: r.RequestURI, - Referer: r.Header.Get("Referer"), - } -} - func GetAccessLog(r *http.Request, HTTPStatusCode int, s3errCode ErrorCode) *AccessLog { bucket, key := s3_constants.GetBucketAndObject(r) var errorCode string diff --git a/weed/s3api/s3lifecycle/evaluator.go b/weed/s3api/s3lifecycle/evaluator.go deleted file mode 100644 index 181b08e44..000000000 --- a/weed/s3api/s3lifecycle/evaluator.go +++ /dev/null @@ -1,127 +0,0 @@ -package s3lifecycle - -import "time" - -// Evaluate checks the given lifecycle rules against an object and returns -// the highest-priority action that applies. The evaluation follows S3's -// action priority: -// 1. ExpiredObjectDeleteMarker (delete marker is sole version) -// 2. NoncurrentVersionExpiration (non-current version age/count) -// 3. Current version Expiration (Days or Date) -// -// AbortIncompleteMultipartUpload is evaluated separately since it applies -// to uploads, not objects. Use EvaluateMPUAbort for that. -func Evaluate(rules []Rule, obj ObjectInfo, now time.Time) EvalResult { - // Phase 1: ExpiredObjectDeleteMarker - if obj.IsDeleteMarker && obj.IsLatest && obj.NumVersions == 1 { - for _, rule := range rules { - if rule.Status != "Enabled" { - continue - } - if !MatchesFilter(rule, obj) { - continue - } - if rule.ExpiredObjectDeleteMarker { - return EvalResult{Action: ActionExpireDeleteMarker, RuleID: rule.ID} - } - } - } - - // Phase 2: NoncurrentVersionExpiration - if !obj.IsLatest && !obj.SuccessorModTime.IsZero() { - for _, rule := range rules { - if ShouldExpireNoncurrentVersion(rule, obj, obj.NoncurrentIndex, now) { - return EvalResult{Action: ActionDeleteVersion, RuleID: rule.ID} - } - } - } - - // Phase 3: Current version Expiration - if obj.IsLatest && !obj.IsDeleteMarker { - for _, rule := range rules { - if rule.Status != "Enabled" { - continue - } - if !MatchesFilter(rule, obj) { - continue - } - // Date-based expiration - if !rule.ExpirationDate.IsZero() && !now.Before(rule.ExpirationDate) { - return EvalResult{Action: ActionDeleteObject, RuleID: rule.ID} - } - // Days-based expiration - if rule.ExpirationDays > 0 { - expiryTime := expectedExpiryTime(obj.ModTime, rule.ExpirationDays) - if !now.Before(expiryTime) { - return EvalResult{Action: ActionDeleteObject, RuleID: rule.ID} - } - } - } - } - - return EvalResult{Action: ActionNone} -} - -// ShouldExpireNoncurrentVersion checks whether a non-current version should -// be expired considering both NoncurrentDays and NewerNoncurrentVersions. -// noncurrentIndex is the 0-based position among non-current versions sorted -// newest-first (0 = newest non-current version). -func ShouldExpireNoncurrentVersion(rule Rule, obj ObjectInfo, noncurrentIndex int, now time.Time) bool { - if rule.Status != "Enabled" { - return false - } - if rule.NoncurrentVersionExpirationDays <= 0 { - return false - } - if obj.IsLatest || obj.SuccessorModTime.IsZero() { - return false - } - if !MatchesFilter(rule, obj) { - return false - } - - // Check age threshold. - expiryTime := expectedExpiryTime(obj.SuccessorModTime, rule.NoncurrentVersionExpirationDays) - if now.Before(expiryTime) { - return false - } - - // Check NewerNoncurrentVersions count threshold. - if rule.NewerNoncurrentVersions > 0 && noncurrentIndex < rule.NewerNoncurrentVersions { - return false - } - - return true -} - -// EvaluateMPUAbort finds the applicable AbortIncompleteMultipartUpload rule -// for a multipart upload with the given key prefix and creation time. -func EvaluateMPUAbort(rules []Rule, uploadKey string, createdAt time.Time, now time.Time) EvalResult { - for _, rule := range rules { - if rule.Status != "Enabled" { - continue - } - if rule.AbortMPUDaysAfterInitiation <= 0 { - continue - } - if !matchesPrefix(rule.Prefix, uploadKey) { - continue - } - cutoff := expectedExpiryTime(createdAt, rule.AbortMPUDaysAfterInitiation) - if !now.Before(cutoff) { - return EvalResult{Action: ActionAbortMultipartUpload, RuleID: rule.ID} - } - } - return EvalResult{Action: ActionNone} -} - -// expectedExpiryTime computes the expiration time given a reference time and -// a number of days. Following S3 semantics, expiration happens at midnight UTC -// of the day after the specified number of days. -func expectedExpiryTime(refTime time.Time, days int) time.Time { - if days == 0 { - return refTime - } - t := refTime.UTC().Add(time.Duration(days+1) * 24 * time.Hour) - return t.Truncate(24 * time.Hour) -} diff --git a/weed/s3api/s3lifecycle/evaluator_test.go b/weed/s3api/s3lifecycle/evaluator_test.go deleted file mode 100644 index aa58e4bc8..000000000 --- a/weed/s3api/s3lifecycle/evaluator_test.go +++ /dev/null @@ -1,495 +0,0 @@ -package s3lifecycle - -import ( - "testing" - "time" -) - -var now = time.Date(2026, 3, 27, 12, 0, 0, 0, time.UTC) - -func TestEvaluate_ExpirationDays(t *testing.T) { - rules := []Rule{{ - ID: "expire-30d", Status: "Enabled", - ExpirationDays: 30, - }} - - t.Run("object_older_than_days_is_expired", func(t *testing.T) { - obj := ObjectInfo{ - Key: "data/file.txt", IsLatest: true, - ModTime: now.Add(-31 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionDeleteObject, result.Action) - assertEqual(t, "expire-30d", result.RuleID) - }) - - t.Run("object_younger_than_days_is_not_expired", func(t *testing.T) { - obj := ObjectInfo{ - Key: "data/file.txt", IsLatest: true, - ModTime: now.Add(-10 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) - - t.Run("non_latest_version_not_affected_by_expiration_days", func(t *testing.T) { - obj := ObjectInfo{ - Key: "data/file.txt", IsLatest: false, - ModTime: now.Add(-60 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) - - t.Run("delete_marker_not_affected_by_expiration_days", func(t *testing.T) { - obj := ObjectInfo{ - Key: "data/file.txt", IsLatest: true, IsDeleteMarker: true, - ModTime: now.Add(-60 * 24 * time.Hour), NumVersions: 3, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) -} - -func TestEvaluate_ExpirationDate(t *testing.T) { - expirationDate := time.Date(2026, 3, 15, 0, 0, 0, 0, time.UTC) - rules := []Rule{{ - ID: "expire-date", Status: "Enabled", - ExpirationDate: expirationDate, - }} - - t.Run("object_expired_after_date", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, - ModTime: now.Add(-60 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionDeleteObject, result.Action) - }) - - t.Run("object_not_expired_before_date", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, - ModTime: now.Add(-1 * time.Hour), - } - beforeDate := time.Date(2026, 3, 10, 0, 0, 0, 0, time.UTC) - result := Evaluate(rules, obj, beforeDate) - assertAction(t, ActionNone, result.Action) - }) -} - -func TestEvaluate_ExpiredObjectDeleteMarker(t *testing.T) { - rules := []Rule{{ - ID: "cleanup-markers", Status: "Enabled", - ExpiredObjectDeleteMarker: true, - }} - - t.Run("sole_delete_marker_is_expired", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, IsDeleteMarker: true, - NumVersions: 1, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionExpireDeleteMarker, result.Action) - }) - - t.Run("delete_marker_with_other_versions_not_expired", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, IsDeleteMarker: true, - NumVersions: 3, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) - - t.Run("non_latest_delete_marker_not_expired", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: false, IsDeleteMarker: true, - NumVersions: 1, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) - - t.Run("non_delete_marker_not_affected", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, IsDeleteMarker: false, - NumVersions: 1, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) -} - -func TestEvaluate_NoncurrentVersionExpiration(t *testing.T) { - rules := []Rule{{ - ID: "expire-noncurrent", Status: "Enabled", - NoncurrentVersionExpirationDays: 30, - }} - - t.Run("old_noncurrent_version_is_expired", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: false, - SuccessorModTime: now.Add(-45 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionDeleteVersion, result.Action) - }) - - t.Run("recent_noncurrent_version_is_not_expired", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: false, - SuccessorModTime: now.Add(-10 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) - - t.Run("latest_version_not_affected", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, - ModTime: now.Add(-60 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) -} - -func TestShouldExpireNoncurrentVersion(t *testing.T) { - rule := Rule{ - ID: "noncurrent-rule", Status: "Enabled", - NoncurrentVersionExpirationDays: 30, - NewerNoncurrentVersions: 2, - } - - t.Run("old_version_beyond_count_is_expired", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: false, - SuccessorModTime: now.Add(-45 * 24 * time.Hour), - } - // noncurrentIndex=2 means this is the 3rd noncurrent version (0-indexed) - // With NewerNoncurrentVersions=2, indices 0 and 1 are kept. - if !ShouldExpireNoncurrentVersion(rule, obj, 2, now) { - t.Error("expected version at index 2 to be expired") - } - }) - - t.Run("old_version_within_count_is_kept", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: false, - SuccessorModTime: now.Add(-45 * 24 * time.Hour), - } - // noncurrentIndex=1 is within the keep threshold (NewerNoncurrentVersions=2). - if ShouldExpireNoncurrentVersion(rule, obj, 1, now) { - t.Error("expected version at index 1 to be kept") - } - }) - - t.Run("recent_version_beyond_count_is_kept", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: false, - SuccessorModTime: now.Add(-5 * 24 * time.Hour), - } - // Even at index 5 (beyond count), if too young, it's kept. - if ShouldExpireNoncurrentVersion(rule, obj, 5, now) { - t.Error("expected recent version to be kept regardless of index") - } - }) - - t.Run("disabled_rule_never_expires", func(t *testing.T) { - disabled := Rule{ - ID: "disabled", Status: "Disabled", - NoncurrentVersionExpirationDays: 1, - } - obj := ObjectInfo{ - Key: "file.txt", IsLatest: false, - SuccessorModTime: now.Add(-365 * 24 * time.Hour), - } - if ShouldExpireNoncurrentVersion(disabled, obj, 10, now) { - t.Error("disabled rule should never expire") - } - }) -} - -func TestEvaluate_PrefixFilter(t *testing.T) { - rules := []Rule{{ - ID: "logs-only", Status: "Enabled", - Prefix: "logs/", - ExpirationDays: 7, - }} - - t.Run("matching_prefix", func(t *testing.T) { - obj := ObjectInfo{ - Key: "logs/app.log", IsLatest: true, - ModTime: now.Add(-10 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionDeleteObject, result.Action) - }) - - t.Run("non_matching_prefix", func(t *testing.T) { - obj := ObjectInfo{ - Key: "data/file.txt", IsLatest: true, - ModTime: now.Add(-10 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) -} - -func TestEvaluate_TagFilter(t *testing.T) { - rules := []Rule{{ - ID: "temp-only", Status: "Enabled", - ExpirationDays: 1, - FilterTags: map[string]string{"env": "temp"}, - }} - - t.Run("matching_tags", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, - ModTime: now.Add(-5 * 24 * time.Hour), - Tags: map[string]string{"env": "temp", "project": "foo"}, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionDeleteObject, result.Action) - }) - - t.Run("missing_tag", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, - ModTime: now.Add(-5 * 24 * time.Hour), - Tags: map[string]string{"project": "foo"}, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) - - t.Run("wrong_tag_value", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, - ModTime: now.Add(-5 * 24 * time.Hour), - Tags: map[string]string{"env": "prod"}, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) - - t.Run("nil_object_tags", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, - ModTime: now.Add(-5 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) -} - -func TestEvaluate_SizeFilter(t *testing.T) { - rules := []Rule{{ - ID: "large-files", Status: "Enabled", - ExpirationDays: 7, - FilterSizeGreaterThan: 1024 * 1024, // > 1 MB - FilterSizeLessThan: 100 * 1024 * 1024, // < 100 MB - }} - - t.Run("matching_size", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.bin", IsLatest: true, - ModTime: now.Add(-10 * 24 * time.Hour), - Size: 10 * 1024 * 1024, // 10 MB - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionDeleteObject, result.Action) - }) - - t.Run("too_small", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.bin", IsLatest: true, - ModTime: now.Add(-10 * 24 * time.Hour), - Size: 512, // 512 bytes - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) - - t.Run("too_large", func(t *testing.T) { - obj := ObjectInfo{ - Key: "file.bin", IsLatest: true, - ModTime: now.Add(-10 * 24 * time.Hour), - Size: 200 * 1024 * 1024, // 200 MB - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) -} - -func TestEvaluate_CombinedFilters(t *testing.T) { - rules := []Rule{{ - ID: "combined", Status: "Enabled", - Prefix: "logs/", - ExpirationDays: 7, - FilterTags: map[string]string{"env": "dev"}, - FilterSizeGreaterThan: 100, - }} - - t.Run("all_filters_match", func(t *testing.T) { - obj := ObjectInfo{ - Key: "logs/app.log", IsLatest: true, - ModTime: now.Add(-10 * 24 * time.Hour), - Size: 1024, - Tags: map[string]string{"env": "dev"}, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionDeleteObject, result.Action) - }) - - t.Run("prefix_doesnt_match", func(t *testing.T) { - obj := ObjectInfo{ - Key: "data/app.log", IsLatest: true, - ModTime: now.Add(-10 * 24 * time.Hour), - Size: 1024, - Tags: map[string]string{"env": "dev"}, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) - - t.Run("tag_doesnt_match", func(t *testing.T) { - obj := ObjectInfo{ - Key: "logs/app.log", IsLatest: true, - ModTime: now.Add(-10 * 24 * time.Hour), - Size: 1024, - Tags: map[string]string{"env": "prod"}, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) - - t.Run("size_doesnt_match", func(t *testing.T) { - obj := ObjectInfo{ - Key: "logs/app.log", IsLatest: true, - ModTime: now.Add(-10 * 24 * time.Hour), - Size: 50, // too small - Tags: map[string]string{"env": "dev"}, - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) - }) -} - -func TestEvaluate_DisabledRule(t *testing.T) { - rules := []Rule{{ - ID: "disabled", Status: "Disabled", - ExpirationDays: 1, - }} - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, - ModTime: now.Add(-365 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionNone, result.Action) -} - -func TestEvaluate_MultipleRules_Priority(t *testing.T) { - t.Run("delete_marker_takes_priority_over_expiration", func(t *testing.T) { - rules := []Rule{ - {ID: "expire", Status: "Enabled", ExpirationDays: 1}, - {ID: "marker", Status: "Enabled", ExpiredObjectDeleteMarker: true}, - } - obj := ObjectInfo{ - Key: "file.txt", IsLatest: true, IsDeleteMarker: true, - NumVersions: 1, ModTime: now.Add(-10 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionExpireDeleteMarker, result.Action) - assertEqual(t, "marker", result.RuleID) - }) - - t.Run("first_matching_expiration_rule_wins", func(t *testing.T) { - rules := []Rule{ - {ID: "rule1", Status: "Enabled", ExpirationDays: 30, Prefix: "logs/"}, - {ID: "rule2", Status: "Enabled", ExpirationDays: 7}, - } - obj := ObjectInfo{ - Key: "logs/app.log", IsLatest: true, - ModTime: now.Add(-31 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionDeleteObject, result.Action) - assertEqual(t, "rule1", result.RuleID) - }) -} - -func TestEvaluate_EmptyPrefix(t *testing.T) { - rules := []Rule{{ - ID: "all", Status: "Enabled", - ExpirationDays: 30, - }} - obj := ObjectInfo{ - Key: "any/path/file.txt", IsLatest: true, - ModTime: now.Add(-31 * 24 * time.Hour), - } - result := Evaluate(rules, obj, now) - assertAction(t, ActionDeleteObject, result.Action) -} - -func TestEvaluateMPUAbort(t *testing.T) { - rules := []Rule{{ - ID: "abort-mpu", Status: "Enabled", - AbortMPUDaysAfterInitiation: 7, - }} - - t.Run("old_upload_is_aborted", func(t *testing.T) { - result := EvaluateMPUAbort(rules, "uploads/file.bin", now.Add(-10*24*time.Hour), now) - assertAction(t, ActionAbortMultipartUpload, result.Action) - }) - - t.Run("recent_upload_is_not_aborted", func(t *testing.T) { - result := EvaluateMPUAbort(rules, "uploads/file.bin", now.Add(-3*24*time.Hour), now) - assertAction(t, ActionNone, result.Action) - }) - - t.Run("prefix_scoped_abort", func(t *testing.T) { - prefixRules := []Rule{{ - ID: "abort-logs", Status: "Enabled", - Prefix: "logs/", - AbortMPUDaysAfterInitiation: 1, - }} - result := EvaluateMPUAbort(prefixRules, "data/file.bin", now.Add(-5*24*time.Hour), now) - assertAction(t, ActionNone, result.Action) - }) -} - -func TestExpectedExpiryTime(t *testing.T) { - ref := time.Date(2026, 3, 1, 15, 30, 0, 0, time.UTC) - - t.Run("30_days", func(t *testing.T) { - // S3 spec: expires at midnight UTC of day 32 (ref + 31 days, truncated). - expiry := expectedExpiryTime(ref, 30) - expected := time.Date(2026, 4, 1, 0, 0, 0, 0, time.UTC) - if !expiry.Equal(expected) { - t.Errorf("expected %v, got %v", expected, expiry) - } - }) - - t.Run("zero_days_returns_ref", func(t *testing.T) { - expiry := expectedExpiryTime(ref, 0) - if !expiry.Equal(ref) { - t.Errorf("expected %v, got %v", ref, expiry) - } - }) -} - -func assertAction(t *testing.T, expected, actual Action) { - t.Helper() - if expected != actual { - t.Errorf("expected action %d, got %d", expected, actual) - } -} - -func assertEqual(t *testing.T, expected, actual string) { - t.Helper() - if expected != actual { - t.Errorf("expected %q, got %q", expected, actual) - } -} diff --git a/weed/s3api/s3lifecycle/filter.go b/weed/s3api/s3lifecycle/filter.go deleted file mode 100644 index 394425d60..000000000 --- a/weed/s3api/s3lifecycle/filter.go +++ /dev/null @@ -1,56 +0,0 @@ -package s3lifecycle - -import "strings" - -// MatchesFilter checks if an object matches the rule's filter criteria -// (prefix, tags, and size constraints). -func MatchesFilter(rule Rule, obj ObjectInfo) bool { - if !matchesPrefix(rule.Prefix, obj.Key) { - return false - } - if !matchesTags(rule.FilterTags, obj.Tags) { - return false - } - if !matchesSize(rule.FilterSizeGreaterThan, rule.FilterSizeLessThan, obj.Size) { - return false - } - return true -} - -// matchesPrefix returns true if the object key starts with the given prefix. -// An empty prefix matches all keys. -func matchesPrefix(prefix, key string) bool { - if prefix == "" { - return true - } - return strings.HasPrefix(key, prefix) -} - -// matchesTags returns true if all rule tags are present in the object's tags -// with matching values. An empty or nil rule tag set matches all objects. -func matchesTags(ruleTags, objTags map[string]string) bool { - if len(ruleTags) == 0 { - return true - } - if len(objTags) == 0 { - return false - } - for k, v := range ruleTags { - if objVal, ok := objTags[k]; !ok || objVal != v { - return false - } - } - return true -} - -// matchesSize returns true if the object's size falls within the specified -// bounds. Zero values mean no constraint on that side. -func matchesSize(greaterThan, lessThan, objSize int64) bool { - if greaterThan > 0 && objSize <= greaterThan { - return false - } - if lessThan > 0 && objSize >= lessThan { - return false - } - return true -} diff --git a/weed/s3api/s3lifecycle/filter_test.go b/weed/s3api/s3lifecycle/filter_test.go deleted file mode 100644 index c8bcfeb10..000000000 --- a/weed/s3api/s3lifecycle/filter_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package s3lifecycle - -import "testing" - -func TestMatchesPrefix(t *testing.T) { - tests := []struct { - name string - prefix string - key string - want bool - }{ - {"empty_prefix_matches_all", "", "any/key.txt", true}, - {"exact_prefix_match", "logs/", "logs/app.log", true}, - {"prefix_mismatch", "logs/", "data/file.txt", false}, - {"key_shorter_than_prefix", "very/long/prefix/", "short", false}, - {"prefix_equals_key", "exact", "exact", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := matchesPrefix(tt.prefix, tt.key); got != tt.want { - t.Errorf("matchesPrefix(%q, %q) = %v, want %v", tt.prefix, tt.key, got, tt.want) - } - }) - } -} - -func TestMatchesTags(t *testing.T) { - tests := []struct { - name string - ruleTags map[string]string - objTags map[string]string - want bool - }{ - {"nil_rule_tags_match_all", nil, map[string]string{"a": "1"}, true}, - {"empty_rule_tags_match_all", map[string]string{}, map[string]string{"a": "1"}, true}, - {"nil_obj_tags_no_match", map[string]string{"a": "1"}, nil, false}, - {"single_tag_match", map[string]string{"env": "dev"}, map[string]string{"env": "dev", "foo": "bar"}, true}, - {"single_tag_value_mismatch", map[string]string{"env": "dev"}, map[string]string{"env": "prod"}, false}, - {"single_tag_key_missing", map[string]string{"env": "dev"}, map[string]string{"foo": "bar"}, false}, - {"multi_tag_all_match", map[string]string{"env": "dev", "tier": "hot"}, map[string]string{"env": "dev", "tier": "hot", "extra": "x"}, true}, - {"multi_tag_partial_match", map[string]string{"env": "dev", "tier": "hot"}, map[string]string{"env": "dev"}, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := matchesTags(tt.ruleTags, tt.objTags); got != tt.want { - t.Errorf("matchesTags() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestMatchesSize(t *testing.T) { - tests := []struct { - name string - greaterThan int64 - lessThan int64 - objSize int64 - want bool - }{ - {"no_constraints", 0, 0, 1000, true}, - {"only_greater_than_pass", 100, 0, 200, true}, - {"only_greater_than_fail", 100, 0, 50, false}, - {"only_greater_than_equal_fail", 100, 0, 100, false}, - {"only_less_than_pass", 0, 1000, 500, true}, - {"only_less_than_fail", 0, 1000, 2000, false}, - {"only_less_than_equal_fail", 0, 1000, 1000, false}, - {"both_constraints_pass", 100, 1000, 500, true}, - {"both_constraints_too_small", 100, 1000, 50, false}, - {"both_constraints_too_large", 100, 1000, 2000, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := matchesSize(tt.greaterThan, tt.lessThan, tt.objSize); got != tt.want { - t.Errorf("matchesSize(%d, %d, %d) = %v, want %v", - tt.greaterThan, tt.lessThan, tt.objSize, got, tt.want) - } - }) - } -} diff --git a/weed/s3api/s3lifecycle/tags.go b/weed/s3api/s3lifecycle/tags.go index 57092ed56..49bbaab66 100644 --- a/weed/s3api/s3lifecycle/tags.go +++ b/weed/s3api/s3lifecycle/tags.go @@ -1,34 +1,3 @@ package s3lifecycle -import "strings" - const tagPrefix = "X-Amz-Tagging-" - -// ExtractTags extracts S3 object tags from a filer entry's Extended metadata. -// Tags are stored with the key prefix "X-Amz-Tagging-" followed by the tag key. -func ExtractTags(extended map[string][]byte) map[string]string { - if len(extended) == 0 { - return nil - } - var tags map[string]string - for k, v := range extended { - if strings.HasPrefix(k, tagPrefix) { - if tags == nil { - tags = make(map[string]string) - } - tags[k[len(tagPrefix):]] = string(v) - } - } - return tags -} - -// HasTagRules returns true if any enabled rule in the set uses tag-based filtering. -// This is used as an optimization to skip tag extraction when no rules need it. -func HasTagRules(rules []Rule) bool { - for _, r := range rules { - if r.Status == "Enabled" && len(r.FilterTags) > 0 { - return true - } - } - return false -} diff --git a/weed/s3api/s3lifecycle/tags_test.go b/weed/s3api/s3lifecycle/tags_test.go deleted file mode 100644 index 0eb198c5f..000000000 --- a/weed/s3api/s3lifecycle/tags_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package s3lifecycle - -import "testing" - -func TestExtractTags(t *testing.T) { - t.Run("extracts_tags_with_prefix", func(t *testing.T) { - extended := map[string][]byte{ - "X-Amz-Tagging-env": []byte("prod"), - "X-Amz-Tagging-project": []byte("foo"), - "Content-Type": []byte("text/plain"), - "X-Amz-Meta-Custom": []byte("value"), - } - tags := ExtractTags(extended) - if len(tags) != 2 { - t.Fatalf("expected 2 tags, got %d", len(tags)) - } - if tags["env"] != "prod" { - t.Errorf("expected env=prod, got %q", tags["env"]) - } - if tags["project"] != "foo" { - t.Errorf("expected project=foo, got %q", tags["project"]) - } - }) - - t.Run("nil_extended_returns_nil", func(t *testing.T) { - tags := ExtractTags(nil) - if tags != nil { - t.Errorf("expected nil, got %v", tags) - } - }) - - t.Run("no_tags_returns_nil", func(t *testing.T) { - extended := map[string][]byte{ - "Content-Type": []byte("text/plain"), - } - tags := ExtractTags(extended) - if tags != nil { - t.Errorf("expected nil, got %v", tags) - } - }) - - t.Run("empty_tag_value", func(t *testing.T) { - extended := map[string][]byte{ - "X-Amz-Tagging-empty": []byte(""), - } - tags := ExtractTags(extended) - if len(tags) != 1 { - t.Fatalf("expected 1 tag, got %d", len(tags)) - } - if tags["empty"] != "" { - t.Errorf("expected empty value, got %q", tags["empty"]) - } - }) -} - -func TestHasTagRules(t *testing.T) { - t.Run("has_tag_rules", func(t *testing.T) { - rules := []Rule{ - {Status: "Enabled", FilterTags: map[string]string{"env": "dev"}}, - } - if !HasTagRules(rules) { - t.Error("expected true") - } - }) - - t.Run("no_tag_rules", func(t *testing.T) { - rules := []Rule{ - {Status: "Enabled", ExpirationDays: 30}, - } - if HasTagRules(rules) { - t.Error("expected false") - } - }) - - t.Run("disabled_tag_rule", func(t *testing.T) { - rules := []Rule{ - {Status: "Disabled", FilterTags: map[string]string{"env": "dev"}}, - } - if HasTagRules(rules) { - t.Error("expected false for disabled rule") - } - }) - - t.Run("empty_rules", func(t *testing.T) { - if HasTagRules(nil) { - t.Error("expected false for nil rules") - } - }) -} diff --git a/weed/s3api/s3lifecycle/version_time.go b/weed/s3api/s3lifecycle/version_time.go index fb6cfbbf5..a9d2e9ae2 100644 --- a/weed/s3api/s3lifecycle/version_time.go +++ b/weed/s3api/s3lifecycle/version_time.go @@ -1,99 +1,6 @@ package s3lifecycle -import ( - "math" - "strconv" - "time" -) - // versionIdFormatThreshold distinguishes old vs new format version IDs. // New format (inverted timestamps) produces values above this threshold; // old format (raw timestamps) produces values below it. const versionIdFormatThreshold = 0x4000000000000000 - -// GetVersionTimestamp extracts the actual timestamp from a SeaweedFS version ID, -// handling both old (raw nanosecond) and new (inverted nanosecond) formats. -// Returns zero time if the version ID is invalid or "null". -func GetVersionTimestamp(versionId string) time.Time { - ns := getVersionTimestampNanos(versionId) - if ns == 0 { - return time.Time{} - } - return time.Unix(0, ns) -} - -// getVersionTimestampNanos extracts the raw nanosecond timestamp from a version ID. -func getVersionTimestampNanos(versionId string) int64 { - if len(versionId) < 16 || versionId == "null" { - return 0 - } - timestampPart, err := strconv.ParseUint(versionId[:16], 16, 64) - if err != nil { - return 0 - } - if timestampPart > math.MaxInt64 { - return 0 - } - if timestampPart > versionIdFormatThreshold { - // New format: inverted timestamp, convert back. - return int64(math.MaxInt64 - timestampPart) - } - return int64(timestampPart) -} - -// isNewFormatVersionId returns true if the version ID uses inverted timestamps. -func isNewFormatVersionId(versionId string) bool { - if len(versionId) < 16 || versionId == "null" { - return false - } - timestampPart, err := strconv.ParseUint(versionId[:16], 16, 64) - if err != nil { - return false - } - return timestampPart > versionIdFormatThreshold && timestampPart <= math.MaxInt64 -} - -// CompareVersionIds compares two version IDs for sorting (newest first). -// Returns negative if a is newer, positive if b is newer, 0 if equal. -// Handles both old and new format version IDs and uses full lexicographic -// comparison (not just timestamps) to break ties from the random suffix. -func CompareVersionIds(a, b string) int { - if a == b { - return 0 - } - if a == "null" { - return 1 - } - if b == "null" { - return -1 - } - - aIsNew := isNewFormatVersionId(a) - bIsNew := isNewFormatVersionId(b) - - if aIsNew == bIsNew { - if aIsNew { - // New format: smaller hex = newer (inverted timestamps). - if a < b { - return -1 - } - return 1 - } - // Old format: smaller hex = older. - if a < b { - return 1 - } - return -1 - } - - // Mixed formats: compare by actual timestamp. - aTime := getVersionTimestampNanos(a) - bTime := getVersionTimestampNanos(b) - if aTime > bTime { - return -1 - } - if aTime < bTime { - return 1 - } - return 0 -} diff --git a/weed/s3api/s3lifecycle/version_time_test.go b/weed/s3api/s3lifecycle/version_time_test.go deleted file mode 100644 index 460cbec58..000000000 --- a/weed/s3api/s3lifecycle/version_time_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package s3lifecycle - -import ( - "fmt" - "math" - "testing" - "time" -) - -func TestGetVersionTimestamp(t *testing.T) { - t.Run("new_format_inverted_timestamp", func(t *testing.T) { - // Simulate a new-format version ID (inverted timestamp above threshold). - now := time.Now() - inverted := math.MaxInt64 - now.UnixNano() - versionId := fmt.Sprintf("%016x", inverted) + "0000000000000000" - - got := GetVersionTimestamp(versionId) - // Should recover the original timestamp within 1 second. - diff := got.Sub(now) - if diff < -time.Second || diff > time.Second { - t.Errorf("timestamp diff too large: %v (got %v, want ~%v)", diff, got, now) - } - }) - - t.Run("old_format_raw_timestamp", func(t *testing.T) { - // Simulate an old-format version ID (raw nanosecond timestamp below threshold). - // Use a timestamp from 2023 which would be below threshold. - ts := time.Date(2023, 6, 15, 12, 0, 0, 0, time.UTC) - versionId := fmt.Sprintf("%016x", ts.UnixNano()) + "abcdef0123456789" - - got := GetVersionTimestamp(versionId) - if !got.Equal(ts) { - t.Errorf("expected %v, got %v", ts, got) - } - }) - - t.Run("null_version_id", func(t *testing.T) { - got := GetVersionTimestamp("null") - if !got.IsZero() { - t.Errorf("expected zero time for null version, got %v", got) - } - }) - - t.Run("empty_version_id", func(t *testing.T) { - got := GetVersionTimestamp("") - if !got.IsZero() { - t.Errorf("expected zero time for empty version, got %v", got) - } - }) - - t.Run("short_version_id", func(t *testing.T) { - got := GetVersionTimestamp("abc") - if !got.IsZero() { - t.Errorf("expected zero time for short version, got %v", got) - } - }) - - t.Run("high_bit_overflow_returns_zero", func(t *testing.T) { - // Version ID with first 16 hex chars > math.MaxInt64 should return zero, - // not a wrapped negative timestamp. - versionId := "80000000000000000000000000000000" - got := GetVersionTimestamp(versionId) - if !got.IsZero() { - t.Errorf("expected zero time for overflow version ID, got %v", got) - } - }) - - t.Run("invalid_hex", func(t *testing.T) { - got := GetVersionTimestamp("zzzzzzzzzzzzzzzz0000000000000000") - if !got.IsZero() { - t.Errorf("expected zero time for invalid hex, got %v", got) - } - }) -} diff --git a/weed/s3api/s3tables/filer_ops.go b/weed/s3api/s3tables/filer_ops.go index 7edb8a2a5..7a0ad66ff 100644 --- a/weed/s3api/s3tables/filer_ops.go +++ b/weed/s3api/s3tables/filer_ops.go @@ -50,46 +50,6 @@ func (h *S3TablesHandler) ensureDirectory(ctx context.Context, client filer_pb.S return err } -// upsertFile creates or updates a small file with the given content -func (h *S3TablesHandler) upsertFile(ctx context.Context, client filer_pb.SeaweedFilerClient, path string, data []byte) error { - dir, name := splitPath(path) - now := time.Now().Unix() - resp, err := filer_pb.LookupEntry(ctx, client, &filer_pb.LookupDirectoryEntryRequest{ - Directory: dir, - Name: name, - }) - if err != nil { - if !errors.Is(err, filer_pb.ErrNotFound) { - return err - } - return filer_pb.CreateEntry(ctx, client, &filer_pb.CreateEntryRequest{ - Directory: dir, - Entry: &filer_pb.Entry{ - Name: name, - Content: data, - Attributes: &filer_pb.FuseAttributes{ - Mtime: now, - Crtime: now, - FileMode: uint32(0644), - FileSize: uint64(len(data)), - }, - }, - }) - } - - entry := resp.Entry - if entry.Attributes == nil { - entry.Attributes = &filer_pb.FuseAttributes{} - } - entry.Attributes.Mtime = now - entry.Attributes.FileSize = uint64(len(data)) - entry.Content = data - return filer_pb.UpdateEntry(ctx, client, &filer_pb.UpdateEntryRequest{ - Directory: dir, - Entry: entry, - }) -} - // deleteEntryIfExists removes an entry if it exists, ignoring missing errors func (h *S3TablesHandler) deleteEntryIfExists(ctx context.Context, client filer_pb.SeaweedFilerClient, path string) error { dir, name := splitPath(path) diff --git a/weed/s3api/s3tables/iceberg_layout.go b/weed/s3api/s3tables/iceberg_layout.go index a71fb221d..a754b5d06 100644 --- a/weed/s3api/s3tables/iceberg_layout.go +++ b/weed/s3api/s3tables/iceberg_layout.go @@ -1,14 +1,9 @@ package s3tables import ( - "context" - "encoding/json" - "errors" pathpkg "path" "regexp" "strings" - - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" ) // Iceberg file layout validation @@ -307,130 +302,3 @@ func (v *TableBucketFileValidator) ValidateTableBucketUpload(fullPath string) er return v.layoutValidator.ValidateFilePath(tableRelativePath) } - -// IsTableBucketPath checks if a path is under the table buckets directory -func IsTableBucketPath(fullPath string) bool { - return strings.HasPrefix(fullPath, TablesPath+"/") -} - -// GetTableInfoFromPath extracts bucket, namespace, and table names from a table bucket path -// Returns empty strings if the path doesn't contain enough components -func GetTableInfoFromPath(fullPath string) (bucket, namespace, table string) { - if !strings.HasPrefix(fullPath, TablesPath+"/") { - return "", "", "" - } - - relativePath := strings.TrimPrefix(fullPath, TablesPath+"/") - parts := strings.SplitN(relativePath, "/", 4) - - if len(parts) >= 1 { - bucket = parts[0] - } - if len(parts) >= 2 { - namespace = parts[1] - } - if len(parts) >= 3 { - table = parts[2] - } - - return -} - -// ValidateTableBucketUploadWithClient validates upload and checks that the table exists and is ICEBERG format -func (v *TableBucketFileValidator) ValidateTableBucketUploadWithClient( - ctx context.Context, - client filer_pb.SeaweedFilerClient, - fullPath string, -) error { - // If not a table bucket path, nothing more to check - if !IsTableBucketPath(fullPath) { - return nil - } - - // Get table info and verify it exists - bucket, namespace, table := GetTableInfoFromPath(fullPath) - if bucket == "" || namespace == "" || table == "" { - return nil // Not deep enough to need validation - } - - if strings.HasPrefix(bucket, ".") { - return nil - } - - resp, err := filer_pb.LookupEntry(ctx, client, &filer_pb.LookupDirectoryEntryRequest{ - Directory: TablesPath, - Name: bucket, - }) - if err != nil { - if errors.Is(err, filer_pb.ErrNotFound) { - return nil - } - return &IcebergLayoutError{ - Code: ErrCodeInvalidIcebergLayout, - Message: "failed to verify table bucket: " + err.Error(), - } - } - if resp == nil || !IsTableBucketEntry(resp.Entry) { - return nil - } - - // Now check basic layout once we know this is a table bucket path. - if err := v.ValidateTableBucketUpload(fullPath); err != nil { - return err - } - - // Verify the table exists and has ICEBERG format by checking its metadata - tablePath := GetTablePath(bucket, namespace, table) - dir, name := splitPath(tablePath) - - resp, err = filer_pb.LookupEntry(ctx, client, &filer_pb.LookupDirectoryEntryRequest{ - Directory: dir, - Name: name, - }) - if err != nil { - // Distinguish between "not found" and other errors - if errors.Is(err, filer_pb.ErrNotFound) { - return &IcebergLayoutError{ - Code: ErrCodeInvalidIcebergLayout, - Message: "table does not exist", - } - } - return &IcebergLayoutError{ - Code: ErrCodeInvalidIcebergLayout, - Message: "failed to verify table existence: " + err.Error(), - } - } - - // Check if table has metadata indicating ICEBERG format - if resp.Entry == nil || resp.Entry.Extended == nil { - return &IcebergLayoutError{ - Code: ErrCodeInvalidIcebergLayout, - Message: "table is not a valid ICEBERG table (missing metadata)", - } - } - - metadataBytes, ok := resp.Entry.Extended[ExtendedKeyMetadata] - if !ok { - return &IcebergLayoutError{ - Code: ErrCodeInvalidIcebergLayout, - Message: "table is not in ICEBERG format (missing format metadata)", - } - } - - var metadata tableMetadataInternal - if err := json.Unmarshal(metadataBytes, &metadata); err != nil { - return &IcebergLayoutError{ - Code: ErrCodeInvalidIcebergLayout, - Message: "failed to parse table metadata: " + err.Error(), - } - } - const TableFormatIceberg = "ICEBERG" - if metadata.Format != TableFormatIceberg { - return &IcebergLayoutError{ - Code: ErrCodeInvalidIcebergLayout, - Message: "table is not in " + TableFormatIceberg + " format", - } - } - - return nil -} diff --git a/weed/s3api/s3tables/iceberg_layout_test.go b/weed/s3api/s3tables/iceberg_layout_test.go deleted file mode 100644 index d68b77b46..000000000 --- a/weed/s3api/s3tables/iceberg_layout_test.go +++ /dev/null @@ -1,186 +0,0 @@ -package s3tables - -import ( - "testing" -) - -func TestIcebergLayoutValidator_ValidateFilePath(t *testing.T) { - v := NewIcebergLayoutValidator() - - tests := []struct { - name string - path string - wantErr bool - }{ - // Valid metadata files - {"valid metadata v1", "metadata/v1.metadata.json", false}, - {"valid metadata v123", "metadata/v123.metadata.json", false}, - {"valid snapshot manifest", "metadata/snap-123-1-abc12345-1234-5678-9abc-def012345678.avro", false}, - {"valid manifest file", "metadata/abc12345-1234-5678-9abc-def012345678-m0.avro", false}, - {"valid general manifest", "metadata/abc12345-1234-5678-9abc-def012345678.avro", false}, - {"valid version hint", "metadata/version-hint.text", false}, - {"valid uuid metadata", "metadata/abc12345-1234-5678-9abc-def012345678.metadata.json", false}, - {"valid trino stats", "metadata/20260208_212535_00007_bn4hb-d3599c32-1709-4b94-b6b2-1957b6d6db04.stats", false}, - - // Valid data files - {"valid parquet file", "data/file.parquet", false}, - {"valid orc file", "data/file.orc", false}, - {"valid avro data file", "data/file.avro", false}, - {"valid parquet with path", "data/00000-0-abc12345.parquet", false}, - - // Valid partitioned data - {"valid partitioned parquet", "data/year=2024/file.parquet", false}, - {"valid multi-partition", "data/year=2024/month=01/file.parquet", false}, - {"valid bucket subdirectory", "data/bucket0/file.parquet", false}, - - // Directories only - {"metadata directory bare", "metadata", true}, - {"data directory bare", "data", true}, - {"metadata directory with slash", "metadata/", false}, - {"data directory with slash", "data/", false}, - - // Invalid paths - {"empty path", "", true}, - {"invalid top dir", "invalid/file.parquet", true}, - {"root file", "file.parquet", true}, - {"invalid metadata file", "metadata/random.txt", true}, - {"nested metadata directory", "metadata/nested/v1.metadata.json", true}, - {"nested metadata directory no file", "metadata/nested/", true}, - {"metadata subdir no slash", "metadata/nested", true}, - {"invalid data file", "data/file.csv", true}, - {"invalid data file json", "data/file.json", true}, - - // Partition/subdirectory without trailing slashes - {"partition directory no slash", "data/year=2024", false}, - {"data subdirectory no slash", "data/my_subdir", false}, - {"multi-level partition", "data/event_date=2025-01-01/hour=00/file.parquet", false}, - {"multi-level partition directory", "data/event_date=2025-01-01/hour=00/", false}, - {"multi-level partition directory no slash", "data/event_date=2025-01-01/hour=00", false}, - - // Double slashes - {"data double slash", "data//file.parquet", true}, - {"data redundant slash", "data/year=2024//file.parquet", true}, - {"metadata redundant slash", "metadata//v1.metadata.json", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := v.ValidateFilePath(tt.path) - if (err != nil) != tt.wantErr { - t.Errorf("ValidateFilePath(%q) error = %v, wantErr %v", tt.path, err, tt.wantErr) - } - }) - } -} - -func TestIcebergLayoutValidator_PartitionPaths(t *testing.T) { - v := NewIcebergLayoutValidator() - - validPaths := []string{ - "data/year=2024/file.parquet", - "data/date=2024-01-15/file.parquet", - "data/category=electronics/file.parquet", - "data/user_id=12345/file.parquet", - "data/region=us-east-1/file.parquet", - "data/year=2024/month=01/day=15/file.parquet", - } - - for _, path := range validPaths { - if err := v.ValidateFilePath(path); err != nil { - t.Errorf("ValidateFilePath(%q) should be valid, got error: %v", path, err) - } - } -} - -func TestTableBucketFileValidator_ValidateTableBucketUpload(t *testing.T) { - v := NewTableBucketFileValidator() - - tests := []struct { - name string - path string - wantErr bool - }{ - // Non-table bucket paths should pass (no validation) - {"regular bucket path", "/buckets/mybucket/file.txt", false}, - {"filer path", "/home/user/file.txt", false}, - - // Table bucket structure paths (creating directories) - {"table bucket root", "/buckets/mybucket", false}, - {"namespace dir", "/buckets/mybucket/myns", false}, - {"table dir", "/buckets/mybucket/myns/mytable", false}, - {"table dir trailing slash", "/buckets/mybucket/myns/mytable/", false}, - - // Valid table bucket file uploads - {"valid parquet upload", "/buckets/mybucket/myns/mytable/data/file.parquet", false}, - {"valid metadata upload", "/buckets/mybucket/myns/mytable/metadata/v1.metadata.json", false}, - {"valid trino stats upload", "/buckets/mybucket/myns/mytable/metadata/20260208_212535_00007_bn4hb-d3599c32-1709-4b94-b6b2-1957b6d6db04.stats", false}, - {"valid partitioned data", "/buckets/mybucket/myns/mytable/data/year=2024/file.parquet", false}, - - // Invalid table bucket file uploads - {"invalid file type", "/buckets/mybucket/myns/mytable/data/file.csv", true}, - {"invalid top-level dir", "/buckets/mybucket/myns/mytable/invalid/file.parquet", true}, - {"root file in table", "/buckets/mybucket/myns/mytable/file.parquet", true}, - - // Empty segment cases - {"empty bucket", "/buckets//myns/mytable/data/file.parquet", true}, - {"empty namespace", "/buckets/mybucket//mytable/data/file.parquet", true}, - {"empty table", "/buckets/mybucket/myns//data/file.parquet", true}, - {"empty bucket dir", "/buckets//", true}, - {"empty namespace dir", "/buckets/mybucket//", true}, - {"table double slash bypass", "/buckets/mybucket/myns/mytable//data/file.parquet", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := v.ValidateTableBucketUpload(tt.path) - if (err != nil) != tt.wantErr { - t.Errorf("ValidateTableBucketUpload(%q) error = %v, wantErr %v", tt.path, err, tt.wantErr) - } - }) - } -} - -func TestIsTableBucketPath(t *testing.T) { - tests := []struct { - path string - want bool - }{ - {"/buckets/mybucket", true}, - {"/buckets/mybucket/ns/table/data/file.parquet", true}, - {"/home/user/file.txt", false}, - {"buckets/mybucket", false}, // missing leading slash - } - - for _, tt := range tests { - t.Run(tt.path, func(t *testing.T) { - if got := IsTableBucketPath(tt.path); got != tt.want { - t.Errorf("IsTableBucketPath(%q) = %v, want %v", tt.path, got, tt.want) - } - }) - } -} - -func TestGetTableInfoFromPath(t *testing.T) { - tests := []struct { - path string - wantBucket string - wantNamespace string - wantTable string - }{ - {"/buckets/mybucket/myns/mytable/data/file.parquet", "mybucket", "myns", "mytable"}, - {"/buckets/mybucket/myns/mytable", "mybucket", "myns", "mytable"}, - {"/buckets/mybucket/myns", "mybucket", "myns", ""}, - {"/buckets/mybucket", "mybucket", "", ""}, - {"/home/user/file.txt", "", "", ""}, - } - - for _, tt := range tests { - t.Run(tt.path, func(t *testing.T) { - bucket, namespace, table := GetTableInfoFromPath(tt.path) - if bucket != tt.wantBucket || namespace != tt.wantNamespace || table != tt.wantTable { - t.Errorf("GetTableInfoFromPath(%q) = (%q, %q, %q), want (%q, %q, %q)", - tt.path, bucket, namespace, table, tt.wantBucket, tt.wantNamespace, tt.wantTable) - } - }) - } -} diff --git a/weed/s3api/s3tables/permissions.go b/weed/s3api/s3tables/permissions.go index 4ce198b6d..e5cb45a01 100644 --- a/weed/s3api/s3tables/permissions.go +++ b/weed/s3api/s3tables/permissions.go @@ -90,17 +90,6 @@ type PolicyContext struct { DefaultAllow bool } -// CheckPermissionWithResource checks if a principal has permission to perform an operation on a specific resource -func CheckPermissionWithResource(operation, principal, owner, resourcePolicy, resourceARN string) bool { - return CheckPermissionWithContext(operation, principal, owner, resourcePolicy, resourceARN, nil) -} - -// CheckPermission checks if a principal has permission to perform an operation -// (without resource-specific validation - for backward compatibility) -func CheckPermission(operation, principal, owner, resourcePolicy string) bool { - return CheckPermissionWithContext(operation, principal, owner, resourcePolicy, "", nil) -} - // CheckPermissionWithContext checks permission with optional resource and condition context. func CheckPermissionWithContext(operation, principal, owner, resourcePolicy, resourceARN string, ctx *PolicyContext) bool { // Deny access if identities are empty @@ -415,113 +404,6 @@ func matchesResourcePattern(pattern, resourceARN string) bool { return wildcard.MatchesWildcard(pattern, resourceARN) } -// Helper functions for specific permissions - -// CanCreateTableBucket checks if principal can create table buckets -func CanCreateTableBucket(principal, owner, resourcePolicy string) bool { - return CheckPermission("CreateTableBucket", principal, owner, resourcePolicy) -} - -// CanGetTableBucket checks if principal can get table bucket details -func CanGetTableBucket(principal, owner, resourcePolicy string) bool { - return CheckPermission("GetTableBucket", principal, owner, resourcePolicy) -} - -// CanListTableBuckets checks if principal can list table buckets -func CanListTableBuckets(principal, owner, resourcePolicy string) bool { - return CheckPermission("ListTableBuckets", principal, owner, resourcePolicy) -} - -// CanDeleteTableBucket checks if principal can delete table buckets -func CanDeleteTableBucket(principal, owner, resourcePolicy string) bool { - return CheckPermission("DeleteTableBucket", principal, owner, resourcePolicy) -} - -// CanPutTableBucketPolicy checks if principal can put table bucket policies -func CanPutTableBucketPolicy(principal, owner, resourcePolicy string) bool { - return CheckPermission("PutTableBucketPolicy", principal, owner, resourcePolicy) -} - -// CanGetTableBucketPolicy checks if principal can get table bucket policies -func CanGetTableBucketPolicy(principal, owner, resourcePolicy string) bool { - return CheckPermission("GetTableBucketPolicy", principal, owner, resourcePolicy) -} - -// CanDeleteTableBucketPolicy checks if principal can delete table bucket policies -func CanDeleteTableBucketPolicy(principal, owner, resourcePolicy string) bool { - return CheckPermission("DeleteTableBucketPolicy", principal, owner, resourcePolicy) -} - -// CanCreateNamespace checks if principal can create namespaces -func CanCreateNamespace(principal, owner, resourcePolicy string) bool { - return CheckPermission("CreateNamespace", principal, owner, resourcePolicy) -} - -// CanGetNamespace checks if principal can get namespace details -func CanGetNamespace(principal, owner, resourcePolicy string) bool { - return CheckPermission("GetNamespace", principal, owner, resourcePolicy) -} - -// CanListNamespaces checks if principal can list namespaces -func CanListNamespaces(principal, owner, resourcePolicy string) bool { - return CheckPermission("ListNamespaces", principal, owner, resourcePolicy) -} - -// CanDeleteNamespace checks if principal can delete namespaces -func CanDeleteNamespace(principal, owner, resourcePolicy string) bool { - return CheckPermission("DeleteNamespace", principal, owner, resourcePolicy) -} - -// CanCreateTable checks if principal can create tables -func CanCreateTable(principal, owner, resourcePolicy string) bool { - return CheckPermission("CreateTable", principal, owner, resourcePolicy) -} - -// CanGetTable checks if principal can get table details -func CanGetTable(principal, owner, resourcePolicy string) bool { - return CheckPermission("GetTable", principal, owner, resourcePolicy) -} - -// CanListTables checks if principal can list tables -func CanListTables(principal, owner, resourcePolicy string) bool { - return CheckPermission("ListTables", principal, owner, resourcePolicy) -} - -// CanDeleteTable checks if principal can delete tables -func CanDeleteTable(principal, owner, resourcePolicy string) bool { - return CheckPermission("DeleteTable", principal, owner, resourcePolicy) -} - -// CanPutTablePolicy checks if principal can put table policies -func CanPutTablePolicy(principal, owner, resourcePolicy string) bool { - return CheckPermission("PutTablePolicy", principal, owner, resourcePolicy) -} - -// CanGetTablePolicy checks if principal can get table policies -func CanGetTablePolicy(principal, owner, resourcePolicy string) bool { - return CheckPermission("GetTablePolicy", principal, owner, resourcePolicy) -} - -// CanDeleteTablePolicy checks if principal can delete table policies -func CanDeleteTablePolicy(principal, owner, resourcePolicy string) bool { - return CheckPermission("DeleteTablePolicy", principal, owner, resourcePolicy) -} - -// CanTagResource checks if principal can tag a resource -func CanTagResource(principal, owner, resourcePolicy string) bool { - return CheckPermission("TagResource", principal, owner, resourcePolicy) -} - -// CanUntagResource checks if principal can untag a resource -func CanUntagResource(principal, owner, resourcePolicy string) bool { - return CheckPermission("UntagResource", principal, owner, resourcePolicy) -} - -// CanManageTags checks if principal can manage tags (tag or untag) -func CanManageTags(principal, owner, resourcePolicy string) bool { - return CanTagResource(principal, owner, resourcePolicy) || CanUntagResource(principal, owner, resourcePolicy) -} - // AuthError represents an authorization error type AuthError struct { Operation string diff --git a/weed/s3api/s3tables/utils.go b/weed/s3api/s3tables/utils.go index ff5dd0fe2..2aedefa2b 100644 --- a/weed/s3api/s3tables/utils.go +++ b/weed/s3api/s3tables/utils.go @@ -200,11 +200,6 @@ func validateBucketName(name string) error { return nil } -// ValidateBucketName validates bucket name and returns an error if invalid. -func ValidateBucketName(name string) error { - return validateBucketName(name) -} - // BuildBucketARN builds a bucket ARN with the provided region and account ID. // If region is empty, the ARN will omit the region field. func BuildBucketARN(region, accountID, bucketName string) (string, error) { @@ -367,11 +362,6 @@ func validateNamespace(namespace []string) (string, error) { return flattenNamespace(parts), nil } -// ValidateNamespace is a wrapper to validate namespace for other packages. -func ValidateNamespace(namespace []string) (string, error) { - return validateNamespace(namespace) -} - // ParseNamespace parses a namespace string into namespace parts. func ParseNamespace(namespace string) ([]string, error) { return normalizeNamespace([]string{namespace}) diff --git a/weed/server/common.go b/weed/server/common.go index 32662ada9..9a6b2a7da 100644 --- a/weed/server/common.go +++ b/weed/server/common.go @@ -19,7 +19,6 @@ import ( "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" "github.com/seaweedfs/seaweedfs/weed/util/request_id" - "github.com/seaweedfs/seaweedfs/weed/util/version" "google.golang.org/grpc/metadata" "github.com/seaweedfs/seaweedfs/weed/filer" @@ -237,25 +236,6 @@ func parseURLPath(path string) (vid, fid, filename, ext string, isVolumeIdOnly b return } -func statsHealthHandler(w http.ResponseWriter, r *http.Request) { - m := make(map[string]interface{}) - m["Version"] = version.Version() - writeJsonQuiet(w, r, http.StatusOK, m) -} -func statsCounterHandler(w http.ResponseWriter, r *http.Request) { - m := make(map[string]interface{}) - m["Version"] = version.Version() - m["Counters"] = serverStats - writeJsonQuiet(w, r, http.StatusOK, m) -} - -func statsMemoryHandler(w http.ResponseWriter, r *http.Request) { - m := make(map[string]interface{}) - m["Version"] = version.Version() - m["Memory"] = stats.MemStat() - writeJsonQuiet(w, r, http.StatusOK, m) -} - var StaticFS fs.FS func handleStaticResources(defaultMux *http.ServeMux) { diff --git a/weed/server/filer_server_handlers_proxy.go b/weed/server/filer_server_handlers_proxy.go index 31ee47cdb..cdbb95321 100644 --- a/weed/server/filer_server_handlers_proxy.go +++ b/weed/server/filer_server_handlers_proxy.go @@ -5,7 +5,6 @@ import ( "sync" "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/security" util_http "github.com/seaweedfs/seaweedfs/weed/util/http" "github.com/seaweedfs/seaweedfs/weed/util/mem" "github.com/seaweedfs/seaweedfs/weed/util/request_id" @@ -25,26 +24,6 @@ var ( proxySemaphores sync.Map // host -> chan struct{} ) -func (fs *FilerServer) maybeAddVolumeJwtAuthorization(r *http.Request, fileId string, isWrite bool) { - encodedJwt := fs.maybeGetVolumeJwtAuthorizationToken(fileId, isWrite) - - if encodedJwt == "" { - return - } - - r.Header.Set("Authorization", "BEARER "+string(encodedJwt)) -} - -func (fs *FilerServer) maybeGetVolumeJwtAuthorizationToken(fileId string, isWrite bool) string { - var encodedJwt security.EncodedJwt - if isWrite { - encodedJwt = security.GenJwtForVolumeServer(fs.volumeGuard.SigningKey, fs.volumeGuard.ExpiresAfterSec, fileId) - } else { - encodedJwt = security.GenJwtForVolumeServer(fs.volumeGuard.ReadSigningKey, fs.volumeGuard.ReadExpiresAfterSec, fileId) - } - return string(encodedJwt) -} - func acquireProxySemaphore(ctx context.Context, host string) error { v, _ := proxySemaphores.LoadOrStore(host, make(chan struct{}, proxyReadConcurrencyPerVolumeServer)) sem := v.(chan struct{}) diff --git a/weed/server/filer_server_handlers_write_cipher.go b/weed/server/filer_server_handlers_write_cipher.go deleted file mode 100644 index 2a3fb6b68..000000000 --- a/weed/server/filer_server_handlers_write_cipher.go +++ /dev/null @@ -1,107 +0,0 @@ -package weed_server - -import ( - "bytes" - "context" - "fmt" - "net/http" - "strings" - "time" - - "github.com/seaweedfs/seaweedfs/weed/filer" - "github.com/seaweedfs/seaweedfs/weed/glog" - "github.com/seaweedfs/seaweedfs/weed/operation" - "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" - "github.com/seaweedfs/seaweedfs/weed/storage/needle" - "github.com/seaweedfs/seaweedfs/weed/util" -) - -// handling single chunk POST or PUT upload -func (fs *FilerServer) encrypt(ctx context.Context, w http.ResponseWriter, r *http.Request, so *operation.StorageOption) (filerResult *FilerPostResult, err error) { - - fileId, urlLocation, auth, err := fs.assignNewFileInfo(ctx, so) - - if err != nil || fileId == "" || urlLocation == "" { - return nil, fmt.Errorf("fail to allocate volume for %s, collection:%s, datacenter:%s", r.URL.Path, so.Collection, so.DataCenter) - } - - glog.V(4).InfofCtx(ctx, "write %s to %v", r.URL.Path, urlLocation) - - // Note: encrypt(gzip(data)), encrypt data first, then gzip - - sizeLimit := int64(fs.option.MaxMB) * 1024 * 1024 - - bytesBuffer := bufPool.Get().(*bytes.Buffer) - defer bufPool.Put(bytesBuffer) - - pu, err := needle.ParseUpload(r, sizeLimit, bytesBuffer) - uncompressedData := pu.Data - if pu.IsGzipped { - uncompressedData = pu.UncompressedData - } - if pu.MimeType == "" { - pu.MimeType = http.DetectContentType(uncompressedData) - // println("detect2 mimetype to", pu.MimeType) - } - - uploadOption := &operation.UploadOption{ - UploadUrl: urlLocation, - Filename: pu.FileName, - Cipher: true, - IsInputCompressed: false, - MimeType: pu.MimeType, - PairMap: pu.PairMap, - Jwt: auth, - } - - uploader, uploaderErr := operation.NewUploader() - if uploaderErr != nil { - return nil, fmt.Errorf("uploader initialization error: %w", uploaderErr) - } - - uploadResult, uploadError := uploader.UploadData(ctx, uncompressedData, uploadOption) - if uploadError != nil { - return nil, fmt.Errorf("upload to volume server: %w", uploadError) - } - - // Save to chunk manifest structure - fileChunks := []*filer_pb.FileChunk{uploadResult.ToPbFileChunk(fileId, 0, time.Now().UnixNano())} - - // fmt.Printf("uploaded: %+v\n", uploadResult) - - path := r.URL.Path - if strings.HasSuffix(path, "/") { - if pu.FileName != "" { - path += pu.FileName - } - } - - entry := &filer.Entry{ - FullPath: util.FullPath(path), - Attr: filer.Attr{ - Mtime: time.Now(), - Crtime: time.Now(), - Mode: 0660, - Uid: OS_UID, - Gid: OS_GID, - TtlSec: so.TtlSeconds, - Mime: pu.MimeType, - Md5: util.Base64Md5ToBytes(pu.ContentMd5), - }, - Chunks: fileChunks, - } - - filerResult = &FilerPostResult{ - Name: pu.FileName, - Size: int64(pu.OriginalDataSize), - } - - if dbErr := fs.filer.CreateEntry(ctx, entry, false, false, nil, false, so.MaxFileNameLength); dbErr != nil { - fs.filer.DeleteUncommittedChunks(ctx, entry.GetChunks()) - err = dbErr - filerResult.Error = dbErr.Error() - return - } - - return -} diff --git a/weed/server/filer_server_handlers_write_upload.go b/weed/server/filer_server_handlers_write_upload.go index 2f67e7860..40a5ca4f5 100644 --- a/weed/server/filer_server_handlers_write_upload.go +++ b/weed/server/filer_server_handlers_write_upload.go @@ -196,10 +196,6 @@ func (fs *FilerServer) doUpload(ctx context.Context, urlLocation string, limited return uploadResult, err, data } -func (fs *FilerServer) dataToChunk(ctx context.Context, fileName, contentType string, data []byte, chunkOffset int64, so *operation.StorageOption) ([]*filer_pb.FileChunk, error) { - return fs.dataToChunkWithSSE(ctx, nil, fileName, contentType, data, chunkOffset, so) -} - func (fs *FilerServer) dataToChunkWithSSE(ctx context.Context, r *http.Request, fileName, contentType string, data []byte, chunkOffset int64, so *operation.StorageOption) ([]*filer_pb.FileChunk, error) { dataReader := util.NewBytesReader(data) diff --git a/weed/server/postgres/server.go b/weed/server/postgres/server.go index f35d3704e..1ac4d8b3e 100644 --- a/weed/server/postgres/server.go +++ b/weed/server/postgres/server.go @@ -697,8 +697,3 @@ func (s *PostgreSQLServer) cleanupIdleSessions() { } } } - -// GetAddress returns the server address -func (s *PostgreSQLServer) GetAddress() string { - return fmt.Sprintf("%s:%d", s.config.Host, s.config.Port) -} diff --git a/weed/server/volume_grpc_client_to_master.go b/weed/server/volume_grpc_client_to_master.go index e2523543a..2c484e7ce 100644 --- a/weed/server/volume_grpc_client_to_master.go +++ b/weed/server/volume_grpc_client_to_master.go @@ -106,10 +106,6 @@ func (vs *VolumeServer) StopHeartbeat() (isAlreadyStopping bool) { return false } -func (vs *VolumeServer) doHeartbeat(masterAddress pb.ServerAddress, grpcDialOption grpc.DialOption, sleepInterval time.Duration) (newLeader pb.ServerAddress, err error) { - return vs.doHeartbeatWithRetry(masterAddress, grpcDialOption, sleepInterval, 0) -} - func (vs *VolumeServer) doHeartbeatWithRetry(masterAddress pb.ServerAddress, grpcDialOption grpc.DialOption, sleepInterval time.Duration, duplicateRetryCount int) (newLeader pb.ServerAddress, err error) { ctx, cancel := context.WithCancel(context.Background()) diff --git a/weed/server/volume_server_handlers_admin.go b/weed/server/volume_server_handlers_admin.go index a54369277..dfb90befd 100644 --- a/weed/server/volume_server_handlers_admin.go +++ b/weed/server/volume_server_handlers_admin.go @@ -50,19 +50,3 @@ func (vs *VolumeServer) statusHandler(w http.ResponseWriter, r *http.Request) { m["Volumes"] = vs.store.VolumeInfos() writeJsonQuiet(w, r, http.StatusOK, m) } - -func (vs *VolumeServer) statsDiskHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Server", "SeaweedFS Volume "+version.VERSION) - m := make(map[string]interface{}) - m["Version"] = version.Version() - var ds []*volume_server_pb.DiskStatus - for _, loc := range vs.store.Locations { - if dir, e := filepath.Abs(loc.Directory); e == nil { - newDiskStatus := stats.NewDiskStatus(dir) - newDiskStatus.DiskType = loc.DiskType.String() - ds = append(ds, newDiskStatus) - } - } - m["DiskStatuses"] = ds - writeJsonQuiet(w, r, http.StatusOK, m) -} diff --git a/weed/server/volume_server_handlers_write.go b/weed/server/volume_server_handlers_write.go index 44a2abc34..418f3c235 100644 --- a/weed/server/volume_server_handlers_write.go +++ b/weed/server/volume_server_handlers_write.go @@ -160,11 +160,3 @@ func SetEtag(w http.ResponseWriter, etag string) { } } } - -func getEtag(resp *http.Response) (etag string) { - etag = resp.Header.Get("ETag") - if strings.HasPrefix(etag, "\"") && strings.HasSuffix(etag, "\"") { - return etag[1 : len(etag)-1] - } - return -} diff --git a/weed/server/volume_server_test.go b/weed/server/volume_server_test.go deleted file mode 100644 index ac1ad774e..000000000 --- a/weed/server/volume_server_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package weed_server - -import ( - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/volume_server_pb" - "github.com/seaweedfs/seaweedfs/weed/storage" -) - -func TestMaintenanceMode(t *testing.T) { - testCases := []struct { - name string - pb *volume_server_pb.VolumeServerState - want bool - wantCheckErr string - }{ - { - name: "non-initialized state", - pb: nil, - want: false, - wantCheckErr: "", - }, - { - name: "maintenance mode disabled", - pb: &volume_server_pb.VolumeServerState{ - Maintenance: false, - }, - want: false, - wantCheckErr: "", - }, - { - name: "maintenance mode enabled", - pb: &volume_server_pb.VolumeServerState{ - Maintenance: true, - }, - want: true, - wantCheckErr: "volume server test_1234 is in maintenance mode", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - vs := VolumeServer{ - store: &storage.Store{ - Id: "test_1234", - State: storage.NewStateFromProto("/some/path.pb", tc.pb), - }, - } - - if got, want := vs.MaintenanceMode(), tc.want; got != want { - t.Errorf("MaintenanceMode() returned %v, want %v", got, want) - } - - err, wantErrStr := vs.CheckMaintenanceMode(), tc.wantCheckErr - if err != nil { - if wantErrStr == "" { - t.Errorf("CheckMaintenanceMode() returned error %v, want nil", err) - } - if errStr := err.Error(); errStr != wantErrStr { - t.Errorf("CheckMaintenanceMode() returned error %q, want %q", errStr, wantErrStr) - } - } else { - if wantErrStr != "" { - t.Errorf("CheckMaintenanceMode() returned no error, want %q", wantErrStr) - } - } - }) - } -} diff --git a/weed/sftpd/sftp_file_writer.go b/weed/sftpd/sftp_file_writer.go index fed60eec0..3f6d915b0 100644 --- a/weed/sftpd/sftp_file_writer.go +++ b/weed/sftpd/sftp_file_writer.go @@ -32,28 +32,6 @@ type bufferReader struct { i int64 } -func NewBufferReader(b []byte) *bufferReader { return &bufferReader{b: b} } - -func (r *bufferReader) Read(p []byte) (int, error) { - if r.i >= int64(len(r.b)) { - return 0, io.EOF - } - n := copy(p, r.b[r.i:]) - r.i += int64(n) - return n, nil -} - -func (r *bufferReader) ReadAt(p []byte, off int64) (int, error) { - if off >= int64(len(r.b)) { - return 0, io.EOF - } - n := copy(p, r.b[off:]) - if n < len(p) { - return n, io.EOF - } - return n, nil -} - // listerat implements sftp.ListerAt. type listerat []os.FileInfo diff --git a/weed/shell/command_ec_common.go b/weed/shell/command_ec_common.go index ba84fc7f7..aef207772 100644 --- a/weed/shell/command_ec_common.go +++ b/weed/shell/command_ec_common.go @@ -406,30 +406,6 @@ func sortEcNodesByFreeslotsAscending(ecNodes []*EcNode) { }) } -// if the index node changed the freeEcSlot, need to keep every EcNode still sorted -func ensureSortedEcNodes(data []*CandidateEcNode, index int, lessThan func(i, j int) bool) { - for i := index - 1; i >= 0; i-- { - if lessThan(i+1, i) { - swap(data, i, i+1) - } else { - break - } - } - for i := index + 1; i < len(data); i++ { - if lessThan(i, i-1) { - swap(data, i, i-1) - } else { - break - } - } -} - -func swap(data []*CandidateEcNode, i, j int) { - t := data[i] - data[i] = data[j] - data[j] = t -} - func countShards(ecShardInfos []*master_pb.VolumeEcShardInformationMessage) (count int) { for _, eci := range ecShardInfos { count += erasure_coding.GetShardCount(eci) @@ -1135,48 +1111,6 @@ func (ecb *ecBalancer) pickRackForShardType( return selected.id, nil } -func (ecb *ecBalancer) pickRackToBalanceShardsInto(rackToEcNodes map[RackId]*EcRack, rackToShardCount map[string]int) (RackId, error) { - targets := []RackId{} - targetShards := -1 - for _, shards := range rackToShardCount { - if shards > targetShards { - targetShards = shards - } - } - - details := "" - for rackId, rack := range rackToEcNodes { - shards := rackToShardCount[string(rackId)] - - if rack.freeEcSlot <= 0 { - details += fmt.Sprintf(" Skipped %s because it has no free slots\n", rackId) - continue - } - // For EC shards, replica placement constraint only applies when DiffRackCount > 0. - // When DiffRackCount = 0 (e.g., replica placement "000"), EC shards should be - // distributed freely across racks for fault tolerance - the "000" means - // "no volume replication needed" because erasure coding provides redundancy. - if ecb.replicaPlacement != nil && ecb.replicaPlacement.DiffRackCount > 0 && shards > ecb.replicaPlacement.DiffRackCount { - details += fmt.Sprintf(" Skipped %s because shards %d > replica placement limit for other racks (%d)\n", rackId, shards, ecb.replicaPlacement.DiffRackCount) - continue - } - - if shards < targetShards { - // Favor racks with less shards, to ensure an uniform distribution. - targets = nil - targetShards = shards - } - if shards == targetShards { - targets = append(targets, rackId) - } - } - - if len(targets) == 0 { - return "", errors.New(details) - } - return targets[rand.IntN(len(targets))], nil -} - func (ecb *ecBalancer) balanceEcShardsWithinRacks(collection string) error { // collect vid => []ecNode, since previous steps can change the locations vidLocations := ecb.collectVolumeIdToEcNodes(collection) @@ -1567,46 +1501,6 @@ func (ecb *ecBalancer) pickOneEcNodeAndMoveOneShard(existingLocation *EcNode, co return moveMountedShardToEcNode(ecb.commandEnv, existingLocation, collection, vid, shardId, destNode, destDiskId, ecb.applyBalancing, ecb.diskType) } -func pickNEcShardsToMoveFrom(ecNodes []*EcNode, vid needle.VolumeId, n int, diskType types.DiskType) map[erasure_coding.ShardId]*EcNode { - picked := make(map[erasure_coding.ShardId]*EcNode) - var candidateEcNodes []*CandidateEcNode - for _, ecNode := range ecNodes { - si := findEcVolumeShardsInfo(ecNode, vid, diskType) - if si.Count() > 0 { - candidateEcNodes = append(candidateEcNodes, &CandidateEcNode{ - ecNode: ecNode, - shardCount: si.Count(), - }) - } - } - slices.SortFunc(candidateEcNodes, func(a, b *CandidateEcNode) int { - return b.shardCount - a.shardCount - }) - for i := 0; i < n; i++ { - selectedEcNodeIndex := -1 - for i, candidateEcNode := range candidateEcNodes { - si := findEcVolumeShardsInfo(candidateEcNode.ecNode, vid, diskType) - if si.Count() > 0 { - selectedEcNodeIndex = i - for _, shardId := range si.Ids() { - candidateEcNode.shardCount-- - picked[shardId] = candidateEcNode.ecNode - candidateEcNode.ecNode.deleteEcVolumeShards(vid, []erasure_coding.ShardId{shardId}, diskType) - break - } - break - } - } - if selectedEcNodeIndex >= 0 { - ensureSortedEcNodes(candidateEcNodes, selectedEcNodeIndex, func(i, j int) bool { - return candidateEcNodes[i].shardCount > candidateEcNodes[j].shardCount - }) - } - - } - return picked -} - func (ecb *ecBalancer) collectVolumeIdToEcNodes(collection string) map[needle.VolumeId][]*EcNode { vidLocations := make(map[needle.VolumeId][]*EcNode) for _, ecNode := range ecb.ecNodes { diff --git a/weed/shell/command_ec_common_test.go b/weed/shell/command_ec_common_test.go deleted file mode 100644 index ff186f21d..000000000 --- a/weed/shell/command_ec_common_test.go +++ /dev/null @@ -1,354 +0,0 @@ -package shell - -import ( - "fmt" - "reflect" - "strings" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/master_pb" - "github.com/seaweedfs/seaweedfs/weed/storage/erasure_coding" - "github.com/seaweedfs/seaweedfs/weed/storage/needle" - "github.com/seaweedfs/seaweedfs/weed/storage/super_block" - "github.com/seaweedfs/seaweedfs/weed/storage/types" -) - -func errorCheck(got error, want string) error { - if got == nil && want == "" { - return nil - } - if got != nil && want == "" { - return fmt.Errorf("expected no error, got %q", got.Error()) - } - if got == nil && want != "" { - return fmt.Errorf("got no error, expected %q", want) - } - if !strings.Contains(got.Error(), want) { - return fmt.Errorf("expected error %q, got %q", want, got.Error()) - } - return nil -} - -func TestCollectCollectionsForVolumeIds(t *testing.T) { - testCases := []struct { - topology *master_pb.TopologyInfo - vids []needle.VolumeId - want []string - }{ - // normal volumes - {testTopology1, nil, nil}, - {testTopology1, []needle.VolumeId{}, nil}, - {testTopology1, []needle.VolumeId{needle.VolumeId(9999)}, nil}, - {testTopology1, []needle.VolumeId{needle.VolumeId(2)}, []string{""}}, - {testTopology1, []needle.VolumeId{needle.VolumeId(2), needle.VolumeId(272)}, []string{"", "collection2"}}, - {testTopology1, []needle.VolumeId{needle.VolumeId(2), needle.VolumeId(272), needle.VolumeId(299)}, []string{"", "collection2"}}, - {testTopology1, []needle.VolumeId{needle.VolumeId(272), needle.VolumeId(299), needle.VolumeId(95)}, []string{"collection1", "collection2"}}, - {testTopology1, []needle.VolumeId{needle.VolumeId(272), needle.VolumeId(299), needle.VolumeId(95), needle.VolumeId(51)}, []string{"collection1", "collection2"}}, - {testTopology1, []needle.VolumeId{needle.VolumeId(272), needle.VolumeId(299), needle.VolumeId(95), needle.VolumeId(51), needle.VolumeId(15)}, []string{"collection0", "collection1", "collection2"}}, - // EC volumes - {testTopology2, []needle.VolumeId{needle.VolumeId(9577)}, []string{"s3qldata"}}, - {testTopology2, []needle.VolumeId{needle.VolumeId(9577), needle.VolumeId(12549)}, []string{"s3qldata"}}, - // normal + EC volumes - {testTopology2, []needle.VolumeId{needle.VolumeId(18111)}, []string{"s3qldata"}}, - {testTopology2, []needle.VolumeId{needle.VolumeId(8677)}, []string{"s3qldata"}}, - {testTopology2, []needle.VolumeId{needle.VolumeId(18111), needle.VolumeId(8677)}, []string{"s3qldata"}}, - } - - for _, tc := range testCases { - got := collectCollectionsForVolumeIds(tc.topology, tc.vids) - if !reflect.DeepEqual(got, tc.want) { - t.Errorf("for %v: got %v, want %v", tc.vids, got, tc.want) - } - } -} - -func TestParseReplicaPlacementArg(t *testing.T) { - getDefaultReplicaPlacementOrig := getDefaultReplicaPlacement - getDefaultReplicaPlacement = func(commandEnv *CommandEnv) (*super_block.ReplicaPlacement, error) { - return super_block.NewReplicaPlacementFromString("123") - } - defer func() { - getDefaultReplicaPlacement = getDefaultReplicaPlacementOrig - }() - - testCases := []struct { - argument string - want string - wantErr string - }{ - {"lalala", "lal", "unexpected replication type"}, - {"", "123", ""}, - {"021", "021", ""}, - } - - for _, tc := range testCases { - commandEnv := &CommandEnv{} - got, gotErr := parseReplicaPlacementArg(commandEnv, tc.argument) - - if err := errorCheck(gotErr, tc.wantErr); err != nil { - t.Errorf("argument %q: %s", tc.argument, err.Error()) - continue - } - - want, _ := super_block.NewReplicaPlacementFromString(tc.want) - if !got.Equals(want) { - t.Errorf("got replica placement %q, want %q", got.String(), want.String()) - } - } -} - -func TestEcDistribution(t *testing.T) { - - // find out all volume servers with one slot left. - ecNodes, totalFreeEcSlots := collectEcVolumeServersByDc(testTopology1, "", types.HardDriveType) - - sortEcNodesByFreeslotsDescending(ecNodes) - - if totalFreeEcSlots < erasure_coding.TotalShardsCount { - t.Errorf("not enough free ec shard slots: %d", totalFreeEcSlots) - } - allocatedDataNodes := ecNodes - if len(allocatedDataNodes) > erasure_coding.TotalShardsCount { - allocatedDataNodes = allocatedDataNodes[:erasure_coding.TotalShardsCount] - } - - for _, dn := range allocatedDataNodes { - // fmt.Printf("info %+v %+v\n", dn.info, dn) - fmt.Printf("=> %+v %+v\n", dn.info.Id, dn.freeEcSlot) - } -} - -func TestPickRackToBalanceShardsInto(t *testing.T) { - testCases := []struct { - topology *master_pb.TopologyInfo - vid string - replicaPlacement string - wantOneOf []string - wantErr string - }{ - // Non-EC volumes. We don't care about these, but the function should return all racks as a safeguard. - {testTopologyEc, "", "123", []string{"rack1", "rack2", "rack3", "rack4", "rack5", "rack6"}, ""}, - {testTopologyEc, "6225", "123", []string{"rack1", "rack2", "rack3", "rack4", "rack5", "rack6"}, ""}, - {testTopologyEc, "6226", "123", []string{"rack1", "rack2", "rack3", "rack4", "rack5", "rack6"}, ""}, - {testTopologyEc, "6241", "123", []string{"rack1", "rack2", "rack3", "rack4", "rack5", "rack6"}, ""}, - {testTopologyEc, "6242", "123", []string{"rack1", "rack2", "rack3", "rack4", "rack5", "rack6"}, ""}, - // EC volumes. - // With replication "000" (DiffRackCount=0), EC shards should be distributed freely - // because erasure coding provides its own redundancy. No replica placement error. - {testTopologyEc, "9577", "", []string{"rack1", "rack2", "rack3"}, ""}, - {testTopologyEc, "9577", "111", []string{"rack1", "rack2", "rack3"}, ""}, - {testTopologyEc, "9577", "222", []string{"rack1", "rack2", "rack3"}, ""}, - {testTopologyEc, "10457", "222", []string{"rack1"}, ""}, - {testTopologyEc, "12737", "222", []string{"rack2"}, ""}, - {testTopologyEc, "14322", "222", []string{"rack3"}, ""}, - } - - for _, tc := range testCases { - vid, _ := needle.NewVolumeId(tc.vid) - ecNodes, _ := collectEcVolumeServersByDc(tc.topology, "", types.HardDriveType) - rp, _ := super_block.NewReplicaPlacementFromString(tc.replicaPlacement) - - ecb := &ecBalancer{ - ecNodes: ecNodes, - replicaPlacement: rp, - diskType: types.HardDriveType, - } - - racks := ecb.racks() - rackToShardCount := countShardsByRack(vid, ecNodes, types.HardDriveType) - - got, gotErr := ecb.pickRackToBalanceShardsInto(racks, rackToShardCount) - if err := errorCheck(gotErr, tc.wantErr); err != nil { - t.Errorf("volume %q: %s", tc.vid, err.Error()) - continue - } - - if string(got) == "" && len(tc.wantOneOf) == 0 { - continue - } - found := false - for _, want := range tc.wantOneOf { - if got := string(got); got == want { - found = true - break - } - } - if !(found) { - t.Errorf("expected one of %v for volume %q, got %q", tc.wantOneOf, tc.vid, got) - } - } -} -func TestPickEcNodeToBalanceShardsInto(t *testing.T) { - testCases := []struct { - topology *master_pb.TopologyInfo - nodeId string - vid string - wantOneOf []string - wantErr string - }{ - {testTopologyEc, "", "", nil, "INTERNAL: missing source nodes"}, - {testTopologyEc, "idontexist", "12737", nil, "INTERNAL: missing source nodes"}, - // Non-EC nodes. We don't care about these, but the function should return all available target nodes as a safeguard. - { - testTopologyEc, "172.19.0.10:8702", "6225", []string{ - "172.19.0.13:8701", "172.19.0.14:8711", "172.19.0.16:8704", "172.19.0.17:8703", - "172.19.0.19:8700", "172.19.0.20:8706", "172.19.0.21:8710", "172.19.0.3:8708", - "172.19.0.4:8707", "172.19.0.5:8705", "172.19.0.6:8713", "172.19.0.8:8709", - "172.19.0.9:8712"}, - "", - }, - { - testTopologyEc, "172.19.0.8:8709", "6226", []string{ - "172.19.0.10:8702", "172.19.0.13:8701", "172.19.0.14:8711", "172.19.0.16:8704", - "172.19.0.17:8703", "172.19.0.19:8700", "172.19.0.20:8706", "172.19.0.21:8710", - "172.19.0.3:8708", "172.19.0.4:8707", "172.19.0.5:8705", "172.19.0.6:8713", - "172.19.0.9:8712"}, - "", - }, - // EC volumes. - {testTopologyEc, "172.19.0.10:8702", "14322", []string{ - "172.19.0.14:8711", "172.19.0.5:8705", "172.19.0.6:8713"}, - ""}, - {testTopologyEc, "172.19.0.13:8701", "10457", []string{ - "172.19.0.10:8702", "172.19.0.6:8713"}, - ""}, - {testTopologyEc, "172.19.0.17:8703", "12737", []string{ - "172.19.0.13:8701"}, - ""}, - {testTopologyEc, "172.19.0.20:8706", "14322", []string{ - "172.19.0.14:8711", "172.19.0.5:8705", "172.19.0.6:8713"}, - ""}, - } - - for _, tc := range testCases { - vid, _ := needle.NewVolumeId(tc.vid) - allEcNodes, _ := collectEcVolumeServersByDc(tc.topology, "", types.HardDriveType) - - ecb := &ecBalancer{ - ecNodes: allEcNodes, - diskType: types.HardDriveType, - } - - // Resolve target node by name - var ecNode *EcNode - for _, n := range allEcNodes { - if n.info.Id == tc.nodeId { - ecNode = n - break - } - } - - got, gotErr := ecb.pickEcNodeToBalanceShardsInto(vid, ecNode, allEcNodes) - if err := errorCheck(gotErr, tc.wantErr); err != nil { - t.Errorf("node %q, volume %q: %s", tc.nodeId, tc.vid, err.Error()) - continue - } - - if got == nil { - if len(tc.wantOneOf) == 0 { - continue - } - t.Errorf("node %q, volume %q: got no node, want %q", tc.nodeId, tc.vid, tc.wantOneOf) - continue - } - found := false - for _, want := range tc.wantOneOf { - if got := got.info.Id; got == want { - found = true - break - } - } - if !(found) { - t.Errorf("expected one of %v for volume %q, got %q", tc.wantOneOf, tc.vid, got.info.Id) - } - } -} - -func TestCountFreeShardSlots(t *testing.T) { - testCases := []struct { - name string - topology *master_pb.TopologyInfo - diskType types.DiskType - want map[string]int - }{ - { - name: "topology #1, free HDD shards", - topology: testTopology1, - diskType: types.HardDriveType, - want: map[string]int{ - "192.168.1.1:8080": 17330, - "192.168.1.2:8080": 1540, - "192.168.1.4:8080": 1900, - "192.168.1.5:8080": 27010, - "192.168.1.6:8080": 17420, - }, - }, - { - name: "topology #1, no free SSD shards available", - topology: testTopology1, - diskType: types.SsdType, - want: map[string]int{ - "192.168.1.1:8080": 0, - "192.168.1.2:8080": 0, - "192.168.1.4:8080": 0, - "192.168.1.5:8080": 0, - "192.168.1.6:8080": 0, - }, - }, - { - name: "topology #2, no negative free HDD shards", - topology: testTopology2, - diskType: types.HardDriveType, - want: map[string]int{ - "172.19.0.3:8708": 0, - "172.19.0.4:8707": 8, - "172.19.0.5:8705": 58, - "172.19.0.6:8713": 39, - "172.19.0.8:8709": 8, - "172.19.0.9:8712": 0, - "172.19.0.10:8702": 0, - "172.19.0.13:8701": 0, - "172.19.0.14:8711": 0, - "172.19.0.16:8704": 89, - "172.19.0.17:8703": 0, - "172.19.0.19:8700": 9, - "172.19.0.20:8706": 0, - "172.19.0.21:8710": 9, - }, - }, - { - name: "topology #2, no free SSD shards available", - topology: testTopology2, - diskType: types.SsdType, - want: map[string]int{ - "172.19.0.10:8702": 0, - "172.19.0.13:8701": 0, - "172.19.0.14:8711": 0, - "172.19.0.16:8704": 0, - "172.19.0.17:8703": 0, - "172.19.0.19:8700": 0, - "172.19.0.20:8706": 0, - "172.19.0.21:8710": 0, - "172.19.0.3:8708": 0, - "172.19.0.4:8707": 0, - "172.19.0.5:8705": 0, - "172.19.0.6:8713": 0, - "172.19.0.8:8709": 0, - "172.19.0.9:8712": 0, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got := map[string]int{} - eachDataNode(tc.topology, func(dc DataCenterId, rack RackId, dn *master_pb.DataNodeInfo) { - got[dn.Id] = countFreeShardSlots(dn, tc.diskType) - }) - - if !reflect.DeepEqual(got, tc.want) { - t.Errorf("got %v, want %v", got, tc.want) - } - }) - } -} diff --git a/weed/shell/commands.go b/weed/shell/commands.go index 741dff6b0..6679c15c9 100644 --- a/weed/shell/commands.go +++ b/weed/shell/commands.go @@ -4,8 +4,6 @@ import ( "context" "fmt" "io" - "net/url" - "strconv" "strings" "github.com/seaweedfs/seaweedfs/weed/operation" @@ -138,25 +136,6 @@ func (ce *CommandEnv) GetDataCenter() string { return ce.MasterClient.GetDataCenter() } -func parseFilerUrl(entryPath string) (filerServer string, filerPort int64, path string, err error) { - if strings.HasPrefix(entryPath, "http") { - var u *url.URL - u, err = url.Parse(entryPath) - if err != nil { - return - } - filerServer = u.Hostname() - portString := u.Port() - if portString != "" { - filerPort, err = strconv.ParseInt(portString, 10, 32) - } - path = u.Path - } else { - err = fmt.Errorf("path should have full url /path/to/dirOrFile : %s", entryPath) - } - return -} - func findInputDirectory(args []string) (input string) { input = "." if len(args) > 0 { diff --git a/weed/shell/ec_proportional_rebalance.go b/weed/shell/ec_proportional_rebalance.go index 52adf4297..8d6b1c1b7 100644 --- a/weed/shell/ec_proportional_rebalance.go +++ b/weed/shell/ec_proportional_rebalance.go @@ -1,8 +1,6 @@ package shell import ( - "fmt" - "github.com/seaweedfs/seaweedfs/weed/storage/erasure_coding" "github.com/seaweedfs/seaweedfs/weed/storage/erasure_coding/distribution" "github.com/seaweedfs/seaweedfs/weed/storage/needle" @@ -13,18 +11,6 @@ import ( // ECDistribution is an alias to the distribution package type for backward compatibility type ECDistribution = distribution.ECDistribution -// CalculateECDistribution computes the target EC shard distribution based on replication policy. -// This is a convenience wrapper that uses the default 10+4 EC configuration. -// For custom EC ratios, use the distribution package directly. -func CalculateECDistribution(totalShards, parityShards int, rp *super_block.ReplicaPlacement) *ECDistribution { - ec := distribution.ECConfig{ - DataShards: totalShards - parityShards, - ParityShards: parityShards, - } - rep := distribution.NewReplicationConfig(rp) - return distribution.CalculateDistribution(ec, rep) -} - // TopologyDistributionAnalysis holds the current shard distribution analysis // This wraps the distribution package's TopologyAnalysis with shell-specific EcNode handling type TopologyDistributionAnalysis struct { @@ -34,99 +20,6 @@ type TopologyDistributionAnalysis struct { nodeMap map[string]*EcNode // nodeID -> EcNode } -// NewTopologyDistributionAnalysis creates a new analysis structure -func NewTopologyDistributionAnalysis() *TopologyDistributionAnalysis { - return &TopologyDistributionAnalysis{ - inner: distribution.NewTopologyAnalysis(), - nodeMap: make(map[string]*EcNode), - } -} - -// AddNode adds a node and its shards to the analysis -func (a *TopologyDistributionAnalysis) AddNode(node *EcNode, shardsInfo *erasure_coding.ShardsInfo) { - nodeId := node.info.Id - - // Create distribution.TopologyNode from EcNode - topoNode := &distribution.TopologyNode{ - NodeID: nodeId, - DataCenter: string(node.dc), - Rack: string(node.rack), - FreeSlots: node.freeEcSlot, - TotalShards: shardsInfo.Count(), - ShardIDs: shardsInfo.IdsInt(), - } - - a.inner.AddNode(topoNode) - a.nodeMap[nodeId] = node - - // Add shard locations - for _, shardId := range shardsInfo.Ids() { - a.inner.AddShardLocation(distribution.ShardLocation{ - ShardID: int(shardId), - NodeID: nodeId, - DataCenter: string(node.dc), - Rack: string(node.rack), - }) - } -} - -// Finalize completes the analysis -func (a *TopologyDistributionAnalysis) Finalize() { - a.inner.Finalize() -} - -// String returns a summary -func (a *TopologyDistributionAnalysis) String() string { - return a.inner.String() -} - -// DetailedString returns detailed analysis -func (a *TopologyDistributionAnalysis) DetailedString() string { - return a.inner.DetailedString() -} - -// GetShardsByDC returns shard counts by DC -func (a *TopologyDistributionAnalysis) GetShardsByDC() map[DataCenterId]int { - result := make(map[DataCenterId]int) - for dc, count := range a.inner.ShardsByDC { - result[DataCenterId(dc)] = count - } - return result -} - -// GetShardsByRack returns shard counts by rack -func (a *TopologyDistributionAnalysis) GetShardsByRack() map[RackId]int { - result := make(map[RackId]int) - for rack, count := range a.inner.ShardsByRack { - result[RackId(rack)] = count - } - return result -} - -// GetShardsByNode returns shard counts by node -func (a *TopologyDistributionAnalysis) GetShardsByNode() map[EcNodeId]int { - result := make(map[EcNodeId]int) - for nodeId, count := range a.inner.ShardsByNode { - result[EcNodeId(nodeId)] = count - } - return result -} - -// AnalyzeVolumeDistribution creates an analysis of current shard distribution for a volume -func AnalyzeVolumeDistribution(volumeId needle.VolumeId, locations []*EcNode, diskType types.DiskType) *TopologyDistributionAnalysis { - analysis := NewTopologyDistributionAnalysis() - - for _, node := range locations { - si := findEcVolumeShardsInfo(node, volumeId, diskType) - if si.Count() > 0 { - analysis.AddNode(node, si) - } - } - - analysis.Finalize() - return analysis -} - // ECShardMove represents a planned shard move (shell-specific with EcNode references) type ECShardMove struct { VolumeId needle.VolumeId @@ -136,12 +29,6 @@ type ECShardMove struct { Reason string } -// String returns a human-readable description -func (m ECShardMove) String() string { - return fmt.Sprintf("volume %d shard %d: %s -> %s (%s)", - m.VolumeId, m.ShardId, m.SourceNode.info.Id, m.DestNode.info.Id, m.Reason) -} - // ProportionalECRebalancer implements proportional shard distribution for shell commands type ProportionalECRebalancer struct { ecNodes []*EcNode @@ -149,133 +36,3 @@ type ProportionalECRebalancer struct { diskType types.DiskType ecConfig distribution.ECConfig } - -// NewProportionalECRebalancer creates a new proportional rebalancer with default EC config -func NewProportionalECRebalancer( - ecNodes []*EcNode, - rp *super_block.ReplicaPlacement, - diskType types.DiskType, -) *ProportionalECRebalancer { - return NewProportionalECRebalancerWithConfig( - ecNodes, - rp, - diskType, - distribution.DefaultECConfig(), - ) -} - -// NewProportionalECRebalancerWithConfig creates a rebalancer with custom EC configuration -func NewProportionalECRebalancerWithConfig( - ecNodes []*EcNode, - rp *super_block.ReplicaPlacement, - diskType types.DiskType, - ecConfig distribution.ECConfig, -) *ProportionalECRebalancer { - return &ProportionalECRebalancer{ - ecNodes: ecNodes, - replicaPlacement: rp, - diskType: diskType, - ecConfig: ecConfig, - } -} - -// PlanMoves generates a plan for moving shards to achieve proportional distribution -func (r *ProportionalECRebalancer) PlanMoves( - volumeId needle.VolumeId, - locations []*EcNode, -) ([]ECShardMove, error) { - // Build topology analysis - analysis := distribution.NewTopologyAnalysis() - nodeMap := make(map[string]*EcNode) - - // Add all EC nodes to the analysis (even those without shards) - for _, node := range r.ecNodes { - nodeId := node.info.Id - topoNode := &distribution.TopologyNode{ - NodeID: nodeId, - DataCenter: string(node.dc), - Rack: string(node.rack), - FreeSlots: node.freeEcSlot, - } - analysis.AddNode(topoNode) - nodeMap[nodeId] = node - } - - // Add shard locations from nodes that have shards - for _, node := range locations { - nodeId := node.info.Id - si := findEcVolumeShardsInfo(node, volumeId, r.diskType) - for _, shardId := range si.Ids() { - analysis.AddShardLocation(distribution.ShardLocation{ - ShardID: int(shardId), - NodeID: nodeId, - DataCenter: string(node.dc), - Rack: string(node.rack), - }) - } - if _, exists := nodeMap[nodeId]; !exists { - nodeMap[nodeId] = node - } - } - - analysis.Finalize() - - // Create rebalancer and plan moves - rep := distribution.NewReplicationConfig(r.replicaPlacement) - rebalancer := distribution.NewRebalancer(r.ecConfig, rep) - - plan, err := rebalancer.PlanRebalance(analysis) - if err != nil { - return nil, err - } - - // Convert distribution moves to shell moves - var moves []ECShardMove - for _, move := range plan.Moves { - srcNode := nodeMap[move.SourceNode.NodeID] - destNode := nodeMap[move.DestNode.NodeID] - if srcNode == nil || destNode == nil { - continue - } - - moves = append(moves, ECShardMove{ - VolumeId: volumeId, - ShardId: erasure_coding.ShardId(move.ShardID), - SourceNode: srcNode, - DestNode: destNode, - Reason: move.Reason, - }) - } - - return moves, nil -} - -// GetDistributionSummary returns a summary of the planned distribution -func GetDistributionSummary(rp *super_block.ReplicaPlacement) string { - ec := distribution.DefaultECConfig() - rep := distribution.NewReplicationConfig(rp) - dist := distribution.CalculateDistribution(ec, rep) - return dist.Summary() -} - -// GetDistributionSummaryWithConfig returns a summary with custom EC configuration -func GetDistributionSummaryWithConfig(rp *super_block.ReplicaPlacement, ecConfig distribution.ECConfig) string { - rep := distribution.NewReplicationConfig(rp) - dist := distribution.CalculateDistribution(ecConfig, rep) - return dist.Summary() -} - -// GetFaultToleranceAnalysis returns fault tolerance analysis for the given configuration -func GetFaultToleranceAnalysis(rp *super_block.ReplicaPlacement) string { - ec := distribution.DefaultECConfig() - rep := distribution.NewReplicationConfig(rp) - dist := distribution.CalculateDistribution(ec, rep) - return dist.FaultToleranceAnalysis() -} - -// GetFaultToleranceAnalysisWithConfig returns fault tolerance analysis with custom EC configuration -func GetFaultToleranceAnalysisWithConfig(rp *super_block.ReplicaPlacement, ecConfig distribution.ECConfig) string { - rep := distribution.NewReplicationConfig(rp) - dist := distribution.CalculateDistribution(ecConfig, rep) - return dist.FaultToleranceAnalysis() -} diff --git a/weed/shell/ec_proportional_rebalance_test.go b/weed/shell/ec_proportional_rebalance_test.go deleted file mode 100644 index c8ec99e0a..000000000 --- a/weed/shell/ec_proportional_rebalance_test.go +++ /dev/null @@ -1,251 +0,0 @@ -package shell - -import ( - "testing" - - "github.com/seaweedfs/seaweedfs/weed/pb/master_pb" - "github.com/seaweedfs/seaweedfs/weed/storage/erasure_coding" - "github.com/seaweedfs/seaweedfs/weed/storage/erasure_coding/distribution" - "github.com/seaweedfs/seaweedfs/weed/storage/needle" - "github.com/seaweedfs/seaweedfs/weed/storage/super_block" - "github.com/seaweedfs/seaweedfs/weed/storage/types" -) - -func TestCalculateECDistributionShell(t *testing.T) { - // Test the shell wrapper function - rp, _ := super_block.NewReplicaPlacementFromString("110") - - dist := CalculateECDistribution( - erasure_coding.TotalShardsCount, - erasure_coding.ParityShardsCount, - rp, - ) - - if dist.ReplicationConfig.MinDataCenters != 2 { - t.Errorf("Expected 2 DCs, got %d", dist.ReplicationConfig.MinDataCenters) - } - if dist.TargetShardsPerDC != 7 { - t.Errorf("Expected 7 shards per DC, got %d", dist.TargetShardsPerDC) - } - - t.Log(dist.Summary()) -} - -func TestAnalyzeVolumeDistributionShell(t *testing.T) { - diskType := types.HardDriveType - diskTypeKey := string(diskType) - - // Build a topology with unbalanced distribution - node1 := &EcNode{ - info: &master_pb.DataNodeInfo{ - Id: "127.0.0.1:8080", - DiskInfos: map[string]*master_pb.DiskInfo{ - diskTypeKey: { - Type: diskTypeKey, - MaxVolumeCount: 10, - EcShardInfos: []*master_pb.VolumeEcShardInformationMessage{ - { - Id: 1, - Collection: "test", - EcIndexBits: 0x3FFF, // All 14 shards - }, - }, - }, - }, - }, - dc: "dc1", - rack: "rack1", - freeEcSlot: 5, - } - - node2 := &EcNode{ - info: &master_pb.DataNodeInfo{ - Id: "127.0.0.1:8081", - DiskInfos: map[string]*master_pb.DiskInfo{ - diskTypeKey: { - Type: diskTypeKey, - MaxVolumeCount: 10, - EcShardInfos: []*master_pb.VolumeEcShardInformationMessage{}, - }, - }, - }, - dc: "dc2", - rack: "rack2", - freeEcSlot: 10, - } - - locations := []*EcNode{node1, node2} - volumeId := needle.VolumeId(1) - - analysis := AnalyzeVolumeDistribution(volumeId, locations, diskType) - - shardsByDC := analysis.GetShardsByDC() - if shardsByDC["dc1"] != 14 { - t.Errorf("Expected 14 shards in dc1, got %d", shardsByDC["dc1"]) - } - - t.Log(analysis.DetailedString()) -} - -func TestProportionalRebalancerShell(t *testing.T) { - diskType := types.HardDriveType - diskTypeKey := string(diskType) - - // Build topology: 2 DCs, 2 racks each, all shards on one node - nodes := []*EcNode{ - { - info: &master_pb.DataNodeInfo{ - Id: "dc1-rack1-node1", - DiskInfos: map[string]*master_pb.DiskInfo{ - diskTypeKey: { - Type: diskTypeKey, - MaxVolumeCount: 10, - EcShardInfos: []*master_pb.VolumeEcShardInformationMessage{ - {Id: 1, Collection: "test", EcIndexBits: 0x3FFF}, - }, - }, - }, - }, - dc: "dc1", rack: "dc1-rack1", freeEcSlot: 0, - }, - { - info: &master_pb.DataNodeInfo{ - Id: "dc1-rack2-node1", - DiskInfos: map[string]*master_pb.DiskInfo{ - diskTypeKey: {Type: diskTypeKey, MaxVolumeCount: 10}, - }, - }, - dc: "dc1", rack: "dc1-rack2", freeEcSlot: 10, - }, - { - info: &master_pb.DataNodeInfo{ - Id: "dc2-rack1-node1", - DiskInfos: map[string]*master_pb.DiskInfo{ - diskTypeKey: {Type: diskTypeKey, MaxVolumeCount: 10}, - }, - }, - dc: "dc2", rack: "dc2-rack1", freeEcSlot: 10, - }, - { - info: &master_pb.DataNodeInfo{ - Id: "dc2-rack2-node1", - DiskInfos: map[string]*master_pb.DiskInfo{ - diskTypeKey: {Type: diskTypeKey, MaxVolumeCount: 10}, - }, - }, - dc: "dc2", rack: "dc2-rack2", freeEcSlot: 10, - }, - } - - rp, _ := super_block.NewReplicaPlacementFromString("110") - rebalancer := NewProportionalECRebalancer(nodes, rp, diskType) - - volumeId := needle.VolumeId(1) - moves, err := rebalancer.PlanMoves(volumeId, []*EcNode{nodes[0]}) - - if err != nil { - t.Fatalf("PlanMoves failed: %v", err) - } - - t.Logf("Planned %d moves", len(moves)) - for i, move := range moves { - t.Logf(" %d. %s", i+1, move.String()) - } - - // Verify moves to dc2 - movedToDC2 := 0 - for _, move := range moves { - if move.DestNode.dc == "dc2" { - movedToDC2++ - } - } - - if movedToDC2 == 0 { - t.Error("Expected some moves to dc2") - } -} - -func TestCustomECConfigRebalancer(t *testing.T) { - diskType := types.HardDriveType - diskTypeKey := string(diskType) - - // Test with custom 8+4 EC configuration - ecConfig, err := distribution.NewECConfig(8, 4) - if err != nil { - t.Fatalf("Failed to create EC config: %v", err) - } - - // Build topology for 12 shards (8+4) - nodes := []*EcNode{ - { - info: &master_pb.DataNodeInfo{ - Id: "dc1-node1", - DiskInfos: map[string]*master_pb.DiskInfo{ - diskTypeKey: { - Type: diskTypeKey, - MaxVolumeCount: 10, - EcShardInfos: []*master_pb.VolumeEcShardInformationMessage{ - {Id: 1, Collection: "test", EcIndexBits: 0x0FFF}, // 12 shards (bits 0-11) - }, - }, - }, - }, - dc: "dc1", rack: "dc1-rack1", freeEcSlot: 0, - }, - { - info: &master_pb.DataNodeInfo{ - Id: "dc2-node1", - DiskInfos: map[string]*master_pb.DiskInfo{ - diskTypeKey: {Type: diskTypeKey, MaxVolumeCount: 10}, - }, - }, - dc: "dc2", rack: "dc2-rack1", freeEcSlot: 10, - }, - { - info: &master_pb.DataNodeInfo{ - Id: "dc3-node1", - DiskInfos: map[string]*master_pb.DiskInfo{ - diskTypeKey: {Type: diskTypeKey, MaxVolumeCount: 10}, - }, - }, - dc: "dc3", rack: "dc3-rack1", freeEcSlot: 10, - }, - } - - rp, _ := super_block.NewReplicaPlacementFromString("200") // 3 DCs - rebalancer := NewProportionalECRebalancerWithConfig(nodes, rp, diskType, ecConfig) - - volumeId := needle.VolumeId(1) - moves, err := rebalancer.PlanMoves(volumeId, []*EcNode{nodes[0]}) - - if err != nil { - t.Fatalf("PlanMoves failed: %v", err) - } - - t.Logf("Custom 8+4 EC with 200 replication: planned %d moves", len(moves)) - - // Get the distribution summary - summary := GetDistributionSummaryWithConfig(rp, ecConfig) - t.Log(summary) - - analysis := GetFaultToleranceAnalysisWithConfig(rp, ecConfig) - t.Log(analysis) -} - -func TestGetDistributionSummaryShell(t *testing.T) { - rp, _ := super_block.NewReplicaPlacementFromString("110") - - summary := GetDistributionSummary(rp) - t.Log(summary) - - if len(summary) == 0 { - t.Error("Summary should not be empty") - } - - analysis := GetFaultToleranceAnalysis(rp) - t.Log(analysis) - - if len(analysis) == 0 { - t.Error("Analysis should not be empty") - } -} diff --git a/weed/shell/shell_liner.go b/weed/shell/shell_liner.go index 00831d42e..78afc7880 100644 --- a/weed/shell/shell_liner.go +++ b/weed/shell/shell_liner.go @@ -126,17 +126,6 @@ func processEachCmd(cmd string, commandEnv *CommandEnv) bool { return false } -func stripQuotes(s string) string { - tokens, unbalanced := parseShellInput(s, false) - if unbalanced { - return s - } - if len(tokens) > 0 { - return tokens[0] - } - return "" -} - func splitCommandLine(line string) []string { tokens, _ := parseShellInput(line, true) return tokens diff --git a/weed/shell/shell_liner_test.go b/weed/shell/shell_liner_test.go deleted file mode 100644 index bfdd2b378..000000000 --- a/weed/shell/shell_liner_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package shell - -import ( - "flag" - "reflect" - "testing" -) - -func TestSplitCommandLine(t *testing.T) { - tests := []struct { - input string - expected []string - }{ - { - input: `s3.configure -user=test`, - expected: []string{`s3.configure`, `-user=test`}, - }, - { - input: `s3.configure -user=Test_number_004 -account_display_name="Test number 004" -actions=write -apply`, - expected: []string{`s3.configure`, `-user=Test_number_004`, `-account_display_name=Test number 004`, `-actions=write`, `-apply`}, - }, - { - input: `s3.configure -user=Test_number_004 -account_display_name='Test number 004' -actions=write -apply`, - expected: []string{`s3.configure`, `-user=Test_number_004`, `-account_display_name=Test number 004`, `-actions=write`, `-apply`}, - }, - { - input: `s3.configure -flag="a b"c'd e'`, - expected: []string{`s3.configure`, `-flag=a bcd e`}, - }, - { - input: `s3.configure -name="a\"b"`, - expected: []string{`s3.configure`, `-name=a"b`}, - }, - { - input: `s3.configure -path='a\ b'`, - expected: []string{`s3.configure`, `-path=a\ b`}, - }, - } - - for _, tt := range tests { - got := splitCommandLine(tt.input) - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("input: %s\ngot: %v\nwant: %v", tt.input, got, tt.expected) - } - } -} - -func TestStripQuotes(t *testing.T) { - tests := []struct { - input string - expected string - }{ - {input: `"Test number 004"`, expected: `Test number 004`}, - {input: `'Test number 004'`, expected: `Test number 004`}, - {input: `-account_display_name="Test number 004"`, expected: `-account_display_name=Test number 004`}, - {input: `-flag="a"b'c'`, expected: `-flag=abc`}, - {input: `-name="a\"b"`, expected: `-name=a"b`}, - {input: `-path='a\ b'`, expected: `-path=a\ b`}, - {input: `"unbalanced`, expected: `"unbalanced`}, - {input: `'unbalanced`, expected: `'unbalanced`}, - {input: `-name="a\"b`, expected: `-name="a\"b`}, - {input: `trailing\`, expected: `trailing\`}, - } - - for _, tt := range tests { - got := stripQuotes(tt.input) - if got != tt.expected { - t.Errorf("input: %s, got: %s, want: %s", tt.input, got, tt.expected) - } - } -} - -func TestFlagParsing(t *testing.T) { - fs := flag.NewFlagSet("test", flag.ContinueOnError) - displayName := fs.String("account_display_name", "", "display name") - - rawArg := `-account_display_name="Test number 004"` - args := []string{stripQuotes(rawArg)} - err := fs.Parse(args) - if err != nil { - t.Fatal(err) - } - - expected := "Test number 004" - if *displayName != expected { - t.Errorf("got: [%s], want: [%s]", *displayName, expected) - } -} - -func TestEscapedFlagParsing(t *testing.T) { - fs := flag.NewFlagSet("test", flag.ContinueOnError) - name := fs.String("name", "", "name") - - rawArg := `-name="a\"b"` - args := []string{stripQuotes(rawArg)} - err := fs.Parse(args) - if err != nil { - t.Fatal(err) - } - - expected := `a"b` - if *name != expected { - t.Errorf("got: [%s], want: [%s]", *name, expected) - } -} diff --git a/weed/stats/disk_common.go b/weed/stats/disk_common.go deleted file mode 100644 index 936c77e91..000000000 --- a/weed/stats/disk_common.go +++ /dev/null @@ -1,17 +0,0 @@ -package stats - -import "github.com/seaweedfs/seaweedfs/weed/pb/volume_server_pb" - -func calculateDiskRemaining(disk *volume_server_pb.DiskStatus) { - disk.Used = disk.All - disk.Free - - if disk.All > 0 { - disk.PercentFree = float32((float64(disk.Free) / float64(disk.All)) * 100) - disk.PercentUsed = float32((float64(disk.Used) / float64(disk.All)) * 100) - } else { - disk.PercentFree = 0 - disk.PercentUsed = 0 - } - - return -} diff --git a/weed/stats/stats.go b/weed/stats/stats.go index 6d3d55cc6..f875f3780 100644 --- a/weed/stats/stats.go +++ b/weed/stats/stats.go @@ -62,12 +62,6 @@ func ConnectionOpen() { func ConnectionClose() { Chan.Connections <- NewTimedValue(time.Now(), -1) } -func RequestOpen() { - Chan.Requests <- NewTimedValue(time.Now(), 1) -} -func RequestClose() { - Chan.Requests <- NewTimedValue(time.Now(), -1) -} func AssignRequest() { Chan.AssignRequests <- NewTimedValue(time.Now(), 1) } diff --git a/weed/storage/erasure_coding/distribution/analysis.go b/weed/storage/erasure_coding/distribution/analysis.go index 22923e671..b939df53e 100644 --- a/weed/storage/erasure_coding/distribution/analysis.go +++ b/weed/storage/erasure_coding/distribution/analysis.go @@ -1,10 +1,5 @@ package distribution -import ( - "fmt" - "slices" -) - // ShardLocation represents where a shard is located in the topology type ShardLocation struct { ShardID int @@ -47,101 +42,6 @@ type TopologyAnalysis struct { TotalDCs int } -// NewTopologyAnalysis creates a new empty analysis -func NewTopologyAnalysis() *TopologyAnalysis { - return &TopologyAnalysis{ - ShardsByDC: make(map[string]int), - ShardsByRack: make(map[string]int), - ShardsByNode: make(map[string]int), - DCToShards: make(map[string][]int), - RackToShards: make(map[string][]int), - NodeToShards: make(map[string][]int), - DCToRacks: make(map[string][]string), - RackToNodes: make(map[string][]*TopologyNode), - AllNodes: make(map[string]*TopologyNode), - } -} - -// AddShardLocation adds a shard location to the analysis -func (a *TopologyAnalysis) AddShardLocation(loc ShardLocation) { - // Update counts - a.ShardsByDC[loc.DataCenter]++ - a.ShardsByRack[loc.Rack]++ - a.ShardsByNode[loc.NodeID]++ - - // Update shard lists - a.DCToShards[loc.DataCenter] = append(a.DCToShards[loc.DataCenter], loc.ShardID) - a.RackToShards[loc.Rack] = append(a.RackToShards[loc.Rack], loc.ShardID) - a.NodeToShards[loc.NodeID] = append(a.NodeToShards[loc.NodeID], loc.ShardID) - - a.TotalShards++ -} - -// AddNode adds a node to the topology (even if it has no shards) -func (a *TopologyAnalysis) AddNode(node *TopologyNode) { - if _, exists := a.AllNodes[node.NodeID]; exists { - return // Already added - } - - a.AllNodes[node.NodeID] = node - a.TotalNodes++ - - // Update topology structure - if !slices.Contains(a.DCToRacks[node.DataCenter], node.Rack) { - a.DCToRacks[node.DataCenter] = append(a.DCToRacks[node.DataCenter], node.Rack) - } - a.RackToNodes[node.Rack] = append(a.RackToNodes[node.Rack], node) - - // Update counts - if _, exists := a.ShardsByDC[node.DataCenter]; !exists { - a.TotalDCs++ - } - if _, exists := a.ShardsByRack[node.Rack]; !exists { - a.TotalRacks++ - } -} - -// Finalize computes final statistics after all data is added -func (a *TopologyAnalysis) Finalize() { - // Ensure we have accurate DC and rack counts - dcSet := make(map[string]bool) - rackSet := make(map[string]bool) - for _, node := range a.AllNodes { - dcSet[node.DataCenter] = true - rackSet[node.Rack] = true - } - a.TotalDCs = len(dcSet) - a.TotalRacks = len(rackSet) - a.TotalNodes = len(a.AllNodes) -} - -// String returns a summary of the analysis -func (a *TopologyAnalysis) String() string { - return fmt.Sprintf("TopologyAnalysis{shards:%d, nodes:%d, racks:%d, dcs:%d}", - a.TotalShards, a.TotalNodes, a.TotalRacks, a.TotalDCs) -} - -// DetailedString returns a detailed multi-line summary -func (a *TopologyAnalysis) DetailedString() string { - s := fmt.Sprintf("Topology Analysis:\n") - s += fmt.Sprintf(" Total Shards: %d\n", a.TotalShards) - s += fmt.Sprintf(" Data Centers: %d\n", a.TotalDCs) - for dc, count := range a.ShardsByDC { - s += fmt.Sprintf(" %s: %d shards\n", dc, count) - } - s += fmt.Sprintf(" Racks: %d\n", a.TotalRacks) - for rack, count := range a.ShardsByRack { - s += fmt.Sprintf(" %s: %d shards\n", rack, count) - } - s += fmt.Sprintf(" Nodes: %d\n", a.TotalNodes) - for nodeID, count := range a.ShardsByNode { - if count > 0 { - s += fmt.Sprintf(" %s: %d shards\n", nodeID, count) - } - } - return s -} - // TopologyExcess represents a topology level (DC/rack/node) with excess shards type TopologyExcess struct { ID string // DC/rack/node ID @@ -150,91 +50,3 @@ type TopologyExcess struct { Shards []int // Shard IDs at this level Nodes []*TopologyNode // Nodes at this level (for finding sources) } - -// CalculateDCExcess returns DCs with more shards than the target -func CalculateDCExcess(analysis *TopologyAnalysis, dist *ECDistribution) []TopologyExcess { - var excess []TopologyExcess - - for dc, count := range analysis.ShardsByDC { - if count > dist.TargetShardsPerDC { - // Collect nodes in this DC - var nodes []*TopologyNode - for _, rack := range analysis.DCToRacks[dc] { - nodes = append(nodes, analysis.RackToNodes[rack]...) - } - excess = append(excess, TopologyExcess{ - ID: dc, - Level: "dc", - Excess: count - dist.TargetShardsPerDC, - Shards: analysis.DCToShards[dc], - Nodes: nodes, - }) - } - } - - // Sort by excess (most excess first) - slices.SortFunc(excess, func(a, b TopologyExcess) int { - return b.Excess - a.Excess - }) - - return excess -} - -// CalculateRackExcess returns racks with more shards than the target (within a DC) -func CalculateRackExcess(analysis *TopologyAnalysis, dc string, targetPerRack int) []TopologyExcess { - var excess []TopologyExcess - - for _, rack := range analysis.DCToRacks[dc] { - count := analysis.ShardsByRack[rack] - if count > targetPerRack { - excess = append(excess, TopologyExcess{ - ID: rack, - Level: "rack", - Excess: count - targetPerRack, - Shards: analysis.RackToShards[rack], - Nodes: analysis.RackToNodes[rack], - }) - } - } - - slices.SortFunc(excess, func(a, b TopologyExcess) int { - return b.Excess - a.Excess - }) - - return excess -} - -// CalculateUnderservedDCs returns DCs that have fewer shards than target -func CalculateUnderservedDCs(analysis *TopologyAnalysis, dist *ECDistribution) []string { - var underserved []string - - // Check existing DCs - for dc, count := range analysis.ShardsByDC { - if count < dist.TargetShardsPerDC { - underserved = append(underserved, dc) - } - } - - // Check DCs with nodes but no shards - for dc := range analysis.DCToRacks { - if _, exists := analysis.ShardsByDC[dc]; !exists { - underserved = append(underserved, dc) - } - } - - return underserved -} - -// CalculateUnderservedRacks returns racks that have fewer shards than target -func CalculateUnderservedRacks(analysis *TopologyAnalysis, dc string, targetPerRack int) []string { - var underserved []string - - for _, rack := range analysis.DCToRacks[dc] { - count := analysis.ShardsByRack[rack] - if count < targetPerRack { - underserved = append(underserved, rack) - } - } - - return underserved -} diff --git a/weed/storage/erasure_coding/distribution/config.go b/weed/storage/erasure_coding/distribution/config.go index e89d6eeb6..b4935b0c7 100644 --- a/weed/storage/erasure_coding/distribution/config.go +++ b/weed/storage/erasure_coding/distribution/config.go @@ -1,12 +1,6 @@ // Package distribution provides EC shard distribution algorithms with configurable EC ratios. package distribution -import ( - "fmt" - - "github.com/seaweedfs/seaweedfs/weed/storage/super_block" -) - // ECConfig holds erasure coding configuration parameters. // This replaces hard-coded constants like DataShardsCount=10, ParityShardsCount=4. type ECConfig struct { @@ -14,113 +8,6 @@ type ECConfig struct { ParityShards int // Number of parity shards (e.g., 4) } -// DefaultECConfig returns the standard 10+4 EC configuration -func DefaultECConfig() ECConfig { - return ECConfig{ - DataShards: 10, - ParityShards: 4, - } -} - -// NewECConfig creates a new EC configuration with validation -func NewECConfig(dataShards, parityShards int) (ECConfig, error) { - if dataShards <= 0 { - return ECConfig{}, fmt.Errorf("dataShards must be positive, got %d", dataShards) - } - if parityShards <= 0 { - return ECConfig{}, fmt.Errorf("parityShards must be positive, got %d", parityShards) - } - if dataShards+parityShards > 32 { - return ECConfig{}, fmt.Errorf("total shards (%d+%d=%d) exceeds maximum of 32", - dataShards, parityShards, dataShards+parityShards) - } - return ECConfig{ - DataShards: dataShards, - ParityShards: parityShards, - }, nil -} - -// TotalShards returns the total number of shards (data + parity) -func (c ECConfig) TotalShards() int { - return c.DataShards + c.ParityShards -} - -// MaxTolerableLoss returns the maximum number of shards that can be lost -// while still being able to reconstruct the data -func (c ECConfig) MaxTolerableLoss() int { - return c.ParityShards -} - -// MinShardsForReconstruction returns the minimum number of shards needed -// to reconstruct the original data -func (c ECConfig) MinShardsForReconstruction() int { - return c.DataShards -} - -// String returns a human-readable representation -func (c ECConfig) String() string { - return fmt.Sprintf("%d+%d (total: %d, can lose: %d)", - c.DataShards, c.ParityShards, c.TotalShards(), c.MaxTolerableLoss()) -} - -// IsDataShard returns true if the shard ID is a data shard (0 to DataShards-1) -func (c ECConfig) IsDataShard(shardID int) bool { - return shardID >= 0 && shardID < c.DataShards -} - -// IsParityShard returns true if the shard ID is a parity shard (DataShards to TotalShards-1) -func (c ECConfig) IsParityShard(shardID int) bool { - return shardID >= c.DataShards && shardID < c.TotalShards() -} - -// SortShardsDataFirst returns a copy of shards sorted with data shards first. -// This is useful for initial placement where data shards should be spread out first. -func (c ECConfig) SortShardsDataFirst(shards []int) []int { - result := make([]int, len(shards)) - copy(result, shards) - - // Partition: data shards first, then parity shards - dataIdx := 0 - parityIdx := len(result) - 1 - - sorted := make([]int, len(result)) - for _, s := range result { - if c.IsDataShard(s) { - sorted[dataIdx] = s - dataIdx++ - } else { - sorted[parityIdx] = s - parityIdx-- - } - } - - return sorted -} - -// SortShardsParityFirst returns a copy of shards sorted with parity shards first. -// This is useful for rebalancing where we prefer to move parity shards. -func (c ECConfig) SortShardsParityFirst(shards []int) []int { - result := make([]int, len(shards)) - copy(result, shards) - - // Partition: parity shards first, then data shards - parityIdx := 0 - dataIdx := len(result) - 1 - - sorted := make([]int, len(result)) - for _, s := range result { - if c.IsParityShard(s) { - sorted[parityIdx] = s - parityIdx++ - } else { - sorted[dataIdx] = s - dataIdx-- - } - } - - return sorted -} - // ReplicationConfig holds the parsed replication policy type ReplicationConfig struct { MinDataCenters int // X+1 from XYZ replication (minimum DCs to use) @@ -130,42 +17,3 @@ type ReplicationConfig struct { // Original replication string (for logging/debugging) Original string } - -// NewReplicationConfig creates a ReplicationConfig from a ReplicaPlacement -func NewReplicationConfig(rp *super_block.ReplicaPlacement) ReplicationConfig { - if rp == nil { - return ReplicationConfig{ - MinDataCenters: 1, - MinRacksPerDC: 1, - MinNodesPerRack: 1, - Original: "000", - } - } - return ReplicationConfig{ - MinDataCenters: rp.DiffDataCenterCount + 1, - MinRacksPerDC: rp.DiffRackCount + 1, - MinNodesPerRack: rp.SameRackCount + 1, - Original: rp.String(), - } -} - -// NewReplicationConfigFromString creates a ReplicationConfig from a replication string -func NewReplicationConfigFromString(replication string) (ReplicationConfig, error) { - rp, err := super_block.NewReplicaPlacementFromString(replication) - if err != nil { - return ReplicationConfig{}, err - } - return NewReplicationConfig(rp), nil -} - -// TotalPlacementSlots returns the minimum number of unique placement locations -// based on the replication policy -func (r ReplicationConfig) TotalPlacementSlots() int { - return r.MinDataCenters * r.MinRacksPerDC * r.MinNodesPerRack -} - -// String returns a human-readable representation -func (r ReplicationConfig) String() string { - return fmt.Sprintf("replication=%s (DCs:%d, Racks/DC:%d, Nodes/Rack:%d)", - r.Original, r.MinDataCenters, r.MinRacksPerDC, r.MinNodesPerRack) -} diff --git a/weed/storage/erasure_coding/distribution/distribution.go b/weed/storage/erasure_coding/distribution/distribution.go index 03deea710..1ef05c55d 100644 --- a/weed/storage/erasure_coding/distribution/distribution.go +++ b/weed/storage/erasure_coding/distribution/distribution.go @@ -1,9 +1,5 @@ package distribution -import ( - "fmt" -) - // ECDistribution represents the target distribution of EC shards // based on EC configuration and replication policy. type ECDistribution struct { @@ -24,137 +20,3 @@ type ECDistribution struct { MaxShardsPerRack int MaxShardsPerNode int } - -// CalculateDistribution computes the target EC shard distribution based on -// EC configuration and replication policy. -// -// The algorithm: -// 1. Uses replication policy to determine minimum topology spread -// 2. Calculates target shards per level (evenly distributed) -// 3. Calculates max shards per level (for fault tolerance) -func CalculateDistribution(ec ECConfig, rep ReplicationConfig) *ECDistribution { - totalShards := ec.TotalShards() - - // Target distribution (balanced, rounded up to ensure all shards placed) - targetShardsPerDC := ceilDivide(totalShards, rep.MinDataCenters) - targetShardsPerRack := ceilDivide(targetShardsPerDC, rep.MinRacksPerDC) - targetShardsPerNode := ceilDivide(targetShardsPerRack, rep.MinNodesPerRack) - - // Maximum limits for fault tolerance - // The key constraint: losing one failure domain shouldn't lose more than parityShards - // So max shards per domain = totalShards - parityShards + tolerance - // We add small tolerance (+2) to allow for imbalanced topologies - faultToleranceLimit := totalShards - ec.ParityShards + 1 - - maxShardsPerDC := min(faultToleranceLimit, targetShardsPerDC+2) - maxShardsPerRack := min(faultToleranceLimit, targetShardsPerRack+2) - maxShardsPerNode := min(faultToleranceLimit, targetShardsPerNode+2) - - return &ECDistribution{ - ECConfig: ec, - ReplicationConfig: rep, - TargetShardsPerDC: targetShardsPerDC, - TargetShardsPerRack: targetShardsPerRack, - TargetShardsPerNode: targetShardsPerNode, - MaxShardsPerDC: maxShardsPerDC, - MaxShardsPerRack: maxShardsPerRack, - MaxShardsPerNode: maxShardsPerNode, - } -} - -// String returns a human-readable description of the distribution -func (d *ECDistribution) String() string { - return fmt.Sprintf( - "ECDistribution{EC:%s, DCs:%d (target:%d/max:%d), Racks/DC:%d (target:%d/max:%d), Nodes/Rack:%d (target:%d/max:%d)}", - d.ECConfig.String(), - d.ReplicationConfig.MinDataCenters, d.TargetShardsPerDC, d.MaxShardsPerDC, - d.ReplicationConfig.MinRacksPerDC, d.TargetShardsPerRack, d.MaxShardsPerRack, - d.ReplicationConfig.MinNodesPerRack, d.TargetShardsPerNode, d.MaxShardsPerNode, - ) -} - -// Summary returns a multi-line summary of the distribution plan -func (d *ECDistribution) Summary() string { - summary := fmt.Sprintf("EC Configuration: %s\n", d.ECConfig.String()) - summary += fmt.Sprintf("Replication: %s\n", d.ReplicationConfig.String()) - summary += fmt.Sprintf("Distribution Plan:\n") - summary += fmt.Sprintf(" Data Centers: %d (target %d shards each, max %d)\n", - d.ReplicationConfig.MinDataCenters, d.TargetShardsPerDC, d.MaxShardsPerDC) - summary += fmt.Sprintf(" Racks per DC: %d (target %d shards each, max %d)\n", - d.ReplicationConfig.MinRacksPerDC, d.TargetShardsPerRack, d.MaxShardsPerRack) - summary += fmt.Sprintf(" Nodes per Rack: %d (target %d shards each, max %d)\n", - d.ReplicationConfig.MinNodesPerRack, d.TargetShardsPerNode, d.MaxShardsPerNode) - return summary -} - -// CanSurviveDCFailure returns true if the distribution can survive -// complete loss of one data center -func (d *ECDistribution) CanSurviveDCFailure() bool { - // After losing one DC with max shards, check if remaining shards are enough - remainingAfterDCLoss := d.ECConfig.TotalShards() - d.TargetShardsPerDC - return remainingAfterDCLoss >= d.ECConfig.MinShardsForReconstruction() -} - -// CanSurviveRackFailure returns true if the distribution can survive -// complete loss of one rack -func (d *ECDistribution) CanSurviveRackFailure() bool { - remainingAfterRackLoss := d.ECConfig.TotalShards() - d.TargetShardsPerRack - return remainingAfterRackLoss >= d.ECConfig.MinShardsForReconstruction() -} - -// MinDCsForDCFaultTolerance calculates the minimum number of DCs needed -// to survive complete DC failure with this EC configuration -func (d *ECDistribution) MinDCsForDCFaultTolerance() int { - // To survive DC failure, max shards per DC = parityShards - maxShardsPerDC := d.ECConfig.MaxTolerableLoss() - if maxShardsPerDC == 0 { - return d.ECConfig.TotalShards() // Would need one DC per shard - } - return ceilDivide(d.ECConfig.TotalShards(), maxShardsPerDC) -} - -// FaultToleranceAnalysis returns a detailed analysis of fault tolerance -func (d *ECDistribution) FaultToleranceAnalysis() string { - analysis := fmt.Sprintf("Fault Tolerance Analysis for %s:\n", d.ECConfig.String()) - - // DC failure - dcSurvive := d.CanSurviveDCFailure() - shardsAfterDC := d.ECConfig.TotalShards() - d.TargetShardsPerDC - analysis += fmt.Sprintf(" DC Failure: %s\n", boolToResult(dcSurvive)) - analysis += fmt.Sprintf(" - Losing one DC loses ~%d shards\n", d.TargetShardsPerDC) - analysis += fmt.Sprintf(" - Remaining: %d shards (need %d)\n", shardsAfterDC, d.ECConfig.DataShards) - if !dcSurvive { - analysis += fmt.Sprintf(" - Need at least %d DCs for DC fault tolerance\n", d.MinDCsForDCFaultTolerance()) - } - - // Rack failure - rackSurvive := d.CanSurviveRackFailure() - shardsAfterRack := d.ECConfig.TotalShards() - d.TargetShardsPerRack - analysis += fmt.Sprintf(" Rack Failure: %s\n", boolToResult(rackSurvive)) - analysis += fmt.Sprintf(" - Losing one rack loses ~%d shards\n", d.TargetShardsPerRack) - analysis += fmt.Sprintf(" - Remaining: %d shards (need %d)\n", shardsAfterRack, d.ECConfig.DataShards) - - // Node failure (usually survivable) - shardsAfterNode := d.ECConfig.TotalShards() - d.TargetShardsPerNode - nodeSurvive := shardsAfterNode >= d.ECConfig.DataShards - analysis += fmt.Sprintf(" Node Failure: %s\n", boolToResult(nodeSurvive)) - analysis += fmt.Sprintf(" - Losing one node loses ~%d shards\n", d.TargetShardsPerNode) - analysis += fmt.Sprintf(" - Remaining: %d shards (need %d)\n", shardsAfterNode, d.ECConfig.DataShards) - - return analysis -} - -func boolToResult(b bool) string { - if b { - return "SURVIVABLE ✓" - } - return "NOT SURVIVABLE ✗" -} - -// ceilDivide performs ceiling division -func ceilDivide(a, b int) int { - if b <= 0 { - return a - } - return (a + b - 1) / b -} diff --git a/weed/storage/erasure_coding/distribution/distribution_test.go b/weed/storage/erasure_coding/distribution/distribution_test.go deleted file mode 100644 index dc6a19192..000000000 --- a/weed/storage/erasure_coding/distribution/distribution_test.go +++ /dev/null @@ -1,565 +0,0 @@ -package distribution - -import ( - "testing" -) - -func TestNewECConfig(t *testing.T) { - tests := []struct { - name string - dataShards int - parityShards int - wantErr bool - }{ - {"valid 10+4", 10, 4, false}, - {"valid 8+4", 8, 4, false}, - {"valid 6+3", 6, 3, false}, - {"valid 4+2", 4, 2, false}, - {"invalid data=0", 0, 4, true}, - {"invalid parity=0", 10, 0, true}, - {"invalid total>32", 20, 15, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - config, err := NewECConfig(tt.dataShards, tt.parityShards) - if (err != nil) != tt.wantErr { - t.Errorf("NewECConfig() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !tt.wantErr { - if config.DataShards != tt.dataShards { - t.Errorf("DataShards = %d, want %d", config.DataShards, tt.dataShards) - } - if config.ParityShards != tt.parityShards { - t.Errorf("ParityShards = %d, want %d", config.ParityShards, tt.parityShards) - } - if config.TotalShards() != tt.dataShards+tt.parityShards { - t.Errorf("TotalShards() = %d, want %d", config.TotalShards(), tt.dataShards+tt.parityShards) - } - } - }) - } -} - -func TestCalculateDistribution(t *testing.T) { - tests := []struct { - name string - ecConfig ECConfig - replication string - expectedMinDCs int - expectedMinRacksPerDC int - expectedMinNodesPerRack int - expectedTargetPerDC int - expectedTargetPerRack int - expectedTargetPerNode int - }{ - { - name: "10+4 with 000", - ecConfig: DefaultECConfig(), - replication: "000", - expectedMinDCs: 1, - expectedMinRacksPerDC: 1, - expectedMinNodesPerRack: 1, - expectedTargetPerDC: 14, - expectedTargetPerRack: 14, - expectedTargetPerNode: 14, - }, - { - name: "10+4 with 100", - ecConfig: DefaultECConfig(), - replication: "100", - expectedMinDCs: 2, - expectedMinRacksPerDC: 1, - expectedMinNodesPerRack: 1, - expectedTargetPerDC: 7, - expectedTargetPerRack: 7, - expectedTargetPerNode: 7, - }, - { - name: "10+4 with 110", - ecConfig: DefaultECConfig(), - replication: "110", - expectedMinDCs: 2, - expectedMinRacksPerDC: 2, - expectedMinNodesPerRack: 1, - expectedTargetPerDC: 7, - expectedTargetPerRack: 4, - expectedTargetPerNode: 4, - }, - { - name: "10+4 with 200", - ecConfig: DefaultECConfig(), - replication: "200", - expectedMinDCs: 3, - expectedMinRacksPerDC: 1, - expectedMinNodesPerRack: 1, - expectedTargetPerDC: 5, - expectedTargetPerRack: 5, - expectedTargetPerNode: 5, - }, - { - name: "8+4 with 110", - ecConfig: ECConfig{ - DataShards: 8, - ParityShards: 4, - }, - replication: "110", - expectedMinDCs: 2, - expectedMinRacksPerDC: 2, - expectedMinNodesPerRack: 1, - expectedTargetPerDC: 6, // 12/2 = 6 - expectedTargetPerRack: 3, // 6/2 = 3 - expectedTargetPerNode: 3, - }, - { - name: "6+3 with 100", - ecConfig: ECConfig{ - DataShards: 6, - ParityShards: 3, - }, - replication: "100", - expectedMinDCs: 2, - expectedMinRacksPerDC: 1, - expectedMinNodesPerRack: 1, - expectedTargetPerDC: 5, // ceil(9/2) = 5 - expectedTargetPerRack: 5, - expectedTargetPerNode: 5, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rep, err := NewReplicationConfigFromString(tt.replication) - if err != nil { - t.Fatalf("Failed to parse replication %s: %v", tt.replication, err) - } - - dist := CalculateDistribution(tt.ecConfig, rep) - - if dist.ReplicationConfig.MinDataCenters != tt.expectedMinDCs { - t.Errorf("MinDataCenters = %d, want %d", dist.ReplicationConfig.MinDataCenters, tt.expectedMinDCs) - } - if dist.ReplicationConfig.MinRacksPerDC != tt.expectedMinRacksPerDC { - t.Errorf("MinRacksPerDC = %d, want %d", dist.ReplicationConfig.MinRacksPerDC, tt.expectedMinRacksPerDC) - } - if dist.ReplicationConfig.MinNodesPerRack != tt.expectedMinNodesPerRack { - t.Errorf("MinNodesPerRack = %d, want %d", dist.ReplicationConfig.MinNodesPerRack, tt.expectedMinNodesPerRack) - } - if dist.TargetShardsPerDC != tt.expectedTargetPerDC { - t.Errorf("TargetShardsPerDC = %d, want %d", dist.TargetShardsPerDC, tt.expectedTargetPerDC) - } - if dist.TargetShardsPerRack != tt.expectedTargetPerRack { - t.Errorf("TargetShardsPerRack = %d, want %d", dist.TargetShardsPerRack, tt.expectedTargetPerRack) - } - if dist.TargetShardsPerNode != tt.expectedTargetPerNode { - t.Errorf("TargetShardsPerNode = %d, want %d", dist.TargetShardsPerNode, tt.expectedTargetPerNode) - } - - t.Logf("Distribution for %s: %s", tt.name, dist.String()) - }) - } -} - -func TestFaultToleranceAnalysis(t *testing.T) { - tests := []struct { - name string - ecConfig ECConfig - replication string - canSurviveDC bool - canSurviveRack bool - }{ - // 10+4 = 14 shards, need 10 to reconstruct, can lose 4 - {"10+4 000", DefaultECConfig(), "000", false, false}, // All in one, any failure is fatal - {"10+4 100", DefaultECConfig(), "100", false, false}, // 7 per DC/rack, 7 remaining < 10 - {"10+4 200", DefaultECConfig(), "200", false, false}, // 5 per DC/rack, 9 remaining < 10 - {"10+4 110", DefaultECConfig(), "110", false, true}, // 4 per rack, 10 remaining = enough for rack - - // 8+4 = 12 shards, need 8 to reconstruct, can lose 4 - {"8+4 100", ECConfig{8, 4}, "100", false, false}, // 6 per DC/rack, 6 remaining < 8 - {"8+4 200", ECConfig{8, 4}, "200", true, true}, // 4 per DC/rack, 8 remaining = enough! - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rep, _ := NewReplicationConfigFromString(tt.replication) - dist := CalculateDistribution(tt.ecConfig, rep) - - if dist.CanSurviveDCFailure() != tt.canSurviveDC { - t.Errorf("CanSurviveDCFailure() = %v, want %v", dist.CanSurviveDCFailure(), tt.canSurviveDC) - } - if dist.CanSurviveRackFailure() != tt.canSurviveRack { - t.Errorf("CanSurviveRackFailure() = %v, want %v", dist.CanSurviveRackFailure(), tt.canSurviveRack) - } - - t.Log(dist.FaultToleranceAnalysis()) - }) - } -} - -func TestMinDCsForDCFaultTolerance(t *testing.T) { - tests := []struct { - name string - ecConfig ECConfig - minDCs int - }{ - // 10+4: can lose 4, so max 4 per DC, 14/4 = 4 DCs needed - {"10+4", DefaultECConfig(), 4}, - // 8+4: can lose 4, so max 4 per DC, 12/4 = 3 DCs needed - {"8+4", ECConfig{8, 4}, 3}, - // 6+3: can lose 3, so max 3 per DC, 9/3 = 3 DCs needed - {"6+3", ECConfig{6, 3}, 3}, - // 4+2: can lose 2, so max 2 per DC, 6/2 = 3 DCs needed - {"4+2", ECConfig{4, 2}, 3}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rep, _ := NewReplicationConfigFromString("000") - dist := CalculateDistribution(tt.ecConfig, rep) - - if dist.MinDCsForDCFaultTolerance() != tt.minDCs { - t.Errorf("MinDCsForDCFaultTolerance() = %d, want %d", - dist.MinDCsForDCFaultTolerance(), tt.minDCs) - } - - t.Logf("%s: needs %d DCs for DC fault tolerance", tt.name, dist.MinDCsForDCFaultTolerance()) - }) - } -} - -func TestTopologyAnalysis(t *testing.T) { - analysis := NewTopologyAnalysis() - - // Add nodes to topology - node1 := &TopologyNode{ - NodeID: "node1", - DataCenter: "dc1", - Rack: "rack1", - FreeSlots: 5, - } - node2 := &TopologyNode{ - NodeID: "node2", - DataCenter: "dc1", - Rack: "rack2", - FreeSlots: 10, - } - node3 := &TopologyNode{ - NodeID: "node3", - DataCenter: "dc2", - Rack: "rack3", - FreeSlots: 10, - } - - analysis.AddNode(node1) - analysis.AddNode(node2) - analysis.AddNode(node3) - - // Add shard locations (all on node1) - for i := 0; i < 14; i++ { - analysis.AddShardLocation(ShardLocation{ - ShardID: i, - NodeID: "node1", - DataCenter: "dc1", - Rack: "rack1", - }) - } - - analysis.Finalize() - - // Verify counts - if analysis.TotalShards != 14 { - t.Errorf("TotalShards = %d, want 14", analysis.TotalShards) - } - if analysis.ShardsByDC["dc1"] != 14 { - t.Errorf("ShardsByDC[dc1] = %d, want 14", analysis.ShardsByDC["dc1"]) - } - if analysis.ShardsByRack["rack1"] != 14 { - t.Errorf("ShardsByRack[rack1] = %d, want 14", analysis.ShardsByRack["rack1"]) - } - if analysis.ShardsByNode["node1"] != 14 { - t.Errorf("ShardsByNode[node1] = %d, want 14", analysis.ShardsByNode["node1"]) - } - - t.Log(analysis.DetailedString()) -} - -func TestRebalancer(t *testing.T) { - // Build topology: 2 DCs, 2 racks each, all shards on one node - analysis := NewTopologyAnalysis() - - // Add nodes - nodes := []*TopologyNode{ - {NodeID: "dc1-rack1-node1", DataCenter: "dc1", Rack: "dc1-rack1", FreeSlots: 0}, - {NodeID: "dc1-rack2-node1", DataCenter: "dc1", Rack: "dc1-rack2", FreeSlots: 10}, - {NodeID: "dc2-rack1-node1", DataCenter: "dc2", Rack: "dc2-rack1", FreeSlots: 10}, - {NodeID: "dc2-rack2-node1", DataCenter: "dc2", Rack: "dc2-rack2", FreeSlots: 10}, - } - for _, node := range nodes { - analysis.AddNode(node) - } - - // Add all 14 shards to first node - for i := 0; i < 14; i++ { - analysis.AddShardLocation(ShardLocation{ - ShardID: i, - NodeID: "dc1-rack1-node1", - DataCenter: "dc1", - Rack: "dc1-rack1", - }) - } - analysis.Finalize() - - // Create rebalancer with 110 replication (2 DCs, 2 racks each) - ec := DefaultECConfig() - rep, _ := NewReplicationConfigFromString("110") - rebalancer := NewRebalancer(ec, rep) - - plan, err := rebalancer.PlanRebalance(analysis) - if err != nil { - t.Fatalf("PlanRebalance failed: %v", err) - } - - t.Logf("Planned %d moves", plan.TotalMoves) - t.Log(plan.DetailedString()) - - // Verify we're moving shards to dc2 - movedToDC2 := 0 - for _, move := range plan.Moves { - if move.DestNode.DataCenter == "dc2" { - movedToDC2++ - } - } - - if movedToDC2 == 0 { - t.Error("Expected some moves to dc2") - } - - // With "110" replication, target is 7 shards per DC - // Starting with 14 in dc1, should plan to move 7 to dc2 - if plan.MovesAcrossDC < 7 { - t.Errorf("Expected at least 7 cross-DC moves for 110 replication, got %d", plan.MovesAcrossDC) - } -} - -func TestCustomECRatios(t *testing.T) { - // Test various custom EC ratios that seaweed-enterprise might use - ratios := []struct { - name string - data int - parity int - }{ - {"4+2", 4, 2}, - {"6+3", 6, 3}, - {"8+2", 8, 2}, - {"8+4", 8, 4}, - {"10+4", 10, 4}, - {"12+4", 12, 4}, - {"16+4", 16, 4}, - } - - for _, ratio := range ratios { - t.Run(ratio.name, func(t *testing.T) { - ec, err := NewECConfig(ratio.data, ratio.parity) - if err != nil { - t.Fatalf("Failed to create EC config: %v", err) - } - - rep, _ := NewReplicationConfigFromString("110") - dist := CalculateDistribution(ec, rep) - - t.Logf("EC %s with replication 110:", ratio.name) - t.Logf(" Total shards: %d", ec.TotalShards()) - t.Logf(" Can lose: %d shards", ec.MaxTolerableLoss()) - t.Logf(" Target per DC: %d", dist.TargetShardsPerDC) - t.Logf(" Target per rack: %d", dist.TargetShardsPerRack) - t.Logf(" Min DCs for DC fault tolerance: %d", dist.MinDCsForDCFaultTolerance()) - - // Verify basic sanity - if dist.TargetShardsPerDC*2 < ec.TotalShards() { - t.Errorf("Target per DC (%d) * 2 should be >= total (%d)", - dist.TargetShardsPerDC, ec.TotalShards()) - } - }) - } -} - -func TestShardClassification(t *testing.T) { - ec := DefaultECConfig() // 10+4 - - // Test IsDataShard - for i := 0; i < 10; i++ { - if !ec.IsDataShard(i) { - t.Errorf("Shard %d should be a data shard", i) - } - if ec.IsParityShard(i) { - t.Errorf("Shard %d should not be a parity shard", i) - } - } - - // Test IsParityShard - for i := 10; i < 14; i++ { - if ec.IsDataShard(i) { - t.Errorf("Shard %d should not be a data shard", i) - } - if !ec.IsParityShard(i) { - t.Errorf("Shard %d should be a parity shard", i) - } - } - - // Test with custom 8+4 EC - ec84, _ := NewECConfig(8, 4) - for i := 0; i < 8; i++ { - if !ec84.IsDataShard(i) { - t.Errorf("8+4 EC: Shard %d should be a data shard", i) - } - } - for i := 8; i < 12; i++ { - if !ec84.IsParityShard(i) { - t.Errorf("8+4 EC: Shard %d should be a parity shard", i) - } - } -} - -func TestSortShardsDataFirst(t *testing.T) { - ec := DefaultECConfig() // 10+4 - - // Mixed shards: [0, 10, 5, 11, 2, 12, 7, 13] - shards := []int{0, 10, 5, 11, 2, 12, 7, 13} - sorted := ec.SortShardsDataFirst(shards) - - t.Logf("Original: %v", shards) - t.Logf("Sorted (data first): %v", sorted) - - // First 4 should be data shards (0, 5, 2, 7) - for i := 0; i < 4; i++ { - if !ec.IsDataShard(sorted[i]) { - t.Errorf("Position %d should be a data shard, got %d", i, sorted[i]) - } - } - - // Last 4 should be parity shards (10, 11, 12, 13) - for i := 4; i < 8; i++ { - if !ec.IsParityShard(sorted[i]) { - t.Errorf("Position %d should be a parity shard, got %d", i, sorted[i]) - } - } -} - -func TestSortShardsParityFirst(t *testing.T) { - ec := DefaultECConfig() // 10+4 - - // Mixed shards: [0, 10, 5, 11, 2, 12, 7, 13] - shards := []int{0, 10, 5, 11, 2, 12, 7, 13} - sorted := ec.SortShardsParityFirst(shards) - - t.Logf("Original: %v", shards) - t.Logf("Sorted (parity first): %v", sorted) - - // First 4 should be parity shards (10, 11, 12, 13) - for i := 0; i < 4; i++ { - if !ec.IsParityShard(sorted[i]) { - t.Errorf("Position %d should be a parity shard, got %d", i, sorted[i]) - } - } - - // Last 4 should be data shards (0, 5, 2, 7) - for i := 4; i < 8; i++ { - if !ec.IsDataShard(sorted[i]) { - t.Errorf("Position %d should be a data shard, got %d", i, sorted[i]) - } - } -} - -func TestRebalancerPrefersMovingParityShards(t *testing.T) { - // Build topology where one node has all shards including mix of data and parity - analysis := NewTopologyAnalysis() - - // Node 1: Has all 14 shards (mixed data and parity) - node1 := &TopologyNode{ - NodeID: "node1", - DataCenter: "dc1", - Rack: "rack1", - FreeSlots: 0, - } - analysis.AddNode(node1) - - // Node 2: Empty, ready to receive - node2 := &TopologyNode{ - NodeID: "node2", - DataCenter: "dc1", - Rack: "rack1", - FreeSlots: 10, - } - analysis.AddNode(node2) - - // Add all 14 shards to node1 - for i := 0; i < 14; i++ { - analysis.AddShardLocation(ShardLocation{ - ShardID: i, - NodeID: "node1", - DataCenter: "dc1", - Rack: "rack1", - }) - } - analysis.Finalize() - - // Create rebalancer - ec := DefaultECConfig() - rep, _ := NewReplicationConfigFromString("000") - rebalancer := NewRebalancer(ec, rep) - - plan, err := rebalancer.PlanRebalance(analysis) - if err != nil { - t.Fatalf("PlanRebalance failed: %v", err) - } - - t.Logf("Planned %d moves", len(plan.Moves)) - - // Check that parity shards are moved first - parityMovesFirst := 0 - dataMovesFirst := 0 - seenDataMove := false - - for _, move := range plan.Moves { - isParity := ec.IsParityShard(move.ShardID) - t.Logf("Move shard %d (parity=%v): %s -> %s", - move.ShardID, isParity, move.SourceNode.NodeID, move.DestNode.NodeID) - - if isParity && !seenDataMove { - parityMovesFirst++ - } else if !isParity { - seenDataMove = true - dataMovesFirst++ - } - } - - t.Logf("Parity moves before first data move: %d", parityMovesFirst) - t.Logf("Data moves: %d", dataMovesFirst) - - // With 10+4 EC, there are 4 parity shards - // They should be moved before data shards when possible - if parityMovesFirst < 4 && len(plan.Moves) >= 4 { - t.Logf("Note: Expected parity shards to be moved first, but got %d parity moves before data moves", parityMovesFirst) - } -} - -func TestDistributionSummary(t *testing.T) { - ec := DefaultECConfig() - rep, _ := NewReplicationConfigFromString("110") - dist := CalculateDistribution(ec, rep) - - summary := dist.Summary() - t.Log(summary) - - if len(summary) == 0 { - t.Error("Summary should not be empty") - } - - analysis := dist.FaultToleranceAnalysis() - t.Log(analysis) - - if len(analysis) == 0 { - t.Error("Fault tolerance analysis should not be empty") - } -} diff --git a/weed/storage/erasure_coding/distribution/rebalancer.go b/weed/storage/erasure_coding/distribution/rebalancer.go index cd8b87de6..2442e59a9 100644 --- a/weed/storage/erasure_coding/distribution/rebalancer.go +++ b/weed/storage/erasure_coding/distribution/rebalancer.go @@ -1,10 +1,5 @@ package distribution -import ( - "fmt" - "slices" -) - // ShardMove represents a planned shard move type ShardMove struct { ShardID int @@ -13,12 +8,6 @@ type ShardMove struct { Reason string } -// String returns a human-readable description of the move -func (m ShardMove) String() string { - return fmt.Sprintf("shard %d: %s -> %s (%s)", - m.ShardID, m.SourceNode.NodeID, m.DestNode.NodeID, m.Reason) -} - // RebalancePlan contains the complete plan for rebalancing EC shards type RebalancePlan struct { Moves []ShardMove @@ -32,346 +21,8 @@ type RebalancePlan struct { MovesWithinRack int } -// String returns a summary of the plan -func (p *RebalancePlan) String() string { - return fmt.Sprintf("RebalancePlan{moves:%d, acrossDC:%d, acrossRack:%d, withinRack:%d}", - p.TotalMoves, p.MovesAcrossDC, p.MovesAcrossRack, p.MovesWithinRack) -} - -// DetailedString returns a detailed multi-line summary -func (p *RebalancePlan) DetailedString() string { - s := fmt.Sprintf("Rebalance Plan:\n") - s += fmt.Sprintf(" Total Moves: %d\n", p.TotalMoves) - s += fmt.Sprintf(" Across DC: %d\n", p.MovesAcrossDC) - s += fmt.Sprintf(" Across Rack: %d\n", p.MovesAcrossRack) - s += fmt.Sprintf(" Within Rack: %d\n", p.MovesWithinRack) - s += fmt.Sprintf("\nMoves:\n") - for i, move := range p.Moves { - s += fmt.Sprintf(" %d. %s\n", i+1, move.String()) - } - return s -} - // Rebalancer plans shard moves to achieve proportional distribution type Rebalancer struct { ecConfig ECConfig repConfig ReplicationConfig } - -// NewRebalancer creates a new rebalancer with the given configuration -func NewRebalancer(ec ECConfig, rep ReplicationConfig) *Rebalancer { - return &Rebalancer{ - ecConfig: ec, - repConfig: rep, - } -} - -// PlanRebalance creates a rebalancing plan based on current topology analysis -func (r *Rebalancer) PlanRebalance(analysis *TopologyAnalysis) (*RebalancePlan, error) { - dist := CalculateDistribution(r.ecConfig, r.repConfig) - - plan := &RebalancePlan{ - Distribution: dist, - Analysis: analysis, - } - - // Step 1: Balance across data centers - dcMoves := r.planDCMoves(analysis, dist) - for _, move := range dcMoves { - plan.Moves = append(plan.Moves, move) - plan.MovesAcrossDC++ - } - - // Update analysis after DC moves (for planning purposes) - r.applyMovesToAnalysis(analysis, dcMoves) - - // Step 2: Balance across racks within each DC - rackMoves := r.planRackMoves(analysis, dist) - for _, move := range rackMoves { - plan.Moves = append(plan.Moves, move) - plan.MovesAcrossRack++ - } - - // Update analysis after rack moves - r.applyMovesToAnalysis(analysis, rackMoves) - - // Step 3: Balance across nodes within each rack - nodeMoves := r.planNodeMoves(analysis, dist) - for _, move := range nodeMoves { - plan.Moves = append(plan.Moves, move) - plan.MovesWithinRack++ - } - - plan.TotalMoves = len(plan.Moves) - - return plan, nil -} - -// planDCMoves plans moves to balance shards across data centers -func (r *Rebalancer) planDCMoves(analysis *TopologyAnalysis, dist *ECDistribution) []ShardMove { - var moves []ShardMove - - overDCs := CalculateDCExcess(analysis, dist) - underDCs := CalculateUnderservedDCs(analysis, dist) - - underIdx := 0 - for _, over := range overDCs { - for over.Excess > 0 && underIdx < len(underDCs) { - destDC := underDCs[underIdx] - - // Find a shard and source node - shardID, srcNode := r.pickShardToMove(analysis, over.Nodes) - if srcNode == nil { - break - } - - // Find destination node in target DC - destNode := r.pickBestDestination(analysis, destDC, "", dist) - if destNode == nil { - underIdx++ - continue - } - - moves = append(moves, ShardMove{ - ShardID: shardID, - SourceNode: srcNode, - DestNode: destNode, - Reason: fmt.Sprintf("balance DC: %s -> %s", srcNode.DataCenter, destDC), - }) - - over.Excess-- - analysis.ShardsByDC[srcNode.DataCenter]-- - analysis.ShardsByDC[destDC]++ - - // Check if destDC reached target - if analysis.ShardsByDC[destDC] >= dist.TargetShardsPerDC { - underIdx++ - } - } - } - - return moves -} - -// planRackMoves plans moves to balance shards across racks within each DC -func (r *Rebalancer) planRackMoves(analysis *TopologyAnalysis, dist *ECDistribution) []ShardMove { - var moves []ShardMove - - for dc := range analysis.DCToRacks { - dcShards := analysis.ShardsByDC[dc] - numRacks := len(analysis.DCToRacks[dc]) - if numRacks == 0 { - continue - } - - targetPerRack := ceilDivide(dcShards, max(numRacks, dist.ReplicationConfig.MinRacksPerDC)) - - overRacks := CalculateRackExcess(analysis, dc, targetPerRack) - underRacks := CalculateUnderservedRacks(analysis, dc, targetPerRack) - - underIdx := 0 - for _, over := range overRacks { - for over.Excess > 0 && underIdx < len(underRacks) { - destRack := underRacks[underIdx] - - // Find shard and source node - shardID, srcNode := r.pickShardToMove(analysis, over.Nodes) - if srcNode == nil { - break - } - - // Find destination node in target rack - destNode := r.pickBestDestination(analysis, dc, destRack, dist) - if destNode == nil { - underIdx++ - continue - } - - moves = append(moves, ShardMove{ - ShardID: shardID, - SourceNode: srcNode, - DestNode: destNode, - Reason: fmt.Sprintf("balance rack: %s -> %s", srcNode.Rack, destRack), - }) - - over.Excess-- - analysis.ShardsByRack[srcNode.Rack]-- - analysis.ShardsByRack[destRack]++ - - if analysis.ShardsByRack[destRack] >= targetPerRack { - underIdx++ - } - } - } - } - - return moves -} - -// planNodeMoves plans moves to balance shards across nodes within each rack -func (r *Rebalancer) planNodeMoves(analysis *TopologyAnalysis, dist *ECDistribution) []ShardMove { - var moves []ShardMove - - for rack, nodes := range analysis.RackToNodes { - if len(nodes) <= 1 { - continue - } - - rackShards := analysis.ShardsByRack[rack] - targetPerNode := ceilDivide(rackShards, max(len(nodes), dist.ReplicationConfig.MinNodesPerRack)) - - // Find over and under nodes - var overNodes []*TopologyNode - var underNodes []*TopologyNode - - for _, node := range nodes { - count := analysis.ShardsByNode[node.NodeID] - if count > targetPerNode { - overNodes = append(overNodes, node) - } else if count < targetPerNode { - underNodes = append(underNodes, node) - } - } - - // Sort by excess/deficit - slices.SortFunc(overNodes, func(a, b *TopologyNode) int { - return analysis.ShardsByNode[b.NodeID] - analysis.ShardsByNode[a.NodeID] - }) - - underIdx := 0 - for _, srcNode := range overNodes { - excess := analysis.ShardsByNode[srcNode.NodeID] - targetPerNode - - for excess > 0 && underIdx < len(underNodes) { - destNode := underNodes[underIdx] - - // Pick a shard from this node, preferring parity shards - shards := analysis.NodeToShards[srcNode.NodeID] - if len(shards) == 0 { - break - } - - // Find a parity shard first, fallback to data shard - shardID := -1 - shardIdx := -1 - for i, s := range shards { - if r.ecConfig.IsParityShard(s) { - shardID = s - shardIdx = i - break - } - } - if shardID == -1 { - shardID = shards[0] - shardIdx = 0 - } - - moves = append(moves, ShardMove{ - ShardID: shardID, - SourceNode: srcNode, - DestNode: destNode, - Reason: fmt.Sprintf("balance node: %s -> %s", srcNode.NodeID, destNode.NodeID), - }) - - excess-- - analysis.ShardsByNode[srcNode.NodeID]-- - analysis.ShardsByNode[destNode.NodeID]++ - - // Update shard lists - remove the specific shard we picked - analysis.NodeToShards[srcNode.NodeID] = append( - shards[:shardIdx], shards[shardIdx+1:]...) - analysis.NodeToShards[destNode.NodeID] = append( - analysis.NodeToShards[destNode.NodeID], shardID) - - if analysis.ShardsByNode[destNode.NodeID] >= targetPerNode { - underIdx++ - } - } - } - } - - return moves -} - -// pickShardToMove selects a shard and its node from the given nodes. -// It prefers to move parity shards first, keeping data shards spread out -// since data shards serve read requests while parity shards are only for reconstruction. -func (r *Rebalancer) pickShardToMove(analysis *TopologyAnalysis, nodes []*TopologyNode) (int, *TopologyNode) { - // Sort by shard count (most shards first) - slices.SortFunc(nodes, func(a, b *TopologyNode) int { - return analysis.ShardsByNode[b.NodeID] - analysis.ShardsByNode[a.NodeID] - }) - - // First pass: try to find a parity shard to move (prefer moving parity) - for _, node := range nodes { - shards := analysis.NodeToShards[node.NodeID] - for _, shardID := range shards { - if r.ecConfig.IsParityShard(shardID) { - return shardID, node - } - } - } - - // Second pass: if no parity shards, move a data shard - for _, node := range nodes { - shards := analysis.NodeToShards[node.NodeID] - if len(shards) > 0 { - return shards[0], node - } - } - - return -1, nil -} - -// pickBestDestination selects the best destination node -func (r *Rebalancer) pickBestDestination(analysis *TopologyAnalysis, targetDC, targetRack string, dist *ECDistribution) *TopologyNode { - var candidates []*TopologyNode - - // Collect candidates - for _, node := range analysis.AllNodes { - // Filter by DC if specified - if targetDC != "" && node.DataCenter != targetDC { - continue - } - // Filter by rack if specified - if targetRack != "" && node.Rack != targetRack { - continue - } - // Check capacity - if node.FreeSlots <= 0 { - continue - } - // Check max shards limit - if analysis.ShardsByNode[node.NodeID] >= dist.MaxShardsPerNode { - continue - } - - candidates = append(candidates, node) - } - - if len(candidates) == 0 { - return nil - } - - // Sort by: 1) fewer shards, 2) more free slots - slices.SortFunc(candidates, func(a, b *TopologyNode) int { - aShards := analysis.ShardsByNode[a.NodeID] - bShards := analysis.ShardsByNode[b.NodeID] - if aShards != bShards { - return aShards - bShards - } - return b.FreeSlots - a.FreeSlots - }) - - return candidates[0] -} - -// applyMovesToAnalysis is a no-op placeholder for potential future use. -// Note: All planners (planDCMoves, planRackMoves, planNodeMoves) update -// their respective counts (ShardsByDC, ShardsByRack, ShardsByNode) and -// shard lists (NodeToShards) inline during planning. This avoids duplicate -// updates that would occur if we also updated counts here. -func (r *Rebalancer) applyMovesToAnalysis(analysis *TopologyAnalysis, moves []ShardMove) { - // Counts are already updated by the individual planners. - // This function is kept for API compatibility and potential future use. -} diff --git a/weed/storage/erasure_coding/ec_shards_info.go b/weed/storage/erasure_coding/ec_shards_info.go index 55838eb4e..0d2ce5b63 100644 --- a/weed/storage/erasure_coding/ec_shards_info.go +++ b/weed/storage/erasure_coding/ec_shards_info.go @@ -53,19 +53,6 @@ func NewShardsInfo() *ShardsInfo { } } -// Initializes a ShardsInfo from a ECVolume. -func ShardsInfoFromVolume(ev *EcVolume) *ShardsInfo { - res := &ShardsInfo{ - shards: make([]ShardInfo, len(ev.Shards)), - } - // Build shards directly to avoid locking in Set() since res is not yet shared - for i, s := range ev.Shards { - res.shards[i] = NewShardInfo(s.ShardId, ShardSize(s.Size())) - res.shardBits = res.shardBits.Set(s.ShardId) - } - return res -} - // Initializes a ShardsInfo from a VolumeEcShardInformationMessage proto. func ShardsInfoFromVolumeEcShardInformationMessage(vi *master_pb.VolumeEcShardInformationMessage) *ShardsInfo { res := NewShardsInfo() diff --git a/weed/storage/erasure_coding/placement/placement.go b/weed/storage/erasure_coding/placement/placement.go index 67e21c1f8..bda050b82 100644 --- a/weed/storage/erasure_coding/placement/placement.go +++ b/weed/storage/erasure_coding/placement/placement.go @@ -64,18 +64,6 @@ type PlacementRequest struct { PreferDifferentRacks bool } -// DefaultPlacementRequest returns the default placement configuration -func DefaultPlacementRequest() PlacementRequest { - return PlacementRequest{ - ShardsNeeded: 14, - MaxShardsPerServer: 0, - MaxShardsPerRack: 0, - MaxTaskLoad: 5, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } -} - // PlacementResult contains the selected destinations for EC shards type PlacementResult struct { SelectedDisks []*DiskCandidate @@ -270,15 +258,6 @@ func groupDisksByRack(disks []*DiskCandidate) map[string][]*DiskCandidate { return result } -// groupDisksByServer groups disks by their server -func groupDisksByServer(disks []*DiskCandidate) map[string][]*DiskCandidate { - result := make(map[string][]*DiskCandidate) - for _, disk := range disks { - result[disk.NodeID] = append(result[disk.NodeID], disk) - } - return result -} - // getRackKey returns the unique key for a rack (dc:rack) func getRackKey(disk *DiskCandidate) string { return fmt.Sprintf("%s:%s", disk.DataCenter, disk.Rack) @@ -393,28 +372,3 @@ func addDiskToResult(result *PlacementResult, disk *DiskCandidate, result.ShardsPerRack[rackKey]++ result.ShardsPerDC[disk.DataCenter]++ } - -// VerifySpread checks if the placement result meets diversity requirements -func VerifySpread(result *PlacementResult, minServers, minRacks int) error { - if result.ServersUsed < minServers { - return fmt.Errorf("only %d servers used, need at least %d", result.ServersUsed, minServers) - } - if result.RacksUsed < minRacks { - return fmt.Errorf("only %d racks used, need at least %d", result.RacksUsed, minRacks) - } - return nil -} - -// CalculateIdealDistribution returns the ideal number of shards per server -// when we have a certain number of shards and servers -func CalculateIdealDistribution(totalShards, numServers int) (min, max int) { - if numServers <= 0 { - return 0, totalShards - } - min = totalShards / numServers - max = min - if totalShards%numServers != 0 { - max = min + 1 - } - return -} diff --git a/weed/storage/erasure_coding/placement/placement_test.go b/weed/storage/erasure_coding/placement/placement_test.go deleted file mode 100644 index 7501dfa9e..000000000 --- a/weed/storage/erasure_coding/placement/placement_test.go +++ /dev/null @@ -1,517 +0,0 @@ -package placement - -import ( - "testing" -) - -// Helper function to create disk candidates for testing -func makeDisk(nodeID string, diskID uint32, dc, rack string, freeSlots int) *DiskCandidate { - return &DiskCandidate{ - NodeID: nodeID, - DiskID: diskID, - DataCenter: dc, - Rack: rack, - VolumeCount: 0, - MaxVolumeCount: 100, - ShardCount: 0, - FreeSlots: freeSlots, - LoadCount: 0, - } -} - -func TestSelectDestinations_SingleRack(t *testing.T) { - // Test: 3 servers in same rack, each with 2 disks, need 6 shards - // Expected: Should spread across all 6 disks (one per disk) - disks := []*DiskCandidate{ - makeDisk("server1", 0, "dc1", "rack1", 10), - makeDisk("server1", 1, "dc1", "rack1", 10), - makeDisk("server2", 0, "dc1", "rack1", 10), - makeDisk("server2", 1, "dc1", "rack1", 10), - makeDisk("server3", 0, "dc1", "rack1", 10), - makeDisk("server3", 1, "dc1", "rack1", 10), - } - - config := PlacementRequest{ - ShardsNeeded: 6, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } - - result, err := SelectDestinations(disks, config) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(result.SelectedDisks) != 6 { - t.Errorf("expected 6 selected disks, got %d", len(result.SelectedDisks)) - } - - // Verify all 3 servers are used - if result.ServersUsed != 3 { - t.Errorf("expected 3 servers used, got %d", result.ServersUsed) - } - - // Verify each disk is unique - diskSet := make(map[string]bool) - for _, disk := range result.SelectedDisks { - key := getDiskKey(disk) - if diskSet[key] { - t.Errorf("disk %s selected multiple times", key) - } - diskSet[key] = true - } -} - -func TestSelectDestinations_MultipleRacks(t *testing.T) { - // Test: 2 racks with 2 servers each, each server has 2 disks - // Need 8 shards - // Expected: Should spread across all 8 disks - disks := []*DiskCandidate{ - makeDisk("server1", 0, "dc1", "rack1", 10), - makeDisk("server1", 1, "dc1", "rack1", 10), - makeDisk("server2", 0, "dc1", "rack1", 10), - makeDisk("server2", 1, "dc1", "rack1", 10), - makeDisk("server3", 0, "dc1", "rack2", 10), - makeDisk("server3", 1, "dc1", "rack2", 10), - makeDisk("server4", 0, "dc1", "rack2", 10), - makeDisk("server4", 1, "dc1", "rack2", 10), - } - - config := PlacementRequest{ - ShardsNeeded: 8, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } - - result, err := SelectDestinations(disks, config) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(result.SelectedDisks) != 8 { - t.Errorf("expected 8 selected disks, got %d", len(result.SelectedDisks)) - } - - // Verify all 4 servers are used - if result.ServersUsed != 4 { - t.Errorf("expected 4 servers used, got %d", result.ServersUsed) - } - - // Verify both racks are used - if result.RacksUsed != 2 { - t.Errorf("expected 2 racks used, got %d", result.RacksUsed) - } -} - -func TestSelectDestinations_PrefersDifferentServers(t *testing.T) { - // Test: 4 servers with 4 disks each, need 4 shards - // Expected: Should use one disk from each server - disks := []*DiskCandidate{ - makeDisk("server1", 0, "dc1", "rack1", 10), - makeDisk("server1", 1, "dc1", "rack1", 10), - makeDisk("server1", 2, "dc1", "rack1", 10), - makeDisk("server1", 3, "dc1", "rack1", 10), - makeDisk("server2", 0, "dc1", "rack1", 10), - makeDisk("server2", 1, "dc1", "rack1", 10), - makeDisk("server2", 2, "dc1", "rack1", 10), - makeDisk("server2", 3, "dc1", "rack1", 10), - makeDisk("server3", 0, "dc1", "rack1", 10), - makeDisk("server3", 1, "dc1", "rack1", 10), - makeDisk("server3", 2, "dc1", "rack1", 10), - makeDisk("server3", 3, "dc1", "rack1", 10), - makeDisk("server4", 0, "dc1", "rack1", 10), - makeDisk("server4", 1, "dc1", "rack1", 10), - makeDisk("server4", 2, "dc1", "rack1", 10), - makeDisk("server4", 3, "dc1", "rack1", 10), - } - - config := PlacementRequest{ - ShardsNeeded: 4, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } - - result, err := SelectDestinations(disks, config) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(result.SelectedDisks) != 4 { - t.Errorf("expected 4 selected disks, got %d", len(result.SelectedDisks)) - } - - // Verify all 4 servers are used (one shard per server) - if result.ServersUsed != 4 { - t.Errorf("expected 4 servers used, got %d", result.ServersUsed) - } - - // Each server should have exactly 1 shard - for server, count := range result.ShardsPerServer { - if count != 1 { - t.Errorf("server %s has %d shards, expected 1", server, count) - } - } -} - -func TestSelectDestinations_SpilloverToMultipleDisksPerServer(t *testing.T) { - // Test: 2 servers with 4 disks each, need 6 shards - // Expected: First pick one from each server (2 shards), then one more from each (4 shards), - // then fill remaining from any server (6 shards) - disks := []*DiskCandidate{ - makeDisk("server1", 0, "dc1", "rack1", 10), - makeDisk("server1", 1, "dc1", "rack1", 10), - makeDisk("server1", 2, "dc1", "rack1", 10), - makeDisk("server1", 3, "dc1", "rack1", 10), - makeDisk("server2", 0, "dc1", "rack1", 10), - makeDisk("server2", 1, "dc1", "rack1", 10), - makeDisk("server2", 2, "dc1", "rack1", 10), - makeDisk("server2", 3, "dc1", "rack1", 10), - } - - config := PlacementRequest{ - ShardsNeeded: 6, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } - - result, err := SelectDestinations(disks, config) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(result.SelectedDisks) != 6 { - t.Errorf("expected 6 selected disks, got %d", len(result.SelectedDisks)) - } - - // Both servers should be used - if result.ServersUsed != 2 { - t.Errorf("expected 2 servers used, got %d", result.ServersUsed) - } - - // Each server should have exactly 3 shards (balanced) - for server, count := range result.ShardsPerServer { - if count != 3 { - t.Errorf("server %s has %d shards, expected 3", server, count) - } - } -} - -func TestSelectDestinations_MaxShardsPerServer(t *testing.T) { - // Test: 2 servers with 4 disks each, need 6 shards, max 2 per server - // Expected: Should only select 4 shards (2 per server limit) - disks := []*DiskCandidate{ - makeDisk("server1", 0, "dc1", "rack1", 10), - makeDisk("server1", 1, "dc1", "rack1", 10), - makeDisk("server1", 2, "dc1", "rack1", 10), - makeDisk("server1", 3, "dc1", "rack1", 10), - makeDisk("server2", 0, "dc1", "rack1", 10), - makeDisk("server2", 1, "dc1", "rack1", 10), - makeDisk("server2", 2, "dc1", "rack1", 10), - makeDisk("server2", 3, "dc1", "rack1", 10), - } - - config := PlacementRequest{ - ShardsNeeded: 6, - MaxShardsPerServer: 2, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } - - result, err := SelectDestinations(disks, config) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Should only get 4 shards due to server limit - if len(result.SelectedDisks) != 4 { - t.Errorf("expected 4 selected disks (limit 2 per server), got %d", len(result.SelectedDisks)) - } - - // No server should exceed the limit - for server, count := range result.ShardsPerServer { - if count > 2 { - t.Errorf("server %s has %d shards, exceeds limit of 2", server, count) - } - } -} - -func TestSelectDestinations_14ShardsAcross7Servers(t *testing.T) { - // Test: Real-world EC scenario - 14 shards across 7 servers with 2 disks each - // Expected: Should spread evenly (2 shards per server) - var disks []*DiskCandidate - for i := 1; i <= 7; i++ { - serverID := "server" + string(rune('0'+i)) - disks = append(disks, makeDisk(serverID, 0, "dc1", "rack1", 10)) - disks = append(disks, makeDisk(serverID, 1, "dc1", "rack1", 10)) - } - - config := PlacementRequest{ - ShardsNeeded: 14, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } - - result, err := SelectDestinations(disks, config) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(result.SelectedDisks) != 14 { - t.Errorf("expected 14 selected disks, got %d", len(result.SelectedDisks)) - } - - // All 7 servers should be used - if result.ServersUsed != 7 { - t.Errorf("expected 7 servers used, got %d", result.ServersUsed) - } - - // Each server should have exactly 2 shards - for server, count := range result.ShardsPerServer { - if count != 2 { - t.Errorf("server %s has %d shards, expected 2", server, count) - } - } -} - -func TestSelectDestinations_FewerServersThanShards(t *testing.T) { - // Test: Only 3 servers but need 6 shards - // Expected: Should distribute evenly (2 per server) - disks := []*DiskCandidate{ - makeDisk("server1", 0, "dc1", "rack1", 10), - makeDisk("server1", 1, "dc1", "rack1", 10), - makeDisk("server1", 2, "dc1", "rack1", 10), - makeDisk("server2", 0, "dc1", "rack1", 10), - makeDisk("server2", 1, "dc1", "rack1", 10), - makeDisk("server2", 2, "dc1", "rack1", 10), - makeDisk("server3", 0, "dc1", "rack1", 10), - makeDisk("server3", 1, "dc1", "rack1", 10), - makeDisk("server3", 2, "dc1", "rack1", 10), - } - - config := PlacementRequest{ - ShardsNeeded: 6, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } - - result, err := SelectDestinations(disks, config) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(result.SelectedDisks) != 6 { - t.Errorf("expected 6 selected disks, got %d", len(result.SelectedDisks)) - } - - // All 3 servers should be used - if result.ServersUsed != 3 { - t.Errorf("expected 3 servers used, got %d", result.ServersUsed) - } - - // Each server should have exactly 2 shards - for server, count := range result.ShardsPerServer { - if count != 2 { - t.Errorf("server %s has %d shards, expected 2", server, count) - } - } -} - -func TestSelectDestinations_NoSuitableDisks(t *testing.T) { - // Test: All disks have no free slots - disks := []*DiskCandidate{ - {NodeID: "server1", DiskID: 0, DataCenter: "dc1", Rack: "rack1", FreeSlots: 0}, - {NodeID: "server2", DiskID: 0, DataCenter: "dc1", Rack: "rack1", FreeSlots: 0}, - } - - config := PlacementRequest{ - ShardsNeeded: 4, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } - - _, err := SelectDestinations(disks, config) - if err == nil { - t.Error("expected error for no suitable disks, got nil") - } -} - -func TestSelectDestinations_EmptyInput(t *testing.T) { - config := DefaultPlacementRequest() - _, err := SelectDestinations([]*DiskCandidate{}, config) - if err == nil { - t.Error("expected error for empty input, got nil") - } -} - -func TestSelectDestinations_FiltersByLoad(t *testing.T) { - // Test: Some disks have too high load - disks := []*DiskCandidate{ - {NodeID: "server1", DiskID: 0, DataCenter: "dc1", Rack: "rack1", FreeSlots: 10, LoadCount: 10}, - {NodeID: "server2", DiskID: 0, DataCenter: "dc1", Rack: "rack1", FreeSlots: 10, LoadCount: 2}, - {NodeID: "server3", DiskID: 0, DataCenter: "dc1", Rack: "rack1", FreeSlots: 10, LoadCount: 1}, - } - - config := PlacementRequest{ - ShardsNeeded: 2, - MaxTaskLoad: 5, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } - - result, err := SelectDestinations(disks, config) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Should only select from server2 and server3 (server1 has too high load) - for _, disk := range result.SelectedDisks { - if disk.NodeID == "server1" { - t.Errorf("disk from server1 should not be selected (load too high)") - } - } -} - -func TestCalculateDiskScore(t *testing.T) { - // Test that score calculation works as expected - lowUtilDisk := &DiskCandidate{ - VolumeCount: 10, - MaxVolumeCount: 100, - ShardCount: 0, - LoadCount: 0, - } - - highUtilDisk := &DiskCandidate{ - VolumeCount: 90, - MaxVolumeCount: 100, - ShardCount: 5, - LoadCount: 5, - } - - lowScore := calculateDiskScore(lowUtilDisk) - highScore := calculateDiskScore(highUtilDisk) - - if lowScore <= highScore { - t.Errorf("low utilization disk should have higher score: low=%f, high=%f", lowScore, highScore) - } -} - -func TestCalculateIdealDistribution(t *testing.T) { - tests := []struct { - totalShards int - numServers int - expectedMin int - expectedMax int - }{ - {14, 7, 2, 2}, // Even distribution - {14, 4, 3, 4}, // Uneven: 14/4 = 3 remainder 2 - {6, 3, 2, 2}, // Even distribution - {7, 3, 2, 3}, // Uneven: 7/3 = 2 remainder 1 - {10, 0, 0, 10}, // Edge case: no servers - {0, 5, 0, 0}, // Edge case: no shards - } - - for _, tt := range tests { - min, max := CalculateIdealDistribution(tt.totalShards, tt.numServers) - if min != tt.expectedMin || max != tt.expectedMax { - t.Errorf("CalculateIdealDistribution(%d, %d) = (%d, %d), want (%d, %d)", - tt.totalShards, tt.numServers, min, max, tt.expectedMin, tt.expectedMax) - } - } -} - -func TestVerifySpread(t *testing.T) { - result := &PlacementResult{ - ServersUsed: 3, - RacksUsed: 2, - } - - // Should pass - if err := VerifySpread(result, 3, 2); err != nil { - t.Errorf("unexpected error: %v", err) - } - - // Should fail - not enough servers - if err := VerifySpread(result, 4, 2); err == nil { - t.Error("expected error for insufficient servers") - } - - // Should fail - not enough racks - if err := VerifySpread(result, 3, 3); err == nil { - t.Error("expected error for insufficient racks") - } -} - -func TestSelectDestinations_MultiDC(t *testing.T) { - // Test: 2 DCs, each with 2 racks, each rack has 2 servers - disks := []*DiskCandidate{ - // DC1, Rack1 - makeDisk("dc1-r1-s1", 0, "dc1", "rack1", 10), - makeDisk("dc1-r1-s1", 1, "dc1", "rack1", 10), - makeDisk("dc1-r1-s2", 0, "dc1", "rack1", 10), - makeDisk("dc1-r1-s2", 1, "dc1", "rack1", 10), - // DC1, Rack2 - makeDisk("dc1-r2-s1", 0, "dc1", "rack2", 10), - makeDisk("dc1-r2-s1", 1, "dc1", "rack2", 10), - makeDisk("dc1-r2-s2", 0, "dc1", "rack2", 10), - makeDisk("dc1-r2-s2", 1, "dc1", "rack2", 10), - // DC2, Rack1 - makeDisk("dc2-r1-s1", 0, "dc2", "rack1", 10), - makeDisk("dc2-r1-s1", 1, "dc2", "rack1", 10), - makeDisk("dc2-r1-s2", 0, "dc2", "rack1", 10), - makeDisk("dc2-r1-s2", 1, "dc2", "rack1", 10), - // DC2, Rack2 - makeDisk("dc2-r2-s1", 0, "dc2", "rack2", 10), - makeDisk("dc2-r2-s1", 1, "dc2", "rack2", 10), - makeDisk("dc2-r2-s2", 0, "dc2", "rack2", 10), - makeDisk("dc2-r2-s2", 1, "dc2", "rack2", 10), - } - - config := PlacementRequest{ - ShardsNeeded: 8, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } - - result, err := SelectDestinations(disks, config) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(result.SelectedDisks) != 8 { - t.Errorf("expected 8 selected disks, got %d", len(result.SelectedDisks)) - } - - // Should use all 4 racks - if result.RacksUsed != 4 { - t.Errorf("expected 4 racks used, got %d", result.RacksUsed) - } - - // Should use both DCs - if result.DCsUsed != 2 { - t.Errorf("expected 2 DCs used, got %d", result.DCsUsed) - } -} - -func TestSelectDestinations_SameRackDifferentDC(t *testing.T) { - // Test: Same rack name in different DCs should be treated as different racks - disks := []*DiskCandidate{ - makeDisk("dc1-s1", 0, "dc1", "rack1", 10), - makeDisk("dc2-s1", 0, "dc2", "rack1", 10), - } - - config := PlacementRequest{ - ShardsNeeded: 2, - PreferDifferentServers: true, - PreferDifferentRacks: true, - } - - result, err := SelectDestinations(disks, config) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Should use 2 racks (dc1:rack1 and dc2:rack1 are different) - if result.RacksUsed != 2 { - t.Errorf("expected 2 racks used (different DCs), got %d", result.RacksUsed) - } -} diff --git a/weed/storage/idx/binary_search.go b/weed/storage/idx/binary_search.go deleted file mode 100644 index 9f1dcef40..000000000 --- a/weed/storage/idx/binary_search.go +++ /dev/null @@ -1,29 +0,0 @@ -package idx - -import ( - "github.com/seaweedfs/seaweedfs/weed/storage/types" -) - -// FirstInvalidIndex find the first index the failed lessThanOrEqualToFn function's requirement. -func FirstInvalidIndex(bytes []byte, lessThanOrEqualToFn func(key types.NeedleId, offset types.Offset, size types.Size) (bool, error)) (int, error) { - left, right := 0, len(bytes)/types.NeedleMapEntrySize-1 - index := right + 1 - for left <= right { - mid := left + (right-left)>>1 - loc := mid * types.NeedleMapEntrySize - key := types.BytesToNeedleId(bytes[loc : loc+types.NeedleIdSize]) - offset := types.BytesToOffset(bytes[loc+types.NeedleIdSize : loc+types.NeedleIdSize+types.OffsetSize]) - size := types.BytesToSize(bytes[loc+types.NeedleIdSize+types.OffsetSize : loc+types.NeedleIdSize+types.OffsetSize+types.SizeSize]) - res, err := lessThanOrEqualToFn(key, offset, size) - if err != nil { - return -1, err - } - if res { - left = mid + 1 - } else { - index = mid - right = mid - 1 - } - } - return index, nil -} diff --git a/weed/storage/idx_binary_search_test.go b/weed/storage/idx_binary_search_test.go deleted file mode 100644 index 77d38e562..000000000 --- a/weed/storage/idx_binary_search_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package storage - -import ( - "os" - "testing" - - "github.com/seaweedfs/seaweedfs/weed/storage/idx" - "github.com/seaweedfs/seaweedfs/weed/storage/needle" - "github.com/seaweedfs/seaweedfs/weed/storage/super_block" - "github.com/seaweedfs/seaweedfs/weed/storage/types" - "github.com/stretchr/testify/assert" -) - -func TestFirstInvalidIndex(t *testing.T) { - dir := t.TempDir() - - v, err := NewVolume(dir, dir, "", 1, NeedleMapInMemory, &super_block.ReplicaPlacement{}, &needle.TTL{}, 0, needle.GetCurrentVersion(), 0, 0) - if err != nil { - t.Fatalf("volume creation: %v", err) - } - defer v.Close() - type WriteInfo struct { - offset int64 - size int32 - } - // initialize 20 needles then update first 10 needles - for i := 1; i <= 30; i++ { - n := newRandomNeedle(uint64(i)) - n.Flags = 0x08 - _, _, _, err := v.writeNeedle2(n, true, false) - if err != nil { - t.Fatalf("write needle %d: %v", i, err) - } - } - b, err := os.ReadFile(v.IndexFileName() + ".idx") - if err != nil { - t.Fatal(err) - } - // base case every record is valid -> nothing is filtered - index, err := idx.FirstInvalidIndex(b, func(key types.NeedleId, offset types.Offset, size types.Size) (bool, error) { - return true, nil - }) - if err != nil { - t.Fatalf("failed to complete binary search %v", err) - } - assert.Equal(t, 30, index, "when every record is valid nothing should be filtered from binary search") - index, err = idx.FirstInvalidIndex(b, func(key types.NeedleId, offset types.Offset, size types.Size) (bool, error) { - return false, nil - }) - if err != nil { - t.Fatal(err) - } - assert.Equal(t, 0, index, "when every record is invalid everything should be filtered from binary search") - index, err = idx.FirstInvalidIndex(b, func(key types.NeedleId, offset types.Offset, size types.Size) (bool, error) { - return key < 20, nil - }) - if err != nil { - t.Fatal(err) - } - // needle key range from 1 to 30 so < 20 means 19 keys are valid and cutoff the bytes at 19 * 16 = 304 - assert.Equal(t, 19, index, "when every record is invalid everything should be filtered from binary search") - - index, err = idx.FirstInvalidIndex(b, func(key types.NeedleId, offset types.Offset, size types.Size) (bool, error) { - return key <= 1, nil - }) - if err != nil { - t.Fatal(err) - } - // needle key range from 1 to 30 so <=1 1 means 1 key is valid and cutoff the bytes at 1 * 16 = 16 - assert.Equal(t, 1, index, "when every record is invalid everything should be filtered from binary search") -} diff --git a/weed/storage/needle/crc.go b/weed/storage/needle/crc.go index 6ac31cb43..b1c092c49 100644 --- a/weed/storage/needle/crc.go +++ b/weed/storage/needle/crc.go @@ -32,24 +32,7 @@ func (n *Needle) Etag() string { return fmt.Sprintf("%x", bits) } -func NewCRCwriter(w io.Writer) *CRCwriter { - - return &CRCwriter{ - crc: CRC(0), - w: w, - } - -} - type CRCwriter struct { crc CRC w io.Writer } - -func (c *CRCwriter) Write(p []byte) (n int, err error) { - n, err = c.w.Write(p) // with each write ... - c.crc = c.crc.Update(p) - return -} - -func (c *CRCwriter) Sum() uint32 { return uint32(c.crc) } // final hash diff --git a/weed/storage/needle/needle_write.go b/weed/storage/needle/needle_write.go index 009bf393e..d90807d70 100644 --- a/weed/storage/needle/needle_write.go +++ b/weed/storage/needle/needle_write.go @@ -1,7 +1,6 @@ package needle import ( - "bytes" "fmt" "github.com/seaweedfs/seaweedfs/weed/glog" @@ -83,27 +82,3 @@ func WriteNeedleBlob(w backend.BackendStorageFile, dataSlice []byte, size Size, return } - -// prepareNeedleWrite encapsulates the common beginning logic for all versioned writeNeedle functions. -func prepareNeedleWrite(w backend.BackendStorageFile, n *Needle) (offset uint64, bytesBuffer *bytes.Buffer, cleanup func(err error), err error) { - end, _, e := w.GetStat() - if e != nil { - err = fmt.Errorf("Cannot Read Current Volume Position: %w", e) - return - } - offset = uint64(end) - if offset >= MaxPossibleVolumeSize && len(n.Data) != 0 { - err = fmt.Errorf("Volume Size %d Exceeded %d", offset, MaxPossibleVolumeSize) - return - } - bytesBuffer = buffer_pool.SyncPoolGetBuffer() - cleanup = func(err error) { - if err != nil { - if te := w.Truncate(end); te != nil { - // handle error or log - } - } - buffer_pool.SyncPoolPutBuffer(bytesBuffer) - } - return -} diff --git a/weed/storage/store_state.go b/weed/storage/store_state.go index 2bac4fae6..9014b2a2e 100644 --- a/weed/storage/store_state.go +++ b/weed/storage/store_state.go @@ -34,16 +34,6 @@ func NewState(dir string) (*State, error) { return state, err } -func NewStateFromProto(filePath string, state *volume_server_pb.VolumeServerState) *State { - pb := &volume_server_pb.VolumeServerState{} - proto.Merge(pb, state) - - return &State{ - filePath: filePath, - pb: pb, - } -} - func (st *State) Proto() *volume_server_pb.VolumeServerState { st.mu.Lock() defer st.mu.Unlock() diff --git a/weed/topology/capacity_reservation_test.go b/weed/topology/capacity_reservation_test.go deleted file mode 100644 index 38cb14c50..000000000 --- a/weed/topology/capacity_reservation_test.go +++ /dev/null @@ -1,215 +0,0 @@ -package topology - -import ( - "sync" - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/storage/types" -) - -func TestCapacityReservations_BasicOperations(t *testing.T) { - cr := newCapacityReservations() - diskType := types.HardDriveType - - // Test initial state - if count := cr.getReservedCount(diskType); count != 0 { - t.Errorf("Expected 0 reserved count initially, got %d", count) - } - - // Test add reservation - reservationId := cr.addReservation(diskType, 5) - if reservationId == "" { - t.Error("Expected non-empty reservation ID") - } - - if count := cr.getReservedCount(diskType); count != 5 { - t.Errorf("Expected 5 reserved count, got %d", count) - } - - // Test multiple reservations - cr.addReservation(diskType, 3) - if count := cr.getReservedCount(diskType); count != 8 { - t.Errorf("Expected 8 reserved count after second reservation, got %d", count) - } - - // Test remove reservation - success := cr.removeReservation(reservationId) - if !success { - t.Error("Expected successful removal of existing reservation") - } - - if count := cr.getReservedCount(diskType); count != 3 { - t.Errorf("Expected 3 reserved count after removal, got %d", count) - } - - // Test remove non-existent reservation - success = cr.removeReservation("non-existent-id") - if success { - t.Error("Expected failure when removing non-existent reservation") - } -} - -func TestCapacityReservations_ExpiredCleaning(t *testing.T) { - cr := newCapacityReservations() - diskType := types.HardDriveType - - // Add reservations and manipulate their creation time - reservationId1 := cr.addReservation(diskType, 3) - reservationId2 := cr.addReservation(diskType, 2) - - // Make one reservation "old" - cr.Lock() - if reservation, exists := cr.reservations[reservationId1]; exists { - reservation.createdAt = time.Now().Add(-10 * time.Minute) // 10 minutes ago - } - cr.Unlock() - - // Clean expired reservations (5 minute expiration) - cr.cleanExpiredReservations(5 * time.Minute) - - // Only the non-expired reservation should remain - if count := cr.getReservedCount(diskType); count != 2 { - t.Errorf("Expected 2 reserved count after cleaning, got %d", count) - } - - // Verify the right reservation was kept - if !cr.removeReservation(reservationId2) { - t.Error("Expected recent reservation to still exist") - } - - if cr.removeReservation(reservationId1) { - t.Error("Expected old reservation to be cleaned up") - } -} - -func TestCapacityReservations_DifferentDiskTypes(t *testing.T) { - cr := newCapacityReservations() - - // Add reservations for different disk types - cr.addReservation(types.HardDriveType, 5) - cr.addReservation(types.SsdType, 3) - - // Check counts are separate - if count := cr.getReservedCount(types.HardDriveType); count != 5 { - t.Errorf("Expected 5 HDD reserved count, got %d", count) - } - - if count := cr.getReservedCount(types.SsdType); count != 3 { - t.Errorf("Expected 3 SSD reserved count, got %d", count) - } -} - -func TestNodeImpl_ReservationMethods(t *testing.T) { - // Create a test data node - dn := NewDataNode("test-node") - diskType := types.HardDriveType - - // Set up some capacity - diskUsage := dn.diskUsages.getOrCreateDisk(diskType) - diskUsage.maxVolumeCount = 10 - diskUsage.volumeCount = 5 // 5 volumes free initially - - option := &VolumeGrowOption{DiskType: diskType} - - // Test available space calculation - available := dn.AvailableSpaceFor(option) - if available != 5 { - t.Errorf("Expected 5 available slots, got %d", available) - } - - availableForReservation := dn.AvailableSpaceForReservation(option) - if availableForReservation != 5 { - t.Errorf("Expected 5 available slots for reservation, got %d", availableForReservation) - } - - // Test successful reservation - reservationId, success := dn.TryReserveCapacity(diskType, 3) - if !success { - t.Error("Expected successful reservation") - } - if reservationId == "" { - t.Error("Expected non-empty reservation ID") - } - - // Available space should be reduced by reservations - availableForReservation = dn.AvailableSpaceForReservation(option) - if availableForReservation != 2 { - t.Errorf("Expected 2 available slots after reservation, got %d", availableForReservation) - } - - // Base available space should remain unchanged - available = dn.AvailableSpaceFor(option) - if available != 5 { - t.Errorf("Expected base available to remain 5, got %d", available) - } - - // Test reservation failure when insufficient capacity - _, success = dn.TryReserveCapacity(diskType, 3) - if success { - t.Error("Expected reservation failure due to insufficient capacity") - } - - // Test release reservation - dn.ReleaseReservedCapacity(reservationId) - availableForReservation = dn.AvailableSpaceForReservation(option) - if availableForReservation != 5 { - t.Errorf("Expected 5 available slots after release, got %d", availableForReservation) - } -} - -func TestNodeImpl_ConcurrentReservations(t *testing.T) { - dn := NewDataNode("test-node") - diskType := types.HardDriveType - - // Set up capacity - diskUsage := dn.diskUsages.getOrCreateDisk(diskType) - diskUsage.maxVolumeCount = 10 - diskUsage.volumeCount = 0 // 10 volumes free initially - - // Test concurrent reservations using goroutines - var wg sync.WaitGroup - var reservationIds sync.Map - concurrentRequests := 10 - wg.Add(concurrentRequests) - - for i := 0; i < concurrentRequests; i++ { - go func(i int) { - defer wg.Done() - if reservationId, success := dn.TryReserveCapacity(diskType, 1); success { - reservationIds.Store(reservationId, true) - t.Logf("goroutine %d: Successfully reserved %s", i, reservationId) - } else { - t.Errorf("goroutine %d: Expected successful reservation", i) - } - }(i) - } - - wg.Wait() - - // Should have no more capacity - option := &VolumeGrowOption{DiskType: diskType} - if available := dn.AvailableSpaceForReservation(option); available != 0 { - t.Errorf("Expected 0 available slots after all reservations, got %d", available) - // Debug: check total reserved - reservedCount := dn.capacityReservations.getReservedCount(diskType) - t.Logf("Debug: Total reserved count: %d", reservedCount) - } - - // Next reservation should fail - _, success := dn.TryReserveCapacity(diskType, 1) - if success { - t.Error("Expected reservation failure when at capacity") - } - - // Release all reservations - reservationIds.Range(func(key, value interface{}) bool { - dn.ReleaseReservedCapacity(key.(string)) - return true - }) - - // Should have full capacity back - if available := dn.AvailableSpaceForReservation(option); available != 10 { - t.Errorf("Expected 10 available slots after releasing all, got %d", available) - } -} diff --git a/weed/topology/disk.go b/weed/topology/disk.go index fa99ef37a..3616ff928 100644 --- a/weed/topology/disk.go +++ b/weed/topology/disk.go @@ -118,16 +118,6 @@ func (a *DiskUsageCounts) FreeSpace() int64 { return freeVolumeSlotCount } -func (a *DiskUsageCounts) minus(b *DiskUsageCounts) *DiskUsageCounts { - return &DiskUsageCounts{ - volumeCount: a.volumeCount - b.volumeCount, - remoteVolumeCount: a.remoteVolumeCount - b.remoteVolumeCount, - activeVolumeCount: a.activeVolumeCount - b.activeVolumeCount, - ecShardCount: a.ecShardCount - b.ecShardCount, - maxVolumeCount: a.maxVolumeCount - b.maxVolumeCount, - } -} - func (du *DiskUsages) getOrCreateDisk(diskType types.DiskType) *DiskUsageCounts { du.Lock() defer du.Unlock() diff --git a/weed/topology/node.go b/weed/topology/node.go index d32927fca..66d44a8e1 100644 --- a/weed/topology/node.go +++ b/weed/topology/node.go @@ -40,13 +40,6 @@ func newCapacityReservations() *CapacityReservations { } } -func (cr *CapacityReservations) addReservation(diskType types.DiskType, count int64) string { - cr.Lock() - defer cr.Unlock() - - return cr.doAddReservation(diskType, count) -} - func (cr *CapacityReservations) removeReservation(reservationId string) bool { cr.Lock() defer cr.Unlock() diff --git a/weed/topology/volume_layout.go b/weed/topology/volume_layout.go index ecbacef75..6a7ca2c89 100644 --- a/weed/topology/volume_layout.go +++ b/weed/topology/volume_layout.go @@ -40,10 +40,6 @@ func ExistCopies() stateIndicator { return func(state copyState) bool { return state != noCopies } } -func NoCopies() stateIndicator { - return func(state copyState) bool { return state == noCopies } -} - type volumesBinaryState struct { rp *super_block.ReplicaPlacement name volumeState // the name for volume state (eg. "Readonly", "Oversized") @@ -264,12 +260,6 @@ func (vl *VolumeLayout) isCrowdedVolume(v *storage.VolumeInfo) bool { return float64(v.Size) > float64(vl.volumeSizeLimit)*VolumeGrowStrategy.Threshold } -func (vl *VolumeLayout) isWritable(v *storage.VolumeInfo) bool { - return !vl.isOversized(v) && - v.Version == needle.GetCurrentVersion() && - !v.ReadOnly -} - func (vl *VolumeLayout) isEmpty() bool { vl.accessLock.RLock() defer vl.accessLock.RUnlock() diff --git a/weed/topology/volume_layout_test.go b/weed/topology/volume_layout_test.go deleted file mode 100644 index 999c8de8e..000000000 --- a/weed/topology/volume_layout_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package topology - -import ( - "testing" - - "github.com/seaweedfs/seaweedfs/weed/storage" - "github.com/seaweedfs/seaweedfs/weed/storage/needle" - "github.com/seaweedfs/seaweedfs/weed/storage/super_block" - "github.com/seaweedfs/seaweedfs/weed/storage/types" -) - -func TestVolumesBinaryState(t *testing.T) { - vids := []needle.VolumeId{ - needle.VolumeId(1), - needle.VolumeId(2), - needle.VolumeId(3), - needle.VolumeId(4), - needle.VolumeId(5), - } - - dns := []*DataNode{ - &DataNode{ - Ip: "127.0.0.1", - Port: 8081, - }, - &DataNode{ - Ip: "127.0.0.1", - Port: 8082, - }, - &DataNode{ - Ip: "127.0.0.1", - Port: 8083, - }, - } - - rp, _ := super_block.NewReplicaPlacementFromString("002") - - state_exist := NewVolumesBinaryState(readOnlyState, rp, ExistCopies()) - state_exist.Add(vids[0], dns[0]) - state_exist.Add(vids[0], dns[1]) - state_exist.Add(vids[1], dns[2]) - state_exist.Add(vids[2], dns[1]) - state_exist.Add(vids[4], dns[1]) - state_exist.Add(vids[4], dns[2]) - - state_no := NewVolumesBinaryState(readOnlyState, rp, NoCopies()) - state_no.Add(vids[0], dns[0]) - state_no.Add(vids[0], dns[1]) - state_no.Add(vids[3], dns[1]) - - tests := []struct { - name string - state *volumesBinaryState - expectResult []bool - update func() - expectResultAfterUpdate []bool - }{ - { - name: "mark true when copies exist", - state: state_exist, - expectResult: []bool{true, true, true, false, true}, - update: func() { - state_exist.Remove(vids[0], dns[2]) - state_exist.Remove(vids[1], dns[2]) - state_exist.Remove(vids[3], dns[2]) - state_exist.Remove(vids[4], dns[1]) - state_exist.Remove(vids[4], dns[2]) - }, - expectResultAfterUpdate: []bool{true, false, true, false, false}, - }, - { - name: "mark true when no copies exist", - state: state_no, - expectResult: []bool{false, true, true, false, true}, - update: func() { - state_no.Remove(vids[0], dns[2]) - state_no.Remove(vids[1], dns[2]) - state_no.Add(vids[2], dns[1]) - state_no.Remove(vids[3], dns[1]) - state_no.Remove(vids[4], dns[2]) - }, - expectResultAfterUpdate: []bool{false, true, false, true, true}, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var result []bool - for index, _ := range vids { - result = append(result, test.state.IsTrue(vids[index])) - } - if len(result) != len(test.expectResult) { - t.Fatalf("len(result) != len(expectResult), got %d, expected %d\n", - len(result), len(test.expectResult)) - } - for index, val := range result { - if val != test.expectResult[index] { - t.Fatalf("result not matched, index %d, got %v, expected %v\n", - index, val, test.expectResult[index]) - } - } - test.update() - var updateResult []bool - for index, _ := range vids { - updateResult = append(updateResult, test.state.IsTrue(vids[index])) - } - if len(updateResult) != len(test.expectResultAfterUpdate) { - t.Fatalf("len(updateResult) != len(expectResultAfterUpdate), got %d, expected %d\n", - len(updateResult), len(test.expectResultAfterUpdate)) - } - for index, val := range updateResult { - if val != test.expectResultAfterUpdate[index] { - t.Fatalf("update result not matched, index %d, got %v, expected %v\n", - index, val, test.expectResultAfterUpdate[index]) - } - } - }) - } -} - -func TestVolumeLayoutCrowdedState(t *testing.T) { - rp, _ := super_block.NewReplicaPlacementFromString("000") - ttl, _ := needle.ReadTTL("") - diskType := types.HardDriveType - - vl := NewVolumeLayout(rp, ttl, diskType, 1024*1024*1024, false) - - vid := needle.VolumeId(1) - dn := &DataNode{ - NodeImpl: NodeImpl{ - id: "test-node", - }, - Ip: "127.0.0.1", - Port: 8080, - } - - // Create a volume info - volumeInfo := &storage.VolumeInfo{ - Id: vid, - ReplicaPlacement: rp, - Ttl: ttl, - DiskType: string(diskType), - } - - // Register the volume - vl.RegisterVolume(volumeInfo, dn) - - // Add the volume to writables - vl.accessLock.Lock() - vl.setVolumeWritable(vid) - vl.accessLock.Unlock() - - // Mark the volume as crowded - vl.SetVolumeCrowded(vid) - - t.Run("should be crowded after being marked", func(t *testing.T) { - vl.accessLock.RLock() - _, isCrowded := vl.crowded[vid] - vl.accessLock.RUnlock() - if !isCrowded { - t.Fatal("Volume should be marked as crowded after SetVolumeCrowded") - } - }) - - // Remove from writable (simulating temporary unwritable state) - vl.accessLock.Lock() - vl.removeFromWritable(vid) - vl.accessLock.Unlock() - - t.Run("should remain crowded after becoming unwritable", func(t *testing.T) { - // This is the fix for issue #6712 - crowded state should persist - vl.accessLock.RLock() - _, stillCrowded := vl.crowded[vid] - vl.accessLock.RUnlock() - if !stillCrowded { - t.Fatal("Volume should remain crowded after becoming unwritable (fix for issue #6712)") - } - }) - - // Now unregister the volume completely - vl.UnRegisterVolume(volumeInfo, dn) - - t.Run("should not be crowded after unregistering", func(t *testing.T) { - vl.accessLock.RLock() - _, stillCrowdedAfterUnregister := vl.crowded[vid] - vl.accessLock.RUnlock() - if stillCrowdedAfterUnregister { - t.Fatal("Volume should be removed from crowded map after full unregistration") - } - }) -} diff --git a/weed/util/bytes.go b/weed/util/bytes.go index faf7df916..43008c42f 100644 --- a/weed/util/bytes.go +++ b/weed/util/bytes.go @@ -120,10 +120,6 @@ func Base64Encode(data []byte) string { return base64.StdEncoding.EncodeToString(data) } -func Base64Md5(data []byte) string { - return Base64Encode(Md5(data)) -} - func Md5(data []byte) []byte { hash := md5.New() hash.Write(data) diff --git a/weed/util/limited_async_pool.go b/weed/util/limited_async_pool.go deleted file mode 100644 index 51dfd6252..000000000 --- a/weed/util/limited_async_pool.go +++ /dev/null @@ -1,66 +0,0 @@ -package util - -// initial version comes from https://hackernoon.com/asyncawait-in-golang-an-introductory-guide-ol1e34sg - -import ( - "container/list" - "context" - "sync" -) - -type Future interface { - Await() interface{} -} - -type future struct { - await func(ctx context.Context) interface{} -} - -func (f future) Await() interface{} { - return f.await(context.Background()) -} - -type LimitedAsyncExecutor struct { - executor *LimitedConcurrentExecutor - futureList *list.List - futureListCond *sync.Cond -} - -func NewLimitedAsyncExecutor(limit int) *LimitedAsyncExecutor { - return &LimitedAsyncExecutor{ - executor: NewLimitedConcurrentExecutor(limit), - futureList: list.New(), - futureListCond: sync.NewCond(&sync.Mutex{}), - } -} - -func (ae *LimitedAsyncExecutor) Execute(job func() interface{}) { - var result interface{} - c := make(chan struct{}) - ae.executor.Execute(func() { - defer close(c) - result = job() - }) - f := future{await: func(ctx context.Context) interface{} { - select { - case <-ctx.Done(): - return ctx.Err() - case <-c: - return result - } - }} - ae.futureListCond.L.Lock() - ae.futureList.PushBack(f) - ae.futureListCond.Signal() - ae.futureListCond.L.Unlock() -} - -func (ae *LimitedAsyncExecutor) NextFuture() Future { - ae.futureListCond.L.Lock() - for ae.futureList.Len() == 0 { - ae.futureListCond.Wait() - } - f := ae.futureList.Remove(ae.futureList.Front()) - ae.futureListCond.L.Unlock() - return f.(Future) -} diff --git a/weed/util/limited_async_pool_test.go b/weed/util/limited_async_pool_test.go deleted file mode 100644 index 1289f4f33..000000000 --- a/weed/util/limited_async_pool_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package util - -import ( - "fmt" - "sort" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestAsyncPool(t *testing.T) { - p := NewLimitedAsyncExecutor(3) - - p.Execute(FirstFunc) - p.Execute(SecondFunc) - p.Execute(ThirdFunc) - p.Execute(FourthFunc) - p.Execute(FifthFunc) - - var sorted_results []int - for i := 0; i < 5; i++ { - f := p.NextFuture() - x := f.Await().(int) - println(x) - sorted_results = append(sorted_results, x) - } - assert.True(t, sort.IntsAreSorted(sorted_results), "results should be sorted") -} - -func FirstFunc() any { - fmt.Println("-- Executing first function --") - time.Sleep(70 * time.Millisecond) - fmt.Println("-- First Function finished --") - return 1 -} - -func SecondFunc() any { - fmt.Println("-- Executing second function --") - time.Sleep(50 * time.Millisecond) - fmt.Println("-- Second Function finished --") - return 2 -} - -func ThirdFunc() any { - fmt.Println("-- Executing third function --") - time.Sleep(20 * time.Millisecond) - fmt.Println("-- Third Function finished --") - return 3 -} - -func FourthFunc() any { - fmt.Println("-- Executing fourth function --") - time.Sleep(100 * time.Millisecond) - fmt.Println("-- Fourth Function finished --") - return 4 -} - -func FifthFunc() any { - fmt.Println("-- Executing fifth function --") - time.Sleep(40 * time.Millisecond) - fmt.Println("-- Fourth fifth finished --") - return 5 -} diff --git a/weed/util/lock_table.go b/weed/util/lock_table.go index 8f65aac06..65daae39c 100644 --- a/weed/util/lock_table.go +++ b/weed/util/lock_table.go @@ -175,7 +175,3 @@ func (lt *LockTable[T]) ReleaseLock(key T, lock *ActiveLock) { // Notify the next waiter entry.cond.Broadcast() } - -func main() { - -} diff --git a/weed/wdclient/net2/base_connection_pool.go b/weed/wdclient/net2/base_connection_pool.go deleted file mode 100644 index 0b79130e3..000000000 --- a/weed/wdclient/net2/base_connection_pool.go +++ /dev/null @@ -1,159 +0,0 @@ -package net2 - -import ( - "net" - "strings" - "time" - - rp "github.com/seaweedfs/seaweedfs/weed/wdclient/resource_pool" -) - -const defaultDialTimeout = 1 * time.Second - -func defaultDialFunc(network string, address string) (net.Conn, error) { - return net.DialTimeout(network, address, defaultDialTimeout) -} - -func parseResourceLocation(resourceLocation string) ( - network string, - address string) { - - idx := strings.Index(resourceLocation, " ") - if idx >= 0 { - return resourceLocation[:idx], resourceLocation[idx+1:] - } - - return "", resourceLocation -} - -// A thin wrapper around the underlying resource pool. -type connectionPoolImpl struct { - options ConnectionOptions - - pool rp.ResourcePool -} - -// This returns a connection pool where all connections are connected -// to the same (network, address) -func newBaseConnectionPool( - options ConnectionOptions, - createPool func(rp.Options) rp.ResourcePool) ConnectionPool { - - dial := options.Dial - if dial == nil { - dial = defaultDialFunc - } - - openFunc := func(loc string) (interface{}, error) { - network, address := parseResourceLocation(loc) - return dial(network, address) - } - - closeFunc := func(handle interface{}) error { - return handle.(net.Conn).Close() - } - - poolOptions := rp.Options{ - MaxActiveHandles: options.MaxActiveConnections, - MaxIdleHandles: options.MaxIdleConnections, - MaxIdleTime: options.MaxIdleTime, - OpenMaxConcurrency: options.DialMaxConcurrency, - Open: openFunc, - Close: closeFunc, - NowFunc: options.NowFunc, - } - - return &connectionPoolImpl{ - options: options, - pool: createPool(poolOptions), - } -} - -// This returns a connection pool where all connections are connected -// to the same (network, address) -func NewSimpleConnectionPool(options ConnectionOptions) ConnectionPool { - return newBaseConnectionPool(options, rp.NewSimpleResourcePool) -} - -// This returns a connection pool that manages multiple (network, address) -// entries. The connections to each (network, address) entry acts -// independently. For example ("tcp", "localhost:11211") could act as memcache -// shard 0 and ("tcp", "localhost:11212") could act as memcache shard 1. -func NewMultiConnectionPool(options ConnectionOptions) ConnectionPool { - return newBaseConnectionPool( - options, - func(poolOptions rp.Options) rp.ResourcePool { - return rp.NewMultiResourcePool(poolOptions, nil) - }) -} - -// See ConnectionPool for documentation. -func (p *connectionPoolImpl) NumActive() int32 { - return p.pool.NumActive() -} - -// See ConnectionPool for documentation. -func (p *connectionPoolImpl) ActiveHighWaterMark() int32 { - return p.pool.ActiveHighWaterMark() -} - -// This returns the number of alive idle connections. This method is not part -// of ConnectionPool's API. It is used only for testing. -func (p *connectionPoolImpl) NumIdle() int { - return p.pool.NumIdle() -} - -// BaseConnectionPool can only register a single (network, address) entry. -// Register should be call before any Get calls. -func (p *connectionPoolImpl) Register(network string, address string) error { - return p.pool.Register(network + " " + address) -} - -// BaseConnectionPool has nothing to do on Unregister. -func (p *connectionPoolImpl) Unregister(network string, address string) error { - return nil -} - -func (p *connectionPoolImpl) ListRegistered() []NetworkAddress { - result := make([]NetworkAddress, 0, 1) - for _, location := range p.pool.ListRegistered() { - network, address := parseResourceLocation(location) - - result = append( - result, - NetworkAddress{ - Network: network, - Address: address, - }) - } - return result -} - -// This gets an active connection from the connection pool. Note that network -// and address arguments are ignored (The connections with point to the -// network/address provided by the first Register call). -func (p *connectionPoolImpl) Get( - network string, - address string) (ManagedConn, error) { - - handle, err := p.pool.Get(network + " " + address) - if err != nil { - return nil, err - } - return NewManagedConn(network, address, handle, p, p.options), nil -} - -// See ConnectionPool for documentation. -func (p *connectionPoolImpl) Release(conn ManagedConn) error { - return conn.ReleaseConnection() -} - -// See ConnectionPool for documentation. -func (p *connectionPoolImpl) Discard(conn ManagedConn) error { - return conn.DiscardConnection() -} - -// See ConnectionPool for documentation. -func (p *connectionPoolImpl) EnterLameDuckMode() { - p.pool.EnterLameDuckMode() -} diff --git a/weed/wdclient/net2/connection_pool.go b/weed/wdclient/net2/connection_pool.go deleted file mode 100644 index 5b8d4d232..000000000 --- a/weed/wdclient/net2/connection_pool.go +++ /dev/null @@ -1,97 +0,0 @@ -package net2 - -import ( - "net" - "time" -) - -type ConnectionOptions struct { - // The maximum number of connections that can be active per host at any - // given time (A non-positive value indicates the number of connections - // is unbounded). - MaxActiveConnections int32 - - // The maximum number of idle connections per host that are kept alive by - // the connection pool. - MaxIdleConnections uint32 - - // The maximum amount of time an idle connection can alive (if specified). - MaxIdleTime *time.Duration - - // This limits the number of concurrent Dial calls (there's no limit when - // DialMaxConcurrency is non-positive). - DialMaxConcurrency int - - // Dial specifies the dial function for creating network connections. - // If Dial is nil, net.DialTimeout is used, with timeout set to 1 second. - Dial func(network string, address string) (net.Conn, error) - - // This specifies the now time function. When the function is non-nil, the - // connection pool will use the specified function instead of time.Now to - // generate the current time. - NowFunc func() time.Time - - // This specifies the timeout for any Read() operation. - // Note that setting this to 0 (i.e. not setting it) will make - // read operations block indefinitely. - ReadTimeout time.Duration - - // This specifies the timeout for any Write() operation. - // Note that setting this to 0 (i.e. not setting it) will make - // write operations block indefinitely. - WriteTimeout time.Duration -} - -func (o ConnectionOptions) getCurrentTime() time.Time { - if o.NowFunc == nil { - return time.Now() - } else { - return o.NowFunc() - } -} - -// A generic interface for managed connection pool. All connection pool -// implementations must be threadsafe. -type ConnectionPool interface { - // This returns the number of active connections that are on loan. - NumActive() int32 - - // This returns the highest number of active connections for the entire - // lifetime of the pool. - ActiveHighWaterMark() int32 - - // This returns the number of idle connections that are in the pool. - NumIdle() int - - // This associates (network, address) to the connection pool; afterwhich, - // the user can get connections to (network, address). - Register(network string, address string) error - - // This dissociate (network, address) from the connection pool; - // afterwhich, the user can no longer get connections to - // (network, address). - Unregister(network string, address string) error - - // This returns the list of registered (network, address) entries. - ListRegistered() []NetworkAddress - - // This gets an active connection from the connection pool. The connection - // will remain active until one of the following is called: - // 1. conn.ReleaseConnection() - // 2. conn.DiscardConnection() - // 3. pool.Release(conn) - // 4. pool.Discard(conn) - Get(network string, address string) (ManagedConn, error) - - // This releases an active connection back to the connection pool. - Release(conn ManagedConn) error - - // This discards an active connection from the connection pool. - Discard(conn ManagedConn) error - - // Enter the connection pool into lame duck mode. The connection pool - // will no longer return connections, and all idle connections are closed - // immediately (including active connections that are released back to the - // pool afterward). - EnterLameDuckMode() -} diff --git a/weed/wdclient/net2/doc.go b/weed/wdclient/net2/doc.go deleted file mode 100644 index fd1c6323d..000000000 --- a/weed/wdclient/net2/doc.go +++ /dev/null @@ -1,6 +0,0 @@ -// net2 is a collection of functions meant to supplement the capabilities -// provided by the standard "net" package. -package net2 - -// copied from https://github.com/dropbox/godropbox/tree/master/net2 -// removed other dependencies diff --git a/weed/wdclient/net2/managed_connection.go b/weed/wdclient/net2/managed_connection.go deleted file mode 100644 index d4696739e..000000000 --- a/weed/wdclient/net2/managed_connection.go +++ /dev/null @@ -1,186 +0,0 @@ -package net2 - -import ( - "fmt" - "net" - "time" - - "errors" - - "github.com/seaweedfs/seaweedfs/weed/wdclient/resource_pool" -) - -// Dial's arguments. -type NetworkAddress struct { - Network string - Address string -} - -// A connection managed by a connection pool. NOTE: SetDeadline, -// SetReadDeadline and SetWriteDeadline are disabled for managed connections. -// (The deadlines are set by the connection pool). -type ManagedConn interface { - net.Conn - - // This returns the original (network, address) entry used for creating - // the connection. - Key() NetworkAddress - - // This returns the underlying net.Conn implementation. - RawConn() net.Conn - - // This returns the connection pool which owns this connection. - Owner() ConnectionPool - - // This indicates a user is done with the connection and releases the - // connection back to the connection pool. - ReleaseConnection() error - - // This indicates the connection is an invalid state, and that the - // connection should be discarded from the connection pool. - DiscardConnection() error -} - -// A physical implementation of ManagedConn -type managedConnImpl struct { - addr NetworkAddress - handle resource_pool.ManagedHandle - pool ConnectionPool - options ConnectionOptions -} - -// This creates a managed connection wrapper. -func NewManagedConn( - network string, - address string, - handle resource_pool.ManagedHandle, - pool ConnectionPool, - options ConnectionOptions) ManagedConn { - - addr := NetworkAddress{ - Network: network, - Address: address, - } - - return &managedConnImpl{ - addr: addr, - handle: handle, - pool: pool, - options: options, - } -} - -func (c *managedConnImpl) rawConn() (net.Conn, error) { - h, err := c.handle.Handle() - return h.(net.Conn), err -} - -// See ManagedConn for documentation. -func (c *managedConnImpl) RawConn() net.Conn { - h, _ := c.handle.Handle() - return h.(net.Conn) -} - -// See ManagedConn for documentation. -func (c *managedConnImpl) Key() NetworkAddress { - return c.addr -} - -// See ManagedConn for documentation. -func (c *managedConnImpl) Owner() ConnectionPool { - return c.pool -} - -// See ManagedConn for documentation. -func (c *managedConnImpl) ReleaseConnection() error { - return c.handle.Release() -} - -// See ManagedConn for documentation. -func (c *managedConnImpl) DiscardConnection() error { - return c.handle.Discard() -} - -// See net.Conn for documentation -func (c *managedConnImpl) Read(b []byte) (n int, err error) { - conn, err := c.rawConn() - if err != nil { - return 0, err - } - - if c.options.ReadTimeout > 0 { - deadline := c.options.getCurrentTime().Add(c.options.ReadTimeout) - _ = conn.SetReadDeadline(deadline) - } - n, err = conn.Read(b) - if err != nil { - var localAddr string - if conn.LocalAddr() != nil { - localAddr = conn.LocalAddr().String() - } else { - localAddr = "(nil)" - } - - var remoteAddr string - if conn.RemoteAddr() != nil { - remoteAddr = conn.RemoteAddr().String() - } else { - remoteAddr = "(nil)" - } - err = fmt.Errorf("Read error from host: %s <-> %s: %v", localAddr, remoteAddr, err) - } - return -} - -// See net.Conn for documentation -func (c *managedConnImpl) Write(b []byte) (n int, err error) { - conn, err := c.rawConn() - if err != nil { - return 0, err - } - - if c.options.WriteTimeout > 0 { - deadline := c.options.getCurrentTime().Add(c.options.WriteTimeout) - _ = conn.SetWriteDeadline(deadline) - } - n, err = conn.Write(b) - if err != nil { - err = fmt.Errorf("Write error: %w", err) - } - return -} - -// See net.Conn for documentation -func (c *managedConnImpl) Close() error { - return c.handle.Discard() -} - -// See net.Conn for documentation -func (c *managedConnImpl) LocalAddr() net.Addr { - conn, _ := c.rawConn() - return conn.LocalAddr() -} - -// See net.Conn for documentation -func (c *managedConnImpl) RemoteAddr() net.Addr { - conn, _ := c.rawConn() - return conn.RemoteAddr() -} - -// SetDeadline is disabled for managed connection (The deadline is set by -// us, with respect to the read/write timeouts specified in ConnectionOptions). -func (c *managedConnImpl) SetDeadline(t time.Time) error { - return errors.New("Cannot set deadline for managed connection") -} - -// SetReadDeadline is disabled for managed connection (The deadline is set by -// us with respect to the read timeout specified in ConnectionOptions). -func (c *managedConnImpl) SetReadDeadline(t time.Time) error { - return errors.New("Cannot set read deadline for managed connection") -} - -// SetWriteDeadline is disabled for managed connection (The deadline is set by -// us with respect to the write timeout specified in ConnectionOptions). -func (c *managedConnImpl) SetWriteDeadline(t time.Time) error { - return errors.New("Cannot set write deadline for managed connection") -} diff --git a/weed/wdclient/net2/port.go b/weed/wdclient/net2/port.go deleted file mode 100644 index f83adba28..000000000 --- a/weed/wdclient/net2/port.go +++ /dev/null @@ -1,19 +0,0 @@ -package net2 - -import ( - "net" - "strconv" -) - -// Returns the port information. -func GetPort(addr net.Addr) (int, error) { - _, lport, err := net.SplitHostPort(addr.String()) - if err != nil { - return -1, err - } - lportInt, err := strconv.Atoi(lport) - if err != nil { - return -1, err - } - return lportInt, nil -} diff --git a/weed/wdclient/resource_pool/doc.go b/weed/wdclient/resource_pool/doc.go deleted file mode 100644 index b8b3f92fa..000000000 --- a/weed/wdclient/resource_pool/doc.go +++ /dev/null @@ -1,5 +0,0 @@ -// A generic resource pool for managing resources such as network connections. -package resource_pool - -// copied from https://github.com/dropbox/godropbox/tree/master/resource_pool -// removed other dependencies diff --git a/weed/wdclient/resource_pool/managed_handle.go b/weed/wdclient/resource_pool/managed_handle.go deleted file mode 100644 index 936c2d7c3..000000000 --- a/weed/wdclient/resource_pool/managed_handle.go +++ /dev/null @@ -1,97 +0,0 @@ -package resource_pool - -import ( - "sync/atomic" - - "errors" -) - -// A resource handle managed by a resource pool. -type ManagedHandle interface { - // This returns the handle's resource location. - ResourceLocation() string - - // This returns the underlying resource handle (or error if the handle - // is no longer active). - Handle() (interface{}, error) - - // This returns the resource pool which owns this handle. - Owner() ResourcePool - - // The releases the underlying resource handle to the caller and marks the - // managed handle as inactive. The caller is responsible for cleaning up - // the released handle. This returns nil if the managed handle no longer - // owns the resource. - ReleaseUnderlyingHandle() interface{} - - // This indicates a user is done with the handle and releases the handle - // back to the resource pool. - Release() error - - // This indicates the handle is an invalid state, and that the - // connection should be discarded from the connection pool. - Discard() error -} - -// A physical implementation of ManagedHandle -type managedHandleImpl struct { - location string - handle interface{} - pool ResourcePool - isActive int32 // atomic bool - options Options -} - -// This creates a managed handle wrapper. -func NewManagedHandle( - resourceLocation string, - handle interface{}, - pool ResourcePool, - options Options) ManagedHandle { - - h := &managedHandleImpl{ - location: resourceLocation, - handle: handle, - pool: pool, - options: options, - } - atomic.StoreInt32(&h.isActive, 1) - - return h -} - -// See ManagedHandle for documentation. -func (c *managedHandleImpl) ResourceLocation() string { - return c.location -} - -// See ManagedHandle for documentation. -func (c *managedHandleImpl) Handle() (interface{}, error) { - if atomic.LoadInt32(&c.isActive) == 0 { - return c.handle, errors.New("Resource handle is no longer valid") - } - return c.handle, nil -} - -// See ManagedHandle for documentation. -func (c *managedHandleImpl) Owner() ResourcePool { - return c.pool -} - -// See ManagedHandle for documentation. -func (c *managedHandleImpl) ReleaseUnderlyingHandle() interface{} { - if atomic.CompareAndSwapInt32(&c.isActive, 1, 0) { - return c.handle - } - return nil -} - -// See ManagedHandle for documentation. -func (c *managedHandleImpl) Release() error { - return c.pool.Release(c) -} - -// See ManagedHandle for documentation. -func (c *managedHandleImpl) Discard() error { - return c.pool.Discard(c) -} diff --git a/weed/wdclient/resource_pool/multi_resource_pool.go b/weed/wdclient/resource_pool/multi_resource_pool.go deleted file mode 100644 index 9ac25526d..000000000 --- a/weed/wdclient/resource_pool/multi_resource_pool.go +++ /dev/null @@ -1,200 +0,0 @@ -package resource_pool - -import ( - "fmt" - "sync" - - "errors" -) - -// A resource pool implementation that manages multiple resource location -// entries. The handles to each resource location entry acts independently. -// For example "tcp localhost:11211" could act as memcache -// shard 0 and "tcp localhost:11212" could act as memcache shard 1. -type multiResourcePool struct { - options Options - - createPool func(Options) ResourcePool - - rwMutex sync.RWMutex - isLameDuck bool // guarded by rwMutex - // NOTE: the locationPools is guarded by rwMutex, but the pool entries - // are not. - locationPools map[string]ResourcePool -} - -// This returns a MultiResourcePool, which manages multiple -// resource location entries. The handles to each resource location -// entry acts independently. -// -// When createPool is nil, NewSimpleResourcePool is used as default. -func NewMultiResourcePool( - options Options, - createPool func(Options) ResourcePool) ResourcePool { - - if createPool == nil { - createPool = NewSimpleResourcePool - } - - return &multiResourcePool{ - options: options, - createPool: createPool, - rwMutex: sync.RWMutex{}, - isLameDuck: false, - locationPools: make(map[string]ResourcePool), - } -} - -// See ResourcePool for documentation. -func (p *multiResourcePool) NumActive() int32 { - total := int32(0) - - p.rwMutex.RLock() - defer p.rwMutex.RUnlock() - - for _, pool := range p.locationPools { - total += pool.NumActive() - } - return total -} - -// See ResourcePool for documentation. -func (p *multiResourcePool) ActiveHighWaterMark() int32 { - high := int32(0) - - p.rwMutex.RLock() - defer p.rwMutex.RUnlock() - - for _, pool := range p.locationPools { - val := pool.ActiveHighWaterMark() - if val > high { - high = val - } - } - return high -} - -// See ResourcePool for documentation. -func (p *multiResourcePool) NumIdle() int { - total := 0 - - p.rwMutex.RLock() - defer p.rwMutex.RUnlock() - - for _, pool := range p.locationPools { - total += pool.NumIdle() - } - return total -} - -// See ResourcePool for documentation. -func (p *multiResourcePool) Register(resourceLocation string) error { - if resourceLocation == "" { - return errors.New("Registering invalid resource location") - } - - p.rwMutex.Lock() - defer p.rwMutex.Unlock() - - if p.isLameDuck { - return fmt.Errorf( - "Cannot register %s to lame duck resource pool", - resourceLocation) - } - - if _, inMap := p.locationPools[resourceLocation]; inMap { - return nil - } - - pool := p.createPool(p.options) - if err := pool.Register(resourceLocation); err != nil { - return err - } - - p.locationPools[resourceLocation] = pool - return nil -} - -// See ResourcePool for documentation. -func (p *multiResourcePool) Unregister(resourceLocation string) error { - p.rwMutex.Lock() - defer p.rwMutex.Unlock() - - if pool, inMap := p.locationPools[resourceLocation]; inMap { - _ = pool.Unregister("") - pool.EnterLameDuckMode() - delete(p.locationPools, resourceLocation) - } - return nil -} - -func (p *multiResourcePool) ListRegistered() []string { - p.rwMutex.RLock() - defer p.rwMutex.RUnlock() - - result := make([]string, 0, len(p.locationPools)) - for key, _ := range p.locationPools { - result = append(result, key) - } - - return result -} - -// See ResourcePool for documentation. -func (p *multiResourcePool) Get( - resourceLocation string) (ManagedHandle, error) { - - pool := p.getPool(resourceLocation) - if pool == nil { - return nil, fmt.Errorf( - "%s is not registered in the resource pool", - resourceLocation) - } - return pool.Get(resourceLocation) -} - -// See ResourcePool for documentation. -func (p *multiResourcePool) Release(handle ManagedHandle) error { - pool := p.getPool(handle.ResourceLocation()) - if pool == nil { - return errors.New( - "Resource pool cannot take control of a handle owned " + - "by another resource pool") - } - - return pool.Release(handle) -} - -// See ResourcePool for documentation. -func (p *multiResourcePool) Discard(handle ManagedHandle) error { - pool := p.getPool(handle.ResourceLocation()) - if pool == nil { - return errors.New( - "Resource pool cannot take control of a handle owned " + - "by another resource pool") - } - - return pool.Discard(handle) -} - -// See ResourcePool for documentation. -func (p *multiResourcePool) EnterLameDuckMode() { - p.rwMutex.Lock() - defer p.rwMutex.Unlock() - - p.isLameDuck = true - - for _, pool := range p.locationPools { - pool.EnterLameDuckMode() - } -} - -func (p *multiResourcePool) getPool(resourceLocation string) ResourcePool { - p.rwMutex.RLock() - defer p.rwMutex.RUnlock() - - if pool, inMap := p.locationPools[resourceLocation]; inMap { - return pool - } - return nil -} diff --git a/weed/wdclient/resource_pool/resource_pool.go b/weed/wdclient/resource_pool/resource_pool.go deleted file mode 100644 index 26c433f50..000000000 --- a/weed/wdclient/resource_pool/resource_pool.go +++ /dev/null @@ -1,96 +0,0 @@ -package resource_pool - -import ( - "time" -) - -type Options struct { - // The maximum number of active resource handles per resource location. (A - // non-positive value indicates the number of active resource handles is - // unbounded). - MaxActiveHandles int32 - - // The maximum number of idle resource handles per resource location that - // are kept alive by the resource pool. - MaxIdleHandles uint32 - - // The maximum amount of time an idle resource handle can remain alive (if - // specified). - MaxIdleTime *time.Duration - - // This limits the number of concurrent Open calls (there's no limit when - // OpenMaxConcurrency is non-positive). - OpenMaxConcurrency int - - // This function creates a resource handle (e.g., a connection) for a - // resource location. The function must be thread-safe. - Open func(resourceLocation string) ( - handle interface{}, - err error) - - // This function destroys a resource handle and performs the necessary - // cleanup to free up resources. The function must be thread-safe. - Close func(handle interface{}) error - - // This specifies the now time function. When the function is non-nil, the - // resource pool will use the specified function instead of time.Now to - // generate the current time. - NowFunc func() time.Time -} - -func (o Options) getCurrentTime() time.Time { - if o.NowFunc == nil { - return time.Now() - } else { - return o.NowFunc() - } -} - -// A generic interface for managed resource pool. All resource pool -// implementations must be threadsafe. -type ResourcePool interface { - // This returns the number of active resource handles. - NumActive() int32 - - // This returns the highest number of actives handles for the entire - // lifetime of the pool. If the pool contains multiple sub-pools, the - // high water mark is the max of the sub-pools' high water marks. - ActiveHighWaterMark() int32 - - // This returns the number of alive idle handles. NOTE: This is only used - // for testing. - NumIdle() int - - // This associates a resource location to the resource pool; afterwhich, - // the user can get resource handles for the resource location. - Register(resourceLocation string) error - - // This dissociates a resource location from the resource pool; afterwhich, - // the user can no longer get resource handles for the resource location. - // If the given resource location corresponds to a sub-pool, the unregistered - // sub-pool will enter lame duck mode. - Unregister(resourceLocation string) error - - // This returns the list of registered resource location entries. - ListRegistered() []string - - // This gets an active resource handle from the resource pool. The - // handle will remain active until one of the following is called: - // 1. handle.Release() - // 2. handle.Discard() - // 3. pool.Release(handle) - // 4. pool.Discard(handle) - Get(key string) (ManagedHandle, error) - - // This releases an active resource handle back to the resource pool. - Release(handle ManagedHandle) error - - // This discards an active resource from the resource pool. - Discard(handle ManagedHandle) error - - // Enter the resource pool into lame duck mode. The resource pool - // will no longer return resource handles, and all idle resource handles - // are closed immediately (including active resource handles that are - // released back to the pool afterward). - EnterLameDuckMode() -} diff --git a/weed/wdclient/resource_pool/semaphore.go b/weed/wdclient/resource_pool/semaphore.go deleted file mode 100644 index 9bd6afc33..000000000 --- a/weed/wdclient/resource_pool/semaphore.go +++ /dev/null @@ -1,154 +0,0 @@ -package resource_pool - -import ( - "fmt" - "sync" - "sync/atomic" - "time" -) - -type Semaphore interface { - // Increment the semaphore counter by one. - Release() - - // Decrement the semaphore counter by one, and block if counter < 0 - Acquire() - - // Decrement the semaphore counter by one, and block if counter < 0 - // Wait for up to the given duration. Returns true if did not timeout - TryAcquire(timeout time.Duration) bool -} - -// A simple counting Semaphore. -type boundedSemaphore struct { - slots chan struct{} -} - -// Create a bounded semaphore. The count parameter must be a positive number. -// NOTE: The bounded semaphore will panic if the user tries to Release -// beyond the specified count. -func NewBoundedSemaphore(count uint) Semaphore { - sem := &boundedSemaphore{ - slots: make(chan struct{}, int(count)), - } - for i := 0; i < cap(sem.slots); i++ { - sem.slots <- struct{}{} - } - return sem -} - -// Acquire returns on successful acquisition. -func (sem *boundedSemaphore) Acquire() { - <-sem.slots -} - -// TryAcquire returns true if it acquires a resource slot within the -// timeout, false otherwise. -func (sem *boundedSemaphore) TryAcquire(timeout time.Duration) bool { - if timeout > 0 { - // Wait until we get a slot or timeout expires. - tm := time.NewTimer(timeout) - defer tm.Stop() - select { - case <-sem.slots: - return true - case <-tm.C: - // Timeout expired. In very rare cases this might happen even if - // there is a slot available, e.g. GC pause after we create the timer - // and select randomly picked this one out of the two available channels. - // We should do one final immediate check below. - } - } - - // Return true if we have a slot available immediately and false otherwise. - select { - case <-sem.slots: - return true - default: - return false - } -} - -// Release the acquired semaphore. You must not release more than you -// have acquired. -func (sem *boundedSemaphore) Release() { - select { - case sem.slots <- struct{}{}: - default: - // slots is buffered. If a send blocks, it indicates a programming - // error. - panic(fmt.Errorf("too many releases for boundedSemaphore")) - } -} - -// This returns an unbound counting semaphore with the specified initial count. -// The semaphore counter can be arbitrary large (i.e., Release can be called -// unlimited amount of times). -// -// NOTE: In general, users should use bounded semaphore since it is more -// efficient than unbounded semaphore. -func NewUnboundedSemaphore(initialCount int) Semaphore { - res := &unboundedSemaphore{ - counter: int64(initialCount), - } - res.cond.L = &res.lock - return res -} - -type unboundedSemaphore struct { - lock sync.Mutex - cond sync.Cond - counter int64 -} - -func (s *unboundedSemaphore) Release() { - s.lock.Lock() - s.counter += 1 - if s.counter > 0 { - // Not broadcasting here since it's unlike we can satisfy all waiting - // goroutines. Instead, we will Signal again if there are left over - // quota after Acquire, in case of lost wakeups. - s.cond.Signal() - } - s.lock.Unlock() -} - -func (s *unboundedSemaphore) Acquire() { - s.lock.Lock() - for s.counter < 1 { - s.cond.Wait() - } - s.counter -= 1 - if s.counter > 0 { - s.cond.Signal() - } - s.lock.Unlock() -} - -func (s *unboundedSemaphore) TryAcquire(timeout time.Duration) bool { - done := make(chan bool, 1) - // Gate used to communicate between the threads and decide what the result - // is. If the main thread decides, we have timed out, otherwise we succeed. - decided := new(int32) - atomic.StoreInt32(decided, 0) - go func() { - s.Acquire() - if atomic.SwapInt32(decided, 1) == 0 { - // Acquire won the race - done <- true - } else { - // If we already decided the result, and this thread did not win - s.Release() - } - }() - select { - case <-done: - return true - case <-time.After(timeout): - if atomic.SwapInt32(decided, 1) == 1 { - // The other thread already decided the result - return true - } - return false - } -} diff --git a/weed/wdclient/resource_pool/simple_resource_pool.go b/weed/wdclient/resource_pool/simple_resource_pool.go deleted file mode 100644 index 99f555a02..000000000 --- a/weed/wdclient/resource_pool/simple_resource_pool.go +++ /dev/null @@ -1,343 +0,0 @@ -package resource_pool - -import ( - "errors" - "fmt" - "sync" - "sync/atomic" - "time" -) - -type idleHandle struct { - handle interface{} - keepUntil *time.Time -} - -type TooManyHandles struct { - location string -} - -func (t TooManyHandles) Error() string { - return fmt.Sprintf("Too many handles to %s", t.location) -} - -type OpenHandleError struct { - location string - err error -} - -func (o OpenHandleError) Error() string { - return fmt.Sprintf("Failed to open resource handle: %s (%v)", o.location, o.err) -} - -// A resource pool implementation where all handles are associated to the -// same resource location. -type simpleResourcePool struct { - options Options - - numActive *int32 // atomic counter - - activeHighWaterMark *int32 // atomic / monotonically increasing value - - openTokens Semaphore - - mutex sync.Mutex - location string // guard by mutex - idleHandles []*idleHandle // guarded by mutex - isLameDuck bool // guarded by mutex -} - -// This returns a SimpleResourcePool, where all handles are associated to a -// single resource location. -func NewSimpleResourcePool(options Options) ResourcePool { - numActive := new(int32) - atomic.StoreInt32(numActive, 0) - - activeHighWaterMark := new(int32) - atomic.StoreInt32(activeHighWaterMark, 0) - - var tokens Semaphore - if options.OpenMaxConcurrency > 0 { - tokens = NewBoundedSemaphore(uint(options.OpenMaxConcurrency)) - } - - return &simpleResourcePool{ - location: "", - options: options, - numActive: numActive, - activeHighWaterMark: activeHighWaterMark, - openTokens: tokens, - mutex: sync.Mutex{}, - idleHandles: make([]*idleHandle, 0, 0), - isLameDuck: false, - } -} - -// See ResourcePool for documentation. -func (p *simpleResourcePool) NumActive() int32 { - return atomic.LoadInt32(p.numActive) -} - -// See ResourcePool for documentation. -func (p *simpleResourcePool) ActiveHighWaterMark() int32 { - return atomic.LoadInt32(p.activeHighWaterMark) -} - -// See ResourcePool for documentation. -func (p *simpleResourcePool) NumIdle() int { - p.mutex.Lock() - defer p.mutex.Unlock() - return len(p.idleHandles) -} - -// SimpleResourcePool can only register a single (network, address) entry. -// Register should be call before any Get calls. -func (p *simpleResourcePool) Register(resourceLocation string) error { - if resourceLocation == "" { - return errors.New("Invalid resource location") - } - - p.mutex.Lock() - defer p.mutex.Unlock() - - if p.isLameDuck { - return fmt.Errorf( - "cannot register %s to lame duck resource pool", - resourceLocation) - } - - if p.location == "" { - p.location = resourceLocation - return nil - } - return errors.New("SimpleResourcePool can only register one location") -} - -// SimpleResourcePool will enter lame duck mode upon calling Unregister. -func (p *simpleResourcePool) Unregister(resourceLocation string) error { - p.EnterLameDuckMode() - return nil -} - -func (p *simpleResourcePool) ListRegistered() []string { - p.mutex.Lock() - defer p.mutex.Unlock() - - if p.location != "" { - return []string{p.location} - } - return []string{} -} - -func (p *simpleResourcePool) getLocation() (string, error) { - p.mutex.Lock() - defer p.mutex.Unlock() - - if p.location == "" { - return "", fmt.Errorf( - "resource location is not set for SimpleResourcePool") - } - - if p.isLameDuck { - return "", fmt.Errorf( - "lame duck resource pool cannot return handles to %s", - p.location) - } - - return p.location, nil -} - -// This gets an active resource from the resource pool. Note that the -// resourceLocation argument is ignored (The handles are associated to the -// resource location provided by the first Register call). -func (p *simpleResourcePool) Get(unused string) (ManagedHandle, error) { - activeCount := atomic.AddInt32(p.numActive, 1) - if p.options.MaxActiveHandles > 0 && - activeCount > p.options.MaxActiveHandles { - - atomic.AddInt32(p.numActive, -1) - return nil, TooManyHandles{p.location} - } - - highest := atomic.LoadInt32(p.activeHighWaterMark) - for activeCount > highest && - !atomic.CompareAndSwapInt32( - p.activeHighWaterMark, - highest, - activeCount) { - - highest = atomic.LoadInt32(p.activeHighWaterMark) - } - - if h := p.getIdleHandle(); h != nil { - return h, nil - } - - location, err := p.getLocation() - if err != nil { - atomic.AddInt32(p.numActive, -1) - return nil, err - } - - if p.openTokens != nil { - // Current implementation does not wait for tokens to become available. - // If that causes availability hits, we could increase the wait, - // similar to simple_pool.go. - if p.openTokens.TryAcquire(0) { - defer p.openTokens.Release() - } else { - // We could not immediately acquire a token. - // Instead of waiting - atomic.AddInt32(p.numActive, -1) - return nil, OpenHandleError{ - p.location, errors.New("Open Error: reached OpenMaxConcurrency")} - } - } - - handle, err := p.options.Open(location) - if err != nil { - atomic.AddInt32(p.numActive, -1) - return nil, OpenHandleError{p.location, err} - } - - return NewManagedHandle(p.location, handle, p, p.options), nil -} - -// See ResourcePool for documentation. -func (p *simpleResourcePool) Release(handle ManagedHandle) error { - if pool, ok := handle.Owner().(*simpleResourcePool); !ok || pool != p { - return errors.New( - "Resource pool cannot take control of a handle owned " + - "by another resource pool") - } - - h := handle.ReleaseUnderlyingHandle() - if h != nil { - // We can unref either before or after queuing the idle handle. - // The advantage of unref-ing before queuing is that there is - // a higher chance of successful Get when number of active handles - // is close to the limit (but potentially more handle creation). - // The advantage of queuing before unref-ing is that there's a - // higher chance of reusing handle (but potentially more Get failures). - atomic.AddInt32(p.numActive, -1) - p.queueIdleHandles(h) - } - - return nil -} - -// See ResourcePool for documentation. -func (p *simpleResourcePool) Discard(handle ManagedHandle) error { - if pool, ok := handle.Owner().(*simpleResourcePool); !ok || pool != p { - return errors.New( - "Resource pool cannot take control of a handle owned " + - "by another resource pool") - } - - h := handle.ReleaseUnderlyingHandle() - if h != nil { - atomic.AddInt32(p.numActive, -1) - if err := p.options.Close(h); err != nil { - return fmt.Errorf("failed to close resource handle: %w", err) - } - } - return nil -} - -// See ResourcePool for documentation. -func (p *simpleResourcePool) EnterLameDuckMode() { - p.mutex.Lock() - - toClose := p.idleHandles - p.isLameDuck = true - p.idleHandles = []*idleHandle{} - - p.mutex.Unlock() - - p.closeHandles(toClose) -} - -// This returns an idle resource, if there is one. -func (p *simpleResourcePool) getIdleHandle() ManagedHandle { - var toClose []*idleHandle - defer func() { - // NOTE: Must keep the closure around to late bind the toClose slice. - p.closeHandles(toClose) - }() - - now := p.options.getCurrentTime() - - p.mutex.Lock() - defer p.mutex.Unlock() - - var i int - for i = 0; i < len(p.idleHandles); i++ { - idle := p.idleHandles[i] - if idle.keepUntil == nil || now.Before(*idle.keepUntil) { - break - } - } - if i > 0 { - toClose = p.idleHandles[0:i] - } - - if i < len(p.idleHandles) { - idle := p.idleHandles[i] - p.idleHandles = p.idleHandles[i+1:] - return NewManagedHandle(p.location, idle.handle, p, p.options) - } - - if len(p.idleHandles) > 0 { - p.idleHandles = []*idleHandle{} - } - return nil -} - -// This adds an idle resource to the pool. -func (p *simpleResourcePool) queueIdleHandles(handle interface{}) { - var toClose []*idleHandle - defer func() { - // NOTE: Must keep the closure around to late bind the toClose slice. - p.closeHandles(toClose) - }() - - now := p.options.getCurrentTime() - var keepUntil *time.Time - if p.options.MaxIdleTime != nil { - // NOTE: Assign to temp variable first to work around compiler bug - x := now.Add(*p.options.MaxIdleTime) - keepUntil = &x - } - - p.mutex.Lock() - defer p.mutex.Unlock() - - if p.isLameDuck { - toClose = []*idleHandle{ - {handle: handle}, - } - return - } - - p.idleHandles = append( - p.idleHandles, - &idleHandle{ - handle: handle, - keepUntil: keepUntil, - }) - - nIdleHandles := uint32(len(p.idleHandles)) - if nIdleHandles > p.options.MaxIdleHandles { - handlesToClose := nIdleHandles - p.options.MaxIdleHandles - toClose = p.idleHandles[0:handlesToClose] - p.idleHandles = p.idleHandles[handlesToClose:nIdleHandles] - } -} - -// Closes resources, at this point it is assumed that this resources -// are no longer referenced from the main idleHandles slice. -func (p *simpleResourcePool) closeHandles(handles []*idleHandle) { - for _, handle := range handles { - _ = p.options.Close(handle.handle) - } -} diff --git a/weed/weed.go b/weed/weed.go index f940cdacd..f83777bf5 100644 --- a/weed/weed.go +++ b/weed/weed.go @@ -196,17 +196,9 @@ func help(args []string) { var atexitFuncs []func() -func atexit(f func()) { - atexitFuncs = append(atexitFuncs, f) -} - func exit() { for _, f := range atexitFuncs { f() } os.Exit(exitStatus) } - -func debug(params ...interface{}) { - glog.V(4).Infoln(params...) -} diff --git a/weed/worker/registry.go b/weed/worker/registry.go index 0b40ddec4..fd6cecf30 100644 --- a/weed/worker/registry.go +++ b/weed/worker/registry.go @@ -1,9 +1,7 @@ package worker import ( - "fmt" "sync" - "time" "github.com/seaweedfs/seaweedfs/weed/worker/types" ) @@ -15,334 +13,6 @@ type Registry struct { mutex sync.RWMutex } -// NewRegistry creates a new worker registry -func NewRegistry() *Registry { - return &Registry{ - workers: make(map[string]*types.WorkerData), - stats: &types.RegistryStats{ - TotalWorkers: 0, - ActiveWorkers: 0, - BusyWorkers: 0, - IdleWorkers: 0, - TotalTasks: 0, - CompletedTasks: 0, - FailedTasks: 0, - StartTime: time.Now(), - }, - } -} - -// RegisterWorker registers a new worker -func (r *Registry) RegisterWorker(worker *types.WorkerData) error { - r.mutex.Lock() - defer r.mutex.Unlock() - - if _, exists := r.workers[worker.ID]; exists { - return fmt.Errorf("worker %s already registered", worker.ID) - } - - r.workers[worker.ID] = worker - r.updateStats() - return nil -} - -// UnregisterWorker removes a worker from the registry -func (r *Registry) UnregisterWorker(workerID string) error { - r.mutex.Lock() - defer r.mutex.Unlock() - - if _, exists := r.workers[workerID]; !exists { - return fmt.Errorf("worker %s not found", workerID) - } - - delete(r.workers, workerID) - r.updateStats() - return nil -} - -// GetWorker returns a worker by ID -func (r *Registry) GetWorker(workerID string) (*types.WorkerData, bool) { - r.mutex.RLock() - defer r.mutex.RUnlock() - - worker, exists := r.workers[workerID] - return worker, exists -} - -// ListWorkers returns all registered workers -func (r *Registry) ListWorkers() []*types.WorkerData { - r.mutex.RLock() - defer r.mutex.RUnlock() - - workers := make([]*types.WorkerData, 0, len(r.workers)) - for _, worker := range r.workers { - workers = append(workers, worker) - } - return workers -} - -// GetWorkersByCapability returns workers that support a specific capability -func (r *Registry) GetWorkersByCapability(capability types.TaskType) []*types.WorkerData { - r.mutex.RLock() - defer r.mutex.RUnlock() - - var workers []*types.WorkerData - for _, worker := range r.workers { - for _, cap := range worker.Capabilities { - if cap == capability { - workers = append(workers, worker) - break - } - } - } - return workers -} - -// GetAvailableWorkers returns workers that are available for new tasks -func (r *Registry) GetAvailableWorkers() []*types.WorkerData { - r.mutex.RLock() - defer r.mutex.RUnlock() - - var workers []*types.WorkerData - for _, worker := range r.workers { - if worker.Status == "active" && worker.CurrentLoad < worker.MaxConcurrent { - workers = append(workers, worker) - } - } - return workers -} - -// GetBestWorkerForTask returns the best worker for a specific task -func (r *Registry) GetBestWorkerForTask(taskType types.TaskType) *types.WorkerData { - r.mutex.RLock() - defer r.mutex.RUnlock() - - var bestWorker *types.WorkerData - var bestScore float64 - - for _, worker := range r.workers { - // Check if worker supports this task type - supportsTask := false - for _, cap := range worker.Capabilities { - if cap == taskType { - supportsTask = true - break - } - } - - if !supportsTask { - continue - } - - // Check if worker is available - if worker.Status != "active" || worker.CurrentLoad >= worker.MaxConcurrent { - continue - } - - // Calculate score based on current load and capacity - score := float64(worker.MaxConcurrent-worker.CurrentLoad) / float64(worker.MaxConcurrent) - if bestWorker == nil || score > bestScore { - bestWorker = worker - bestScore = score - } - } - - return bestWorker -} - -// UpdateWorkerHeartbeat updates the last heartbeat time for a worker -func (r *Registry) UpdateWorkerHeartbeat(workerID string) error { - r.mutex.Lock() - defer r.mutex.Unlock() - - worker, exists := r.workers[workerID] - if !exists { - return fmt.Errorf("worker %s not found", workerID) - } - - worker.LastHeartbeat = time.Now() - return nil -} - -// UpdateWorkerLoad updates the current load for a worker -func (r *Registry) UpdateWorkerLoad(workerID string, load int) error { - r.mutex.Lock() - defer r.mutex.Unlock() - - worker, exists := r.workers[workerID] - if !exists { - return fmt.Errorf("worker %s not found", workerID) - } - - worker.CurrentLoad = load - if load >= worker.MaxConcurrent { - worker.Status = "busy" - } else { - worker.Status = "active" - } - - r.updateStats() - return nil -} - -// UpdateWorkerStatus updates the status of a worker -func (r *Registry) UpdateWorkerStatus(workerID string, status string) error { - r.mutex.Lock() - defer r.mutex.Unlock() - - worker, exists := r.workers[workerID] - if !exists { - return fmt.Errorf("worker %s not found", workerID) - } - - worker.Status = status - r.updateStats() - return nil -} - -// CleanupStaleWorkers removes workers that haven't sent heartbeats recently -func (r *Registry) CleanupStaleWorkers(timeout time.Duration) int { - r.mutex.Lock() - defer r.mutex.Unlock() - - var removedCount int - cutoff := time.Now().Add(-timeout) - - for workerID, worker := range r.workers { - if worker.LastHeartbeat.Before(cutoff) { - delete(r.workers, workerID) - removedCount++ - } - } - - if removedCount > 0 { - r.updateStats() - } - - return removedCount -} - -// GetStats returns current registry statistics -func (r *Registry) GetStats() *types.RegistryStats { - r.mutex.RLock() - defer r.mutex.RUnlock() - - // Create a copy of the stats to avoid race conditions - stats := *r.stats - return &stats -} - -// updateStats updates the registry statistics (must be called with lock held) -func (r *Registry) updateStats() { - r.stats.TotalWorkers = len(r.workers) - r.stats.ActiveWorkers = 0 - r.stats.BusyWorkers = 0 - r.stats.IdleWorkers = 0 - - for _, worker := range r.workers { - switch worker.Status { - case "active": - if worker.CurrentLoad > 0 { - r.stats.ActiveWorkers++ - } else { - r.stats.IdleWorkers++ - } - case "busy": - r.stats.BusyWorkers++ - } - } - - r.stats.Uptime = time.Since(r.stats.StartTime) - r.stats.LastUpdated = time.Now() -} - -// GetTaskCapabilities returns all task capabilities available in the registry -func (r *Registry) GetTaskCapabilities() []types.TaskType { - r.mutex.RLock() - defer r.mutex.RUnlock() - - capabilitySet := make(map[types.TaskType]bool) - for _, worker := range r.workers { - for _, cap := range worker.Capabilities { - capabilitySet[cap] = true - } - } - - var capabilities []types.TaskType - for cap := range capabilitySet { - capabilities = append(capabilities, cap) - } - - return capabilities -} - -// GetWorkersByStatus returns workers filtered by status -func (r *Registry) GetWorkersByStatus(status string) []*types.WorkerData { - r.mutex.RLock() - defer r.mutex.RUnlock() - - var workers []*types.WorkerData - for _, worker := range r.workers { - if worker.Status == status { - workers = append(workers, worker) - } - } - return workers -} - -// GetWorkerCount returns the total number of registered workers -func (r *Registry) GetWorkerCount() int { - r.mutex.RLock() - defer r.mutex.RUnlock() - return len(r.workers) -} - -// GetWorkerIDs returns all worker IDs -func (r *Registry) GetWorkerIDs() []string { - r.mutex.RLock() - defer r.mutex.RUnlock() - - ids := make([]string, 0, len(r.workers)) - for id := range r.workers { - ids = append(ids, id) - } - return ids -} - -// GetWorkerSummary returns a summary of all workers -func (r *Registry) GetWorkerSummary() *types.WorkerSummary { - r.mutex.RLock() - defer r.mutex.RUnlock() - - summary := &types.WorkerSummary{ - TotalWorkers: len(r.workers), - ByStatus: make(map[string]int), - ByCapability: make(map[types.TaskType]int), - TotalLoad: 0, - MaxCapacity: 0, - } - - for _, worker := range r.workers { - summary.ByStatus[worker.Status]++ - summary.TotalLoad += worker.CurrentLoad - summary.MaxCapacity += worker.MaxConcurrent - - for _, cap := range worker.Capabilities { - summary.ByCapability[cap]++ - } - } - - return summary -} - // Default global registry instance var defaultRegistry *Registry var registryOnce sync.Once - -// GetDefaultRegistry returns the default global registry -func GetDefaultRegistry() *Registry { - registryOnce.Do(func() { - defaultRegistry = NewRegistry() - }) - return defaultRegistry -} diff --git a/weed/worker/tasks/balance/monitoring.go b/weed/worker/tasks/balance/monitoring.go deleted file mode 100644 index 517de2484..000000000 --- a/weed/worker/tasks/balance/monitoring.go +++ /dev/null @@ -1,138 +0,0 @@ -package balance - -import ( - "sync" - "time" -) - -// BalanceMetrics contains balance-specific monitoring data -type BalanceMetrics struct { - // Execution metrics - VolumesBalanced int64 `json:"volumes_balanced"` - TotalDataTransferred int64 `json:"total_data_transferred"` - AverageImbalance float64 `json:"average_imbalance"` - LastBalanceTime time.Time `json:"last_balance_time"` - - // Performance metrics - AverageTransferSpeed float64 `json:"average_transfer_speed_mbps"` - TotalExecutionTime int64 `json:"total_execution_time_seconds"` - SuccessfulOperations int64 `json:"successful_operations"` - FailedOperations int64 `json:"failed_operations"` - - // Current task metrics - CurrentImbalanceScore float64 `json:"current_imbalance_score"` - PlannedDestinations int `json:"planned_destinations"` - - mutex sync.RWMutex -} - -// NewBalanceMetrics creates a new balance metrics instance -func NewBalanceMetrics() *BalanceMetrics { - return &BalanceMetrics{ - LastBalanceTime: time.Now(), - } -} - -// RecordVolumeBalanced records a successful volume balance operation -func (m *BalanceMetrics) RecordVolumeBalanced(volumeSize int64, transferTime time.Duration) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.VolumesBalanced++ - m.TotalDataTransferred += volumeSize - m.SuccessfulOperations++ - m.LastBalanceTime = time.Now() - m.TotalExecutionTime += int64(transferTime.Seconds()) - - // Calculate average transfer speed (MB/s) - if transferTime > 0 { - speedMBps := float64(volumeSize) / (1024 * 1024) / transferTime.Seconds() - if m.AverageTransferSpeed == 0 { - m.AverageTransferSpeed = speedMBps - } else { - // Exponential moving average - m.AverageTransferSpeed = 0.8*m.AverageTransferSpeed + 0.2*speedMBps - } - } -} - -// RecordFailure records a failed balance operation -func (m *BalanceMetrics) RecordFailure() { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.FailedOperations++ -} - -// UpdateImbalanceScore updates the current cluster imbalance score -func (m *BalanceMetrics) UpdateImbalanceScore(score float64) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.CurrentImbalanceScore = score - - // Update average imbalance with exponential moving average - if m.AverageImbalance == 0 { - m.AverageImbalance = score - } else { - m.AverageImbalance = 0.9*m.AverageImbalance + 0.1*score - } -} - -// SetPlannedDestinations sets the number of planned destinations -func (m *BalanceMetrics) SetPlannedDestinations(count int) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.PlannedDestinations = count -} - -// GetMetrics returns a copy of the current metrics (without the mutex) -func (m *BalanceMetrics) GetMetrics() BalanceMetrics { - m.mutex.RLock() - defer m.mutex.RUnlock() - - // Create a copy without the mutex to avoid copying lock value - return BalanceMetrics{ - VolumesBalanced: m.VolumesBalanced, - TotalDataTransferred: m.TotalDataTransferred, - AverageImbalance: m.AverageImbalance, - LastBalanceTime: m.LastBalanceTime, - AverageTransferSpeed: m.AverageTransferSpeed, - TotalExecutionTime: m.TotalExecutionTime, - SuccessfulOperations: m.SuccessfulOperations, - FailedOperations: m.FailedOperations, - CurrentImbalanceScore: m.CurrentImbalanceScore, - PlannedDestinations: m.PlannedDestinations, - } -} - -// GetSuccessRate returns the success rate as a percentage -func (m *BalanceMetrics) GetSuccessRate() float64 { - m.mutex.RLock() - defer m.mutex.RUnlock() - - total := m.SuccessfulOperations + m.FailedOperations - if total == 0 { - return 100.0 - } - return float64(m.SuccessfulOperations) / float64(total) * 100.0 -} - -// Reset resets all metrics to zero -func (m *BalanceMetrics) Reset() { - m.mutex.Lock() - defer m.mutex.Unlock() - - *m = BalanceMetrics{ - LastBalanceTime: time.Now(), - } -} - -// Global metrics instance for balance tasks -var globalBalanceMetrics = NewBalanceMetrics() - -// GetGlobalBalanceMetrics returns the global balance metrics instance -func GetGlobalBalanceMetrics() *BalanceMetrics { - return globalBalanceMetrics -} diff --git a/weed/worker/tasks/base/registration.go b/weed/worker/tasks/base/registration.go index f69db6b48..12335eb15 100644 --- a/weed/worker/tasks/base/registration.go +++ b/weed/worker/tasks/base/registration.go @@ -74,26 +74,6 @@ type GenericUIProvider struct { taskDef *TaskDefinition } -// GetTaskType returns the task type -func (ui *GenericUIProvider) GetTaskType() types.TaskType { - return ui.taskDef.Type -} - -// GetDisplayName returns the human-readable name -func (ui *GenericUIProvider) GetDisplayName() string { - return ui.taskDef.DisplayName -} - -// GetDescription returns a description of what this task does -func (ui *GenericUIProvider) GetDescription() string { - return ui.taskDef.Description -} - -// GetIcon returns the icon CSS class for this task type -func (ui *GenericUIProvider) GetIcon() string { - return ui.taskDef.Icon -} - // GetCurrentConfig returns current config as TaskConfig func (ui *GenericUIProvider) GetCurrentConfig() types.TaskConfig { return ui.taskDef.Config diff --git a/weed/worker/tasks/base/task_definition.go b/weed/worker/tasks/base/task_definition.go index 5ebc2a4b6..04c5de06f 100644 --- a/weed/worker/tasks/base/task_definition.go +++ b/weed/worker/tasks/base/task_definition.go @@ -2,8 +2,6 @@ package base import ( "fmt" - "reflect" - "strings" "time" "github.com/seaweedfs/seaweedfs/weed/admin/config" @@ -75,108 +73,6 @@ func (c *BaseConfig) Validate() error { return nil } -// StructToMap converts any struct to a map using reflection -func StructToMap(obj interface{}) map[string]interface{} { - result := make(map[string]interface{}) - val := reflect.ValueOf(obj) - - // Handle pointer to struct - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - - if val.Kind() != reflect.Struct { - return result - } - - typ := val.Type() - - for i := 0; i < val.NumField(); i++ { - field := val.Field(i) - fieldType := typ.Field(i) - - // Skip unexported fields - if !field.CanInterface() { - continue - } - - // Handle embedded structs recursively (before JSON tag check) - if field.Kind() == reflect.Struct && fieldType.Anonymous { - embeddedMap := StructToMap(field.Interface()) - for k, v := range embeddedMap { - result[k] = v - } - continue - } - - // Get JSON tag name - jsonTag := fieldType.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - continue - } - - // Remove options like ",omitempty" - if commaIdx := strings.Index(jsonTag, ","); commaIdx >= 0 { - jsonTag = jsonTag[:commaIdx] - } - - result[jsonTag] = field.Interface() - } - return result -} - -// MapToStruct loads data from map into struct using reflection -func MapToStruct(data map[string]interface{}, obj interface{}) error { - val := reflect.ValueOf(obj) - - // Must be pointer to struct - if val.Kind() != reflect.Ptr || val.Elem().Kind() != reflect.Struct { - return fmt.Errorf("obj must be pointer to struct") - } - - val = val.Elem() - typ := val.Type() - - for i := 0; i < val.NumField(); i++ { - field := val.Field(i) - fieldType := typ.Field(i) - - // Skip unexported fields - if !field.CanSet() { - continue - } - - // Handle embedded structs recursively (before JSON tag check) - if field.Kind() == reflect.Struct && fieldType.Anonymous { - err := MapToStruct(data, field.Addr().Interface()) - if err != nil { - return err - } - continue - } - - // Get JSON tag name - jsonTag := fieldType.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - continue - } - - // Remove options like ",omitempty" - if commaIdx := strings.Index(jsonTag, ","); commaIdx >= 0 { - jsonTag = jsonTag[:commaIdx] - } - - if value, exists := data[jsonTag]; exists { - err := setFieldValue(field, value) - if err != nil { - return fmt.Errorf("failed to set field %s: %v", jsonTag, err) - } - } - } - - return nil -} - // ToMap converts config to map using reflection // ToTaskPolicy converts BaseConfig to protobuf (partial implementation) // Note: Concrete implementations should override this to include task-specific config @@ -207,66 +103,3 @@ func (c *BaseConfig) ApplySchemaDefaults(schema *config.Schema) error { // Use reflection-based approach for BaseConfig since it needs to handle embedded structs return schema.ApplyDefaultsToProtobuf(c) } - -// setFieldValue sets a field value with type conversion -func setFieldValue(field reflect.Value, value interface{}) error { - if value == nil { - return nil - } - - valueVal := reflect.ValueOf(value) - fieldType := field.Type() - valueType := valueVal.Type() - - // Direct assignment if types match - if valueType.AssignableTo(fieldType) { - field.Set(valueVal) - return nil - } - - // Type conversion for common cases - switch fieldType.Kind() { - case reflect.Bool: - if b, ok := value.(bool); ok { - field.SetBool(b) - } else { - return fmt.Errorf("cannot convert %T to bool", value) - } - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch v := value.(type) { - case int: - field.SetInt(int64(v)) - case int32: - field.SetInt(int64(v)) - case int64: - field.SetInt(v) - case float64: - field.SetInt(int64(v)) - default: - return fmt.Errorf("cannot convert %T to int", value) - } - case reflect.Float32, reflect.Float64: - switch v := value.(type) { - case float32: - field.SetFloat(float64(v)) - case float64: - field.SetFloat(v) - case int: - field.SetFloat(float64(v)) - case int64: - field.SetFloat(float64(v)) - default: - return fmt.Errorf("cannot convert %T to float", value) - } - case reflect.String: - if s, ok := value.(string); ok { - field.SetString(s) - } else { - return fmt.Errorf("cannot convert %T to string", value) - } - default: - return fmt.Errorf("unsupported field type %s", fieldType.Kind()) - } - - return nil -} diff --git a/weed/worker/tasks/base/task_definition_test.go b/weed/worker/tasks/base/task_definition_test.go deleted file mode 100644 index a0a0a5a24..000000000 --- a/weed/worker/tasks/base/task_definition_test.go +++ /dev/null @@ -1,338 +0,0 @@ -package base - -import ( - "reflect" - "testing" -) - -// Test structs that mirror the actual configuration structure -type TestBaseConfig struct { - Enabled bool `json:"enabled"` - ScanIntervalSeconds int `json:"scan_interval_seconds"` - MaxConcurrent int `json:"max_concurrent"` -} - -type TestTaskConfig struct { - TestBaseConfig - TaskSpecificField float64 `json:"task_specific_field"` - AnotherSpecificField string `json:"another_specific_field"` -} - -type TestNestedConfig struct { - TestBaseConfig - NestedStruct struct { - NestedField string `json:"nested_field"` - } `json:"nested_struct"` - TaskField int `json:"task_field"` -} - -func TestStructToMap_WithEmbeddedStruct(t *testing.T) { - // Test case 1: Basic embedded struct - config := &TestTaskConfig{ - TestBaseConfig: TestBaseConfig{ - Enabled: true, - ScanIntervalSeconds: 1800, - MaxConcurrent: 3, - }, - TaskSpecificField: 0.25, - AnotherSpecificField: "test_value", - } - - result := StructToMap(config) - - // Verify all fields are present - expectedFields := map[string]interface{}{ - "enabled": true, - "scan_interval_seconds": 1800, - "max_concurrent": 3, - "task_specific_field": 0.25, - "another_specific_field": "test_value", - } - - if len(result) != len(expectedFields) { - t.Errorf("Expected %d fields, got %d. Result: %+v", len(expectedFields), len(result), result) - } - - for key, expectedValue := range expectedFields { - if actualValue, exists := result[key]; !exists { - t.Errorf("Missing field: %s", key) - } else if !reflect.DeepEqual(actualValue, expectedValue) { - t.Errorf("Field %s: expected %v (%T), got %v (%T)", key, expectedValue, expectedValue, actualValue, actualValue) - } - } -} - -func TestStructToMap_WithNestedStruct(t *testing.T) { - config := &TestNestedConfig{ - TestBaseConfig: TestBaseConfig{ - Enabled: false, - ScanIntervalSeconds: 3600, - MaxConcurrent: 1, - }, - NestedStruct: struct { - NestedField string `json:"nested_field"` - }{ - NestedField: "nested_value", - }, - TaskField: 42, - } - - result := StructToMap(config) - - // Verify embedded struct fields are included - if enabled, exists := result["enabled"]; !exists || enabled != false { - t.Errorf("Expected enabled=false from embedded struct, got %v", enabled) - } - - if scanInterval, exists := result["scan_interval_seconds"]; !exists || scanInterval != 3600 { - t.Errorf("Expected scan_interval_seconds=3600 from embedded struct, got %v", scanInterval) - } - - if maxConcurrent, exists := result["max_concurrent"]; !exists || maxConcurrent != 1 { - t.Errorf("Expected max_concurrent=1 from embedded struct, got %v", maxConcurrent) - } - - // Verify regular fields are included - if taskField, exists := result["task_field"]; !exists || taskField != 42 { - t.Errorf("Expected task_field=42, got %v", taskField) - } - - // Verify nested struct is included as a whole - if nestedStruct, exists := result["nested_struct"]; !exists { - t.Errorf("Missing nested_struct field") - } else { - // The nested struct should be included as-is, not flattened - if nested, ok := nestedStruct.(struct { - NestedField string `json:"nested_field"` - }); !ok || nested.NestedField != "nested_value" { - t.Errorf("Expected nested_struct with NestedField='nested_value', got %v", nestedStruct) - } - } -} - -func TestMapToStruct_WithEmbeddedStruct(t *testing.T) { - // Test data with all fields including embedded struct fields - data := map[string]interface{}{ - "enabled": true, - "scan_interval_seconds": 2400, - "max_concurrent": 5, - "task_specific_field": 0.15, - "another_specific_field": "updated_value", - } - - config := &TestTaskConfig{} - err := MapToStruct(data, config) - - if err != nil { - t.Fatalf("MapToStruct failed: %v", err) - } - - // Verify embedded struct fields were set - if config.Enabled != true { - t.Errorf("Expected Enabled=true, got %v", config.Enabled) - } - - if config.ScanIntervalSeconds != 2400 { - t.Errorf("Expected ScanIntervalSeconds=2400, got %v", config.ScanIntervalSeconds) - } - - if config.MaxConcurrent != 5 { - t.Errorf("Expected MaxConcurrent=5, got %v", config.MaxConcurrent) - } - - // Verify regular fields were set - if config.TaskSpecificField != 0.15 { - t.Errorf("Expected TaskSpecificField=0.15, got %v", config.TaskSpecificField) - } - - if config.AnotherSpecificField != "updated_value" { - t.Errorf("Expected AnotherSpecificField='updated_value', got %v", config.AnotherSpecificField) - } -} - -func TestMapToStruct_PartialData(t *testing.T) { - // Test with only some fields present (simulating form data) - data := map[string]interface{}{ - "enabled": false, - "max_concurrent": 2, - "task_specific_field": 0.30, - } - - // Start with some initial values - config := &TestTaskConfig{ - TestBaseConfig: TestBaseConfig{ - Enabled: true, - ScanIntervalSeconds: 1800, - MaxConcurrent: 1, - }, - TaskSpecificField: 0.20, - AnotherSpecificField: "initial_value", - } - - err := MapToStruct(data, config) - - if err != nil { - t.Fatalf("MapToStruct failed: %v", err) - } - - // Verify updated fields - if config.Enabled != false { - t.Errorf("Expected Enabled=false (updated), got %v", config.Enabled) - } - - if config.MaxConcurrent != 2 { - t.Errorf("Expected MaxConcurrent=2 (updated), got %v", config.MaxConcurrent) - } - - if config.TaskSpecificField != 0.30 { - t.Errorf("Expected TaskSpecificField=0.30 (updated), got %v", config.TaskSpecificField) - } - - // Verify unchanged fields remain the same - if config.ScanIntervalSeconds != 1800 { - t.Errorf("Expected ScanIntervalSeconds=1800 (unchanged), got %v", config.ScanIntervalSeconds) - } - - if config.AnotherSpecificField != "initial_value" { - t.Errorf("Expected AnotherSpecificField='initial_value' (unchanged), got %v", config.AnotherSpecificField) - } -} - -func TestRoundTripSerialization(t *testing.T) { - // Test complete round-trip: struct -> map -> struct - original := &TestTaskConfig{ - TestBaseConfig: TestBaseConfig{ - Enabled: true, - ScanIntervalSeconds: 3600, - MaxConcurrent: 4, - }, - TaskSpecificField: 0.18, - AnotherSpecificField: "round_trip_test", - } - - // Convert to map - dataMap := StructToMap(original) - - // Convert back to struct - roundTrip := &TestTaskConfig{} - err := MapToStruct(dataMap, roundTrip) - - if err != nil { - t.Fatalf("Round-trip MapToStruct failed: %v", err) - } - - // Verify all fields match - if !reflect.DeepEqual(original.TestBaseConfig, roundTrip.TestBaseConfig) { - t.Errorf("BaseConfig mismatch:\nOriginal: %+v\nRound-trip: %+v", original.TestBaseConfig, roundTrip.TestBaseConfig) - } - - if original.TaskSpecificField != roundTrip.TaskSpecificField { - t.Errorf("TaskSpecificField mismatch: %v != %v", original.TaskSpecificField, roundTrip.TaskSpecificField) - } - - if original.AnotherSpecificField != roundTrip.AnotherSpecificField { - t.Errorf("AnotherSpecificField mismatch: %v != %v", original.AnotherSpecificField, roundTrip.AnotherSpecificField) - } -} - -func TestStructToMap_EmptyStruct(t *testing.T) { - config := &TestTaskConfig{} - result := StructToMap(config) - - // Should still include all fields, even with zero values - expectedFields := []string{"enabled", "scan_interval_seconds", "max_concurrent", "task_specific_field", "another_specific_field"} - - for _, field := range expectedFields { - if _, exists := result[field]; !exists { - t.Errorf("Missing field: %s", field) - } - } -} - -func TestStructToMap_NilPointer(t *testing.T) { - var config *TestTaskConfig = nil - result := StructToMap(config) - - if len(result) != 0 { - t.Errorf("Expected empty map for nil pointer, got %+v", result) - } -} - -func TestMapToStruct_InvalidInput(t *testing.T) { - data := map[string]interface{}{ - "enabled": "not_a_bool", // Wrong type - } - - config := &TestTaskConfig{} - err := MapToStruct(data, config) - - if err == nil { - t.Errorf("Expected error for invalid input type, but got none") - } -} - -func TestMapToStruct_NonPointer(t *testing.T) { - data := map[string]interface{}{ - "enabled": true, - } - - config := TestTaskConfig{} // Not a pointer - err := MapToStruct(data, config) - - if err == nil { - t.Errorf("Expected error for non-pointer input, but got none") - } -} - -// Benchmark tests to ensure performance is reasonable -func BenchmarkStructToMap(b *testing.B) { - config := &TestTaskConfig{ - TestBaseConfig: TestBaseConfig{ - Enabled: true, - ScanIntervalSeconds: 1800, - MaxConcurrent: 3, - }, - TaskSpecificField: 0.25, - AnotherSpecificField: "benchmark_test", - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = StructToMap(config) - } -} - -func BenchmarkMapToStruct(b *testing.B) { - data := map[string]interface{}{ - "enabled": true, - "scan_interval_seconds": 1800, - "max_concurrent": 3, - "task_specific_field": 0.25, - "another_specific_field": "benchmark_test", - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - config := &TestTaskConfig{} - _ = MapToStruct(data, config) - } -} - -func BenchmarkRoundTrip(b *testing.B) { - original := &TestTaskConfig{ - TestBaseConfig: TestBaseConfig{ - Enabled: true, - ScanIntervalSeconds: 1800, - MaxConcurrent: 3, - }, - TaskSpecificField: 0.25, - AnotherSpecificField: "benchmark_test", - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - dataMap := StructToMap(original) - roundTrip := &TestTaskConfig{} - _ = MapToStruct(dataMap, roundTrip) - } -} diff --git a/weed/worker/tasks/erasure_coding/monitoring.go b/weed/worker/tasks/erasure_coding/monitoring.go deleted file mode 100644 index 799eb62c8..000000000 --- a/weed/worker/tasks/erasure_coding/monitoring.go +++ /dev/null @@ -1,229 +0,0 @@ -package erasure_coding - -import ( - "sync" - "time" -) - -// ErasureCodingMetrics contains erasure coding-specific monitoring data -type ErasureCodingMetrics struct { - // Execution metrics - VolumesEncoded int64 `json:"volumes_encoded"` - TotalShardsCreated int64 `json:"total_shards_created"` - TotalDataProcessed int64 `json:"total_data_processed"` - TotalSourcesRemoved int64 `json:"total_sources_removed"` - LastEncodingTime time.Time `json:"last_encoding_time"` - - // Performance metrics - AverageEncodingTime int64 `json:"average_encoding_time_seconds"` - AverageShardSize int64 `json:"average_shard_size"` - AverageDataShards int `json:"average_data_shards"` - AverageParityShards int `json:"average_parity_shards"` - SuccessfulOperations int64 `json:"successful_operations"` - FailedOperations int64 `json:"failed_operations"` - - // Distribution metrics - ShardsPerDataCenter map[string]int64 `json:"shards_per_datacenter"` - ShardsPerRack map[string]int64 `json:"shards_per_rack"` - PlacementSuccessRate float64 `json:"placement_success_rate"` - - // Current task metrics - CurrentVolumeSize int64 `json:"current_volume_size"` - CurrentShardCount int `json:"current_shard_count"` - VolumesPendingEncoding int `json:"volumes_pending_encoding"` - - mutex sync.RWMutex -} - -// NewErasureCodingMetrics creates a new erasure coding metrics instance -func NewErasureCodingMetrics() *ErasureCodingMetrics { - return &ErasureCodingMetrics{ - LastEncodingTime: time.Now(), - ShardsPerDataCenter: make(map[string]int64), - ShardsPerRack: make(map[string]int64), - } -} - -// RecordVolumeEncoded records a successful volume encoding operation -func (m *ErasureCodingMetrics) RecordVolumeEncoded(volumeSize int64, shardsCreated int, dataShards int, parityShards int, encodingTime time.Duration, sourceRemoved bool) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.VolumesEncoded++ - m.TotalShardsCreated += int64(shardsCreated) - m.TotalDataProcessed += volumeSize - m.SuccessfulOperations++ - m.LastEncodingTime = time.Now() - - if sourceRemoved { - m.TotalSourcesRemoved++ - } - - // Update average encoding time - if m.AverageEncodingTime == 0 { - m.AverageEncodingTime = int64(encodingTime.Seconds()) - } else { - // Exponential moving average - newTime := int64(encodingTime.Seconds()) - m.AverageEncodingTime = (m.AverageEncodingTime*4 + newTime) / 5 - } - - // Update average shard size - if shardsCreated > 0 { - avgShardSize := volumeSize / int64(shardsCreated) - if m.AverageShardSize == 0 { - m.AverageShardSize = avgShardSize - } else { - m.AverageShardSize = (m.AverageShardSize*4 + avgShardSize) / 5 - } - } - - // Update average data/parity shards - if m.AverageDataShards == 0 { - m.AverageDataShards = dataShards - m.AverageParityShards = parityShards - } else { - m.AverageDataShards = (m.AverageDataShards*4 + dataShards) / 5 - m.AverageParityShards = (m.AverageParityShards*4 + parityShards) / 5 - } -} - -// RecordFailure records a failed erasure coding operation -func (m *ErasureCodingMetrics) RecordFailure() { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.FailedOperations++ -} - -// RecordShardPlacement records shard placement for distribution tracking -func (m *ErasureCodingMetrics) RecordShardPlacement(dataCenter string, rack string) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.ShardsPerDataCenter[dataCenter]++ - rackKey := dataCenter + ":" + rack - m.ShardsPerRack[rackKey]++ -} - -// UpdateCurrentVolumeInfo updates current volume processing information -func (m *ErasureCodingMetrics) UpdateCurrentVolumeInfo(volumeSize int64, shardCount int) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.CurrentVolumeSize = volumeSize - m.CurrentShardCount = shardCount -} - -// SetVolumesPendingEncoding sets the number of volumes pending encoding -func (m *ErasureCodingMetrics) SetVolumesPendingEncoding(count int) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.VolumesPendingEncoding = count -} - -// UpdatePlacementSuccessRate updates the placement success rate -func (m *ErasureCodingMetrics) UpdatePlacementSuccessRate(rate float64) { - m.mutex.Lock() - defer m.mutex.Unlock() - - if m.PlacementSuccessRate == 0 { - m.PlacementSuccessRate = rate - } else { - // Exponential moving average - m.PlacementSuccessRate = 0.8*m.PlacementSuccessRate + 0.2*rate - } -} - -// GetMetrics returns a copy of the current metrics (without the mutex) -func (m *ErasureCodingMetrics) GetMetrics() ErasureCodingMetrics { - m.mutex.RLock() - defer m.mutex.RUnlock() - - // Create deep copy of maps - shardsPerDC := make(map[string]int64) - for k, v := range m.ShardsPerDataCenter { - shardsPerDC[k] = v - } - - shardsPerRack := make(map[string]int64) - for k, v := range m.ShardsPerRack { - shardsPerRack[k] = v - } - - // Create a copy without the mutex to avoid copying lock value - return ErasureCodingMetrics{ - VolumesEncoded: m.VolumesEncoded, - TotalShardsCreated: m.TotalShardsCreated, - TotalDataProcessed: m.TotalDataProcessed, - TotalSourcesRemoved: m.TotalSourcesRemoved, - LastEncodingTime: m.LastEncodingTime, - AverageEncodingTime: m.AverageEncodingTime, - AverageShardSize: m.AverageShardSize, - AverageDataShards: m.AverageDataShards, - AverageParityShards: m.AverageParityShards, - SuccessfulOperations: m.SuccessfulOperations, - FailedOperations: m.FailedOperations, - ShardsPerDataCenter: shardsPerDC, - ShardsPerRack: shardsPerRack, - PlacementSuccessRate: m.PlacementSuccessRate, - CurrentVolumeSize: m.CurrentVolumeSize, - CurrentShardCount: m.CurrentShardCount, - VolumesPendingEncoding: m.VolumesPendingEncoding, - } -} - -// GetSuccessRate returns the success rate as a percentage -func (m *ErasureCodingMetrics) GetSuccessRate() float64 { - m.mutex.RLock() - defer m.mutex.RUnlock() - - total := m.SuccessfulOperations + m.FailedOperations - if total == 0 { - return 100.0 - } - return float64(m.SuccessfulOperations) / float64(total) * 100.0 -} - -// GetAverageDataProcessed returns the average data processed per volume -func (m *ErasureCodingMetrics) GetAverageDataProcessed() float64 { - m.mutex.RLock() - defer m.mutex.RUnlock() - - if m.VolumesEncoded == 0 { - return 0 - } - return float64(m.TotalDataProcessed) / float64(m.VolumesEncoded) -} - -// GetSourceRemovalRate returns the percentage of sources removed after encoding -func (m *ErasureCodingMetrics) GetSourceRemovalRate() float64 { - m.mutex.RLock() - defer m.mutex.RUnlock() - - if m.VolumesEncoded == 0 { - return 0 - } - return float64(m.TotalSourcesRemoved) / float64(m.VolumesEncoded) * 100.0 -} - -// Reset resets all metrics to zero -func (m *ErasureCodingMetrics) Reset() { - m.mutex.Lock() - defer m.mutex.Unlock() - - *m = ErasureCodingMetrics{ - LastEncodingTime: time.Now(), - ShardsPerDataCenter: make(map[string]int64), - ShardsPerRack: make(map[string]int64), - } -} - -// Global metrics instance for erasure coding tasks -var globalErasureCodingMetrics = NewErasureCodingMetrics() - -// GetGlobalErasureCodingMetrics returns the global erasure coding metrics instance -func GetGlobalErasureCodingMetrics() *ErasureCodingMetrics { - return globalErasureCodingMetrics -} diff --git a/weed/worker/tasks/registry.go b/weed/worker/tasks/registry.go index 626a54a14..fb1c477cf 100644 --- a/weed/worker/tasks/registry.go +++ b/weed/worker/tasks/registry.go @@ -64,51 +64,6 @@ func AutoRegisterUI(registerFunc func(*types.UIRegistry)) { glog.V(1).Infof("Auto-registered task UI provider") } -// SetDefaultCapabilitiesFromRegistry sets the default worker capabilities -// based on all registered task types -func SetDefaultCapabilitiesFromRegistry() { - typesRegistry := GetGlobalTypesRegistry() - - var capabilities []types.TaskType - for taskType := range typesRegistry.GetAllDetectors() { - capabilities = append(capabilities, taskType) - } - - // Set the default capabilities in the types package - types.SetDefaultCapabilities(capabilities) - - glog.V(1).Infof("Set default worker capabilities from registry: %v", capabilities) -} - -// BuildMaintenancePolicyFromTasks creates a maintenance policy with default configurations -// from all registered tasks using their UI providers -func BuildMaintenancePolicyFromTasks() *types.MaintenancePolicy { - policy := types.NewMaintenancePolicy() - - // Get all registered task types from the UI registry - uiRegistry := GetGlobalUIRegistry() - - for taskType, provider := range uiRegistry.GetAllProviders() { - // Get the default configuration from the UI provider - defaultConfig := provider.GetCurrentConfig() - - // Set the configuration in the policy - policy.SetTaskConfig(taskType, defaultConfig) - - glog.V(3).Infof("Added default config for task type %s to policy", taskType) - } - - glog.V(2).Infof("Built maintenance policy with %d task configurations", len(policy.TaskConfigs)) - return policy -} - -// SetMaintenancePolicyFromTasks sets the default maintenance policy from registered tasks -func SetMaintenancePolicyFromTasks() { - // This function can be called to initialize the policy from registered tasks - // For now, we'll just log that this should be called by the integration layer - glog.V(1).Infof("SetMaintenancePolicyFromTasks called - policy should be built by the integration layer") -} - // TaskRegistry manages task factories type TaskRegistry struct { factories map[types.TaskType]types.TaskFactory diff --git a/weed/worker/tasks/schema_provider.go b/weed/worker/tasks/schema_provider.go index 4d69556b1..9715aad17 100644 --- a/weed/worker/tasks/schema_provider.go +++ b/weed/worker/tasks/schema_provider.go @@ -36,16 +36,3 @@ func RegisterTaskConfigSchema(taskType string, provider TaskConfigSchemaProvider defer globalSchemaRegistry.mutex.Unlock() globalSchemaRegistry.providers[taskType] = provider } - -// GetTaskConfigSchema returns the schema for the specified task type -func GetTaskConfigSchema(taskType string) *TaskConfigSchema { - globalSchemaRegistry.mutex.RLock() - provider, exists := globalSchemaRegistry.providers[taskType] - globalSchemaRegistry.mutex.RUnlock() - - if !exists { - return nil - } - - return provider.GetConfigSchema() -} diff --git a/weed/worker/tasks/task.go b/weed/worker/tasks/task.go index f3eed8b2d..4ce022326 100644 --- a/weed/worker/tasks/task.go +++ b/weed/worker/tasks/task.go @@ -1,12 +1,9 @@ package tasks import ( - "context" - "fmt" "sync" "time" - "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb" "github.com/seaweedfs/seaweedfs/weed/worker/types" ) @@ -26,353 +23,11 @@ type BaseTask struct { currentStage string // Current stage description } -// NewBaseTask creates a new base task -func NewBaseTask(taskType types.TaskType) *BaseTask { - return &BaseTask{ - taskType: taskType, - progress: 0.0, - cancelled: false, - loggerConfig: DefaultTaskLoggerConfig(), - } -} - -// NewBaseTaskWithLogger creates a new base task with custom logger configuration -func NewBaseTaskWithLogger(taskType types.TaskType, loggerConfig TaskLoggerConfig) *BaseTask { - return &BaseTask{ - taskType: taskType, - progress: 0.0, - cancelled: false, - loggerConfig: loggerConfig, - } -} - -// InitializeLogger initializes the task logger with task details -func (t *BaseTask) InitializeLogger(taskID string, workerID string, params types.TaskParams) error { - return t.InitializeTaskLogger(taskID, workerID, params) -} - -// InitializeTaskLogger initializes the task logger with task details (LoggerProvider interface) -func (t *BaseTask) InitializeTaskLogger(taskID string, workerID string, params types.TaskParams) error { - t.mutex.Lock() - defer t.mutex.Unlock() - - t.taskID = taskID - - logger, err := NewTaskLogger(taskID, t.taskType, workerID, params, t.loggerConfig) - if err != nil { - return fmt.Errorf("failed to initialize task logger: %w", err) - } - - t.logger = logger - t.logger.Info("BaseTask initialized for task %s (type: %s)", taskID, t.taskType) - - return nil -} - -// Type returns the task type -func (t *BaseTask) Type() types.TaskType { - return t.taskType -} - -// GetProgress returns the current progress (0.0 to 100.0) -func (t *BaseTask) GetProgress() float64 { - t.mutex.RLock() - defer t.mutex.RUnlock() - return t.progress -} - -// SetProgress sets the current progress and logs it -func (t *BaseTask) SetProgress(progress float64) { - t.mutex.Lock() - if progress < 0 { - progress = 0 - } - if progress > 100 { - progress = 100 - } - oldProgress := t.progress - callback := t.progressCallback - stage := t.currentStage - t.progress = progress - t.mutex.Unlock() - - // Log progress change - if t.logger != nil && progress != oldProgress { - message := stage - if message == "" { - message = fmt.Sprintf("Progress updated from %.1f%% to %.1f%%", oldProgress, progress) - } - t.logger.LogProgress(progress, message) - } - - // Call progress callback if set - if callback != nil && progress != oldProgress { - callback(progress, stage) - } -} - -// SetProgressWithStage sets the current progress with a stage description -func (t *BaseTask) SetProgressWithStage(progress float64, stage string) { - t.mutex.Lock() - if progress < 0 { - progress = 0 - } - if progress > 100 { - progress = 100 - } - callback := t.progressCallback - t.progress = progress - t.currentStage = stage - t.mutex.Unlock() - - // Log progress change - if t.logger != nil { - t.logger.LogProgress(progress, stage) - } - - // Call progress callback if set - if callback != nil { - callback(progress, stage) - } -} - -// SetCurrentStage sets the current stage description -func (t *BaseTask) SetCurrentStage(stage string) { - t.mutex.Lock() - defer t.mutex.Unlock() - t.currentStage = stage -} - -// GetCurrentStage returns the current stage description -func (t *BaseTask) GetCurrentStage() string { - t.mutex.RLock() - defer t.mutex.RUnlock() - return t.currentStage -} - -// Cancel cancels the task -func (t *BaseTask) Cancel() error { - t.mutex.Lock() - defer t.mutex.Unlock() - - if t.cancelled { - return nil - } - - t.cancelled = true - - if t.logger != nil { - t.logger.LogStatus("cancelled", "Task cancelled by request") - t.logger.Warning("Task %s was cancelled", t.taskID) - } - - return nil -} - -// IsCancelled returns whether the task is cancelled -func (t *BaseTask) IsCancelled() bool { - t.mutex.RLock() - defer t.mutex.RUnlock() - return t.cancelled -} - -// SetStartTime sets the task start time -func (t *BaseTask) SetStartTime(startTime time.Time) { - t.mutex.Lock() - defer t.mutex.Unlock() - t.startTime = startTime - - if t.logger != nil { - t.logger.LogStatus("running", fmt.Sprintf("Task started at %s", startTime.Format(time.RFC3339))) - } -} - -// GetStartTime returns the task start time -func (t *BaseTask) GetStartTime() time.Time { - t.mutex.RLock() - defer t.mutex.RUnlock() - return t.startTime -} - -// SetEstimatedDuration sets the estimated duration -func (t *BaseTask) SetEstimatedDuration(duration time.Duration) { - t.mutex.Lock() - defer t.mutex.Unlock() - t.estimatedDuration = duration - - if t.logger != nil { - t.logger.LogWithFields("INFO", "Estimated duration set", map[string]interface{}{ - "estimated_duration": duration.String(), - "estimated_seconds": duration.Seconds(), - }) - } -} - -// GetEstimatedDuration returns the estimated duration -func (t *BaseTask) GetEstimatedDuration() time.Duration { - t.mutex.RLock() - defer t.mutex.RUnlock() - return t.estimatedDuration -} - -// SetProgressCallback sets the progress callback function -func (t *BaseTask) SetProgressCallback(callback func(float64, string)) { - t.mutex.Lock() - defer t.mutex.Unlock() - t.progressCallback = callback -} - -// SetLoggerConfig sets the logger configuration for this task -func (t *BaseTask) SetLoggerConfig(config TaskLoggerConfig) { - t.mutex.Lock() - defer t.mutex.Unlock() - t.loggerConfig = config -} - -// GetLogger returns the task logger -func (t *BaseTask) GetLogger() TaskLogger { - t.mutex.RLock() - defer t.mutex.RUnlock() - return t.logger -} - -// GetTaskLogger returns the task logger (LoggerProvider interface) -func (t *BaseTask) GetTaskLogger() TaskLogger { - t.mutex.RLock() - defer t.mutex.RUnlock() - return t.logger -} - -// LogInfo logs an info message -func (t *BaseTask) LogInfo(message string, args ...interface{}) { - if t.logger != nil { - t.logger.Info(message, args...) - } -} - -// LogWarning logs a warning message -func (t *BaseTask) LogWarning(message string, args ...interface{}) { - if t.logger != nil { - t.logger.Warning(message, args...) - } -} - -// LogError logs an error message -func (t *BaseTask) LogError(message string, args ...interface{}) { - if t.logger != nil { - t.logger.Error(message, args...) - } -} - -// LogDebug logs a debug message -func (t *BaseTask) LogDebug(message string, args ...interface{}) { - if t.logger != nil { - t.logger.Debug(message, args...) - } -} - -// LogWithFields logs a message with structured fields -func (t *BaseTask) LogWithFields(level string, message string, fields map[string]interface{}) { - if t.logger != nil { - t.logger.LogWithFields(level, message, fields) - } -} - -// FinishTask finalizes the task and closes the logger -func (t *BaseTask) FinishTask(success bool, errorMsg string) error { - if t.logger != nil { - if success { - t.logger.LogStatus("completed", "Task completed successfully") - t.logger.Info("Task %s finished successfully", t.taskID) - } else { - t.logger.LogStatus("failed", fmt.Sprintf("Task failed: %s", errorMsg)) - t.logger.Error("Task %s failed: %s", t.taskID, errorMsg) - } - - // Close logger - if err := t.logger.Close(); err != nil { - glog.Errorf("Failed to close task logger: %v", err) - } - } - - return nil -} - -// ExecuteTask is a wrapper that handles common task execution logic with logging -func (t *BaseTask) ExecuteTask(ctx context.Context, params types.TaskParams, executor func(context.Context, types.TaskParams) error) error { - // Initialize logger if not already done - if t.logger == nil { - // Generate a temporary task ID if none provided - if t.taskID == "" { - t.taskID = fmt.Sprintf("task_%d", time.Now().UnixNano()) - } - - workerID := "unknown" - if err := t.InitializeLogger(t.taskID, workerID, params); err != nil { - glog.Warningf("Failed to initialize task logger: %v", err) - } - } - - t.SetStartTime(time.Now()) - t.SetProgress(0) - - if t.logger != nil { - t.logger.LogWithFields("INFO", "Task execution started", map[string]interface{}{ - "volume_id": params.VolumeID, - "server": getServerFromSources(params.TypedParams.Sources), - "collection": params.Collection, - }) - } - - // Create a context that can be cancelled - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - // Monitor for cancellation - go func() { - for !t.IsCancelled() { - select { - case <-ctx.Done(): - return - case <-time.After(time.Second): - // Check cancellation every second - } - } - t.LogWarning("Task cancellation detected, cancelling context") - cancel() - }() - - // Execute the actual task - t.LogInfo("Starting task executor") - err := executor(ctx, params) - - if err != nil { - t.LogError("Task executor failed: %v", err) - t.FinishTask(false, err.Error()) - return err - } - - if t.IsCancelled() { - t.LogWarning("Task was cancelled during execution") - t.FinishTask(false, "cancelled") - return context.Canceled - } - - t.SetProgress(100) - t.LogInfo("Task executor completed successfully") - t.FinishTask(true, "") - return nil -} - // UnsupportedTaskTypeError represents an error for unsupported task types type UnsupportedTaskTypeError struct { TaskType types.TaskType } -func (e *UnsupportedTaskTypeError) Error() string { - return "unsupported task type: " + string(e.TaskType) -} - // BaseTaskFactory provides common functionality for task factories type BaseTaskFactory struct { taskType types.TaskType @@ -399,37 +54,12 @@ func (f *BaseTaskFactory) Description() string { return f.description } -// ValidateParams validates task parameters -func ValidateParams(params types.TaskParams, requiredFields ...string) error { - for _, field := range requiredFields { - switch field { - case "volume_id": - if params.VolumeID == 0 { - return &ValidationError{Field: field, Message: "volume_id is required"} - } - case "server": - if len(params.TypedParams.Sources) == 0 { - return &ValidationError{Field: field, Message: "server is required"} - } - case "collection": - if params.Collection == "" { - return &ValidationError{Field: field, Message: "collection is required"} - } - } - } - return nil -} - // ValidationError represents a parameter validation error type ValidationError struct { Field string Message string } -func (e *ValidationError) Error() string { - return e.Field + ": " + e.Message -} - // getServerFromSources extracts the server address from unified sources func getServerFromSources(sources []*worker_pb.TaskSource) string { if len(sources) > 0 { diff --git a/weed/worker/tasks/task_log_handler.go b/weed/worker/tasks/task_log_handler.go index fee62325e..e2d2fc185 100644 --- a/weed/worker/tasks/task_log_handler.go +++ b/weed/worker/tasks/task_log_handler.go @@ -223,36 +223,3 @@ func (h *TaskLogHandler) readTaskLogEntries(logDir string, request *worker_pb.Ta return pbEntries, nil } - -// ListAvailableTaskLogs returns a list of available task log directories -func (h *TaskLogHandler) ListAvailableTaskLogs() ([]string, error) { - entries, err := os.ReadDir(h.baseLogDir) - if err != nil { - return nil, fmt.Errorf("failed to read base log directory: %w", err) - } - - var taskDirs []string - for _, entry := range entries { - if entry.IsDir() { - taskDirs = append(taskDirs, entry.Name()) - } - } - - return taskDirs, nil -} - -// CleanupOldLogs removes old task logs beyond the specified limit -func (h *TaskLogHandler) CleanupOldLogs(maxTasks int) error { - config := TaskLoggerConfig{ - BaseLogDir: h.baseLogDir, - MaxTasks: maxTasks, - } - - // Create a temporary logger to trigger cleanup - tempLogger := &FileTaskLogger{ - config: config, - } - - tempLogger.cleanupOldLogs() - return nil -} diff --git a/weed/worker/tasks/ui_base.go b/weed/worker/tasks/ui_base.go index eb9369337..265914aa6 100644 --- a/weed/worker/tasks/ui_base.go +++ b/weed/worker/tasks/ui_base.go @@ -1,9 +1,6 @@ package tasks import ( - "reflect" - - "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb" "github.com/seaweedfs/seaweedfs/weed/worker/types" ) @@ -85,100 +82,5 @@ type CommonConfigGetter[T any] struct { schedulerFunc func() T } -// NewCommonConfigGetter creates a new common config getter -func NewCommonConfigGetter[T any]( - defaultConfig T, - detectorFunc func() T, - schedulerFunc func() T, -) *CommonConfigGetter[T] { - return &CommonConfigGetter[T]{ - defaultConfig: defaultConfig, - detectorFunc: detectorFunc, - schedulerFunc: schedulerFunc, - } -} - -// GetConfig returns the merged configuration -func (cg *CommonConfigGetter[T]) GetConfig() T { - config := cg.defaultConfig - - // Apply detector values if available - if cg.detectorFunc != nil { - detectorConfig := cg.detectorFunc() - mergeConfigs(&config, detectorConfig) - } - - // Apply scheduler values if available - if cg.schedulerFunc != nil { - schedulerConfig := cg.schedulerFunc() - mergeConfigs(&config, schedulerConfig) - } - - return config -} - -// mergeConfigs merges non-zero values from source into dest -func mergeConfigs[T any](dest *T, source T) { - destValue := reflect.ValueOf(dest).Elem() - sourceValue := reflect.ValueOf(source) - - if destValue.Kind() != reflect.Struct || sourceValue.Kind() != reflect.Struct { - return - } - - for i := 0; i < destValue.NumField(); i++ { - destField := destValue.Field(i) - sourceField := sourceValue.Field(i) - - if !destField.CanSet() { - continue - } - - // Only copy non-zero values - if !sourceField.IsZero() { - if destField.Type() == sourceField.Type() { - destField.Set(sourceField) - } - } - } -} - // RegisterUIFunc provides a common registration function signature type RegisterUIFunc[D, S any] func(uiRegistry *types.UIRegistry, detector D, scheduler S) - -// CommonRegisterUI provides a common registration implementation -func CommonRegisterUI[D, S any]( - taskType types.TaskType, - displayName string, - uiRegistry *types.UIRegistry, - detector D, - scheduler S, - schemaFunc func() *TaskConfigSchema, - configFunc func() types.TaskConfig, - applyTaskPolicyFunc func(policy *worker_pb.TaskPolicy) error, - applyTaskConfigFunc func(config types.TaskConfig) error, -) { - // Get metadata from schema - schema := schemaFunc() - description := "Task configuration" - icon := "fas fa-cog" - - if schema != nil { - description = schema.Description - icon = schema.Icon - } - - uiProvider := NewBaseUIProvider( - taskType, - displayName, - description, - icon, - schemaFunc, - configFunc, - applyTaskPolicyFunc, - applyTaskConfigFunc, - ) - - uiRegistry.RegisterUI(uiProvider) - glog.V(1).Infof("Registered %s task UI provider", taskType) -} diff --git a/weed/worker/tasks/util/csv.go b/weed/worker/tasks/util/csv.go deleted file mode 100644 index 50fb09bff..000000000 --- a/weed/worker/tasks/util/csv.go +++ /dev/null @@ -1,20 +0,0 @@ -package util - -import "strings" - -// ParseCSVSet splits a comma-separated string into a set of trimmed, -// non-empty values. Returns nil if the input is empty. -func ParseCSVSet(csv string) map[string]bool { - csv = strings.TrimSpace(csv) - if csv == "" { - return nil - } - set := make(map[string]bool) - for _, item := range strings.Split(csv, ",") { - trimmed := strings.TrimSpace(item) - if trimmed != "" { - set[trimmed] = true - } - } - return set -} diff --git a/weed/worker/tasks/vacuum/monitoring.go b/weed/worker/tasks/vacuum/monitoring.go deleted file mode 100644 index c7dfd673e..000000000 --- a/weed/worker/tasks/vacuum/monitoring.go +++ /dev/null @@ -1,151 +0,0 @@ -package vacuum - -import ( - "sync" - "time" -) - -// VacuumMetrics contains vacuum-specific monitoring data -type VacuumMetrics struct { - // Execution metrics - VolumesVacuumed int64 `json:"volumes_vacuumed"` - TotalSpaceReclaimed int64 `json:"total_space_reclaimed"` - TotalFilesProcessed int64 `json:"total_files_processed"` - TotalGarbageCollected int64 `json:"total_garbage_collected"` - LastVacuumTime time.Time `json:"last_vacuum_time"` - - // Performance metrics - AverageVacuumTime int64 `json:"average_vacuum_time_seconds"` - AverageGarbageRatio float64 `json:"average_garbage_ratio"` - SuccessfulOperations int64 `json:"successful_operations"` - FailedOperations int64 `json:"failed_operations"` - - // Current task metrics - CurrentGarbageRatio float64 `json:"current_garbage_ratio"` - VolumesPendingVacuum int `json:"volumes_pending_vacuum"` - - mutex sync.RWMutex -} - -// NewVacuumMetrics creates a new vacuum metrics instance -func NewVacuumMetrics() *VacuumMetrics { - return &VacuumMetrics{ - LastVacuumTime: time.Now(), - } -} - -// RecordVolumeVacuumed records a successful volume vacuum operation -func (m *VacuumMetrics) RecordVolumeVacuumed(spaceReclaimed int64, filesProcessed int64, garbageCollected int64, vacuumTime time.Duration, garbageRatio float64) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.VolumesVacuumed++ - m.TotalSpaceReclaimed += spaceReclaimed - m.TotalFilesProcessed += filesProcessed - m.TotalGarbageCollected += garbageCollected - m.SuccessfulOperations++ - m.LastVacuumTime = time.Now() - - // Update average vacuum time - if m.AverageVacuumTime == 0 { - m.AverageVacuumTime = int64(vacuumTime.Seconds()) - } else { - // Exponential moving average - newTime := int64(vacuumTime.Seconds()) - m.AverageVacuumTime = (m.AverageVacuumTime*4 + newTime) / 5 - } - - // Update average garbage ratio - if m.AverageGarbageRatio == 0 { - m.AverageGarbageRatio = garbageRatio - } else { - // Exponential moving average - m.AverageGarbageRatio = 0.8*m.AverageGarbageRatio + 0.2*garbageRatio - } -} - -// RecordFailure records a failed vacuum operation -func (m *VacuumMetrics) RecordFailure() { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.FailedOperations++ -} - -// UpdateCurrentGarbageRatio updates the current volume's garbage ratio -func (m *VacuumMetrics) UpdateCurrentGarbageRatio(ratio float64) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.CurrentGarbageRatio = ratio -} - -// SetVolumesPendingVacuum sets the number of volumes pending vacuum -func (m *VacuumMetrics) SetVolumesPendingVacuum(count int) { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.VolumesPendingVacuum = count -} - -// GetMetrics returns a copy of the current metrics (without the mutex) -func (m *VacuumMetrics) GetMetrics() VacuumMetrics { - m.mutex.RLock() - defer m.mutex.RUnlock() - - // Create a copy without the mutex to avoid copying lock value - return VacuumMetrics{ - VolumesVacuumed: m.VolumesVacuumed, - TotalSpaceReclaimed: m.TotalSpaceReclaimed, - TotalFilesProcessed: m.TotalFilesProcessed, - TotalGarbageCollected: m.TotalGarbageCollected, - LastVacuumTime: m.LastVacuumTime, - AverageVacuumTime: m.AverageVacuumTime, - AverageGarbageRatio: m.AverageGarbageRatio, - SuccessfulOperations: m.SuccessfulOperations, - FailedOperations: m.FailedOperations, - CurrentGarbageRatio: m.CurrentGarbageRatio, - VolumesPendingVacuum: m.VolumesPendingVacuum, - } -} - -// GetSuccessRate returns the success rate as a percentage -func (m *VacuumMetrics) GetSuccessRate() float64 { - m.mutex.RLock() - defer m.mutex.RUnlock() - - total := m.SuccessfulOperations + m.FailedOperations - if total == 0 { - return 100.0 - } - return float64(m.SuccessfulOperations) / float64(total) * 100.0 -} - -// GetAverageSpaceReclaimed returns the average space reclaimed per volume -func (m *VacuumMetrics) GetAverageSpaceReclaimed() float64 { - m.mutex.RLock() - defer m.mutex.RUnlock() - - if m.VolumesVacuumed == 0 { - return 0 - } - return float64(m.TotalSpaceReclaimed) / float64(m.VolumesVacuumed) -} - -// Reset resets all metrics to zero -func (m *VacuumMetrics) Reset() { - m.mutex.Lock() - defer m.mutex.Unlock() - - *m = VacuumMetrics{ - LastVacuumTime: time.Now(), - } -} - -// Global metrics instance for vacuum tasks -var globalVacuumMetrics = NewVacuumMetrics() - -// GetGlobalVacuumMetrics returns the global vacuum metrics instance -func GetGlobalVacuumMetrics() *VacuumMetrics { - return globalVacuumMetrics -} diff --git a/weed/worker/types/config_types.go b/weed/worker/types/config_types.go index 5a9e94fd5..1f91ec085 100644 --- a/weed/worker/types/config_types.go +++ b/weed/worker/types/config_types.go @@ -109,15 +109,6 @@ type MaintenanceWorkersData struct { var defaultCapabilities []TaskType var defaultCapabilitiesMutex sync.RWMutex -// SetDefaultCapabilities sets the default capabilities for workers -// This should be called after task registration is complete -func SetDefaultCapabilities(capabilities []TaskType) { - defaultCapabilitiesMutex.Lock() - defer defaultCapabilitiesMutex.Unlock() - defaultCapabilities = make([]TaskType, len(capabilities)) - copy(defaultCapabilities, capabilities) -} - // GetDefaultCapabilities returns the default capabilities for workers func GetDefaultCapabilities() []TaskType { defaultCapabilitiesMutex.RLock() @@ -129,18 +120,6 @@ func GetDefaultCapabilities() []TaskType { return result } -// DefaultMaintenanceConfig returns default maintenance configuration -func DefaultMaintenanceConfig() *MaintenanceConfig { - return &MaintenanceConfig{ - Enabled: true, - ScanInterval: 30 * time.Minute, - CleanInterval: 6 * time.Hour, - TaskRetention: 7 * 24 * time.Hour, // 7 days - WorkerTimeout: 5 * time.Minute, - Policy: NewMaintenancePolicy(), - } -} - // DefaultWorkerConfig returns default worker configuration func DefaultWorkerConfig() *WorkerConfig { // Get dynamic capabilities from registered task types @@ -154,119 +133,3 @@ func DefaultWorkerConfig() *WorkerConfig { Capabilities: capabilities, } } - -// NewMaintenancePolicy creates a new dynamic maintenance policy -func NewMaintenancePolicy() *MaintenancePolicy { - return &MaintenancePolicy{ - TaskConfigs: make(map[TaskType]interface{}), - GlobalSettings: &GlobalMaintenanceSettings{ - DefaultMaxConcurrent: 2, - MaintenanceEnabled: true, - DefaultScanInterval: 30 * time.Minute, - DefaultTaskTimeout: 5 * time.Minute, - DefaultRetryCount: 3, - DefaultRetryInterval: 5 * time.Minute, - DefaultPriorityBoostAge: 24 * time.Hour, - GlobalConcurrentLimit: 5, - }, - } -} - -// SetTaskConfig sets the configuration for a specific task type -func (p *MaintenancePolicy) SetTaskConfig(taskType TaskType, config interface{}) { - if p.TaskConfigs == nil { - p.TaskConfigs = make(map[TaskType]interface{}) - } - p.TaskConfigs[taskType] = config -} - -// GetTaskConfig returns the configuration for a specific task type -func (p *MaintenancePolicy) GetTaskConfig(taskType TaskType) interface{} { - if p.TaskConfigs == nil { - return nil - } - return p.TaskConfigs[taskType] -} - -// IsTaskEnabled returns whether a task type is enabled (generic helper) -func (p *MaintenancePolicy) IsTaskEnabled(taskType TaskType) bool { - if !p.GlobalSettings.MaintenanceEnabled { - return false - } - - config := p.GetTaskConfig(taskType) - if config == nil { - return false - } - - // Try to get enabled field from config using type assertion - if configMap, ok := config.(map[string]interface{}); ok { - if enabled, exists := configMap["enabled"]; exists { - if enabledBool, ok := enabled.(bool); ok { - return enabledBool - } - } - } - - // If we can't determine from config, default to global setting - return p.GlobalSettings.MaintenanceEnabled -} - -// GetMaxConcurrent returns the max concurrent setting for a task type -func (p *MaintenancePolicy) GetMaxConcurrent(taskType TaskType) int { - config := p.GetTaskConfig(taskType) - if config == nil { - return p.GlobalSettings.DefaultMaxConcurrent - } - - // Try to get max_concurrent field from config - if configMap, ok := config.(map[string]interface{}); ok { - if maxConcurrent, exists := configMap["max_concurrent"]; exists { - if maxConcurrentInt, ok := maxConcurrent.(int); ok { - return maxConcurrentInt - } - if maxConcurrentFloat, ok := maxConcurrent.(float64); ok { - return int(maxConcurrentFloat) - } - } - } - - return p.GlobalSettings.DefaultMaxConcurrent -} - -// GetScanInterval returns the scan interval for a task type -func (p *MaintenancePolicy) GetScanInterval(taskType TaskType) time.Duration { - config := p.GetTaskConfig(taskType) - if config == nil { - return p.GlobalSettings.DefaultScanInterval - } - - // Try to get scan_interval field from config - if configMap, ok := config.(map[string]interface{}); ok { - if scanInterval, exists := configMap["scan_interval"]; exists { - if scanIntervalDuration, ok := scanInterval.(time.Duration); ok { - return scanIntervalDuration - } - if scanIntervalString, ok := scanInterval.(string); ok { - if duration, err := time.ParseDuration(scanIntervalString); err == nil { - return duration - } - } - } - } - - return p.GlobalSettings.DefaultScanInterval -} - -// GetAllTaskTypes returns all configured task types -func (p *MaintenancePolicy) GetAllTaskTypes() []TaskType { - if p.TaskConfigs == nil { - return []TaskType{} - } - - taskTypes := make([]TaskType, 0, len(p.TaskConfigs)) - for taskType := range p.TaskConfigs { - taskTypes = append(taskTypes, taskType) - } - return taskTypes -} diff --git a/weed/worker/types/task.go b/weed/worker/types/task.go index 7e924453c..5ebed89c1 100644 --- a/weed/worker/types/task.go +++ b/weed/worker/types/task.go @@ -53,14 +53,6 @@ type Logger interface { // NoOpLogger is a logger that does nothing (silent) type NoOpLogger struct{} -func (l *NoOpLogger) Info(msg string, args ...interface{}) {} -func (l *NoOpLogger) Warning(msg string, args ...interface{}) {} -func (l *NoOpLogger) Error(msg string, args ...interface{}) {} -func (l *NoOpLogger) Debug(msg string, args ...interface{}) {} -func (l *NoOpLogger) WithFields(fields map[string]interface{}) Logger { - return l // Return self since we're doing nothing anyway -} - // GlogFallbackLogger is a logger that falls back to glog type GlogFallbackLogger struct{} @@ -137,87 +129,3 @@ type UnifiedBaseTask struct { currentStage string workingDir string } - -// NewBaseTask creates a new base task -func NewUnifiedBaseTask(id string, taskType TaskType) *UnifiedBaseTask { - return &UnifiedBaseTask{ - id: id, - taskType: taskType, - } -} - -// ID returns the task ID -func (t *UnifiedBaseTask) ID() string { - return t.id -} - -// Type returns the task type -func (t *UnifiedBaseTask) Type() TaskType { - return t.taskType -} - -// SetProgressCallback sets the progress callback -func (t *UnifiedBaseTask) SetProgressCallback(callback func(float64, string)) { - t.progressCallback = callback -} - -// ReportProgress reports current progress through the callback -func (t *UnifiedBaseTask) ReportProgress(progress float64) { - if t.progressCallback != nil { - t.progressCallback(progress, t.currentStage) - } -} - -// ReportProgressWithStage reports current progress with a specific stage description -func (t *UnifiedBaseTask) ReportProgressWithStage(progress float64, stage string) { - t.currentStage = stage - if t.progressCallback != nil { - t.progressCallback(progress, stage) - } -} - -// SetCurrentStage sets the current stage description -func (t *UnifiedBaseTask) SetCurrentStage(stage string) { - t.currentStage = stage -} - -// GetCurrentStage returns the current stage description -func (t *UnifiedBaseTask) GetCurrentStage() string { - return t.currentStage -} - -// Cancel marks the task as cancelled -func (t *UnifiedBaseTask) Cancel() error { - t.cancelled = true - return nil -} - -// IsCancellable returns true if the task can be cancelled -func (t *UnifiedBaseTask) IsCancellable() bool { - return true -} - -// IsCancelled returns true if the task has been cancelled -func (t *UnifiedBaseTask) IsCancelled() bool { - return t.cancelled -} - -// SetLogger sets the task logger -func (t *UnifiedBaseTask) SetLogger(logger Logger) { - t.logger = logger -} - -// GetLogger returns the task logger -func (t *UnifiedBaseTask) GetLogger() Logger { - return t.logger -} - -// SetWorkingDir sets the task working directory -func (t *UnifiedBaseTask) SetWorkingDir(workingDir string) { - t.workingDir = workingDir -} - -// GetWorkingDir returns the task working directory -func (t *UnifiedBaseTask) GetWorkingDir() string { - return t.workingDir -} diff --git a/weed/worker/types/task_ui.go b/weed/worker/types/task_ui.go index 8a57e83be..e10d727ac 100644 --- a/weed/worker/types/task_ui.go +++ b/weed/worker/types/task_ui.go @@ -6,47 +6,6 @@ import ( "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb" ) -// Helper function to convert seconds to the most appropriate interval unit -func secondsToIntervalValueUnit(totalSeconds int) (int, string) { - if totalSeconds == 0 { - return 0, "minute" - } - - // Preserve seconds when not divisible by minutes - if totalSeconds < 60 || totalSeconds%60 != 0 { - return totalSeconds, "second" - } - - // Check if it's evenly divisible by days - if totalSeconds%(24*3600) == 0 { - return totalSeconds / (24 * 3600), "day" - } - - // Check if it's evenly divisible by hours - if totalSeconds%3600 == 0 { - return totalSeconds / 3600, "hour" - } - - // Default to minutes - return totalSeconds / 60, "minute" -} - -// Helper function to convert interval value and unit to seconds -func IntervalValueUnitToSeconds(value int, unit string) int { - switch unit { - case "day": - return value * 24 * 3600 - case "hour": - return value * 3600 - case "minute": - return value * 60 - case "second": - return value - default: - return value * 60 // Default to minutes - } -} - // TaskConfig defines the interface for task configurations // This matches the interfaces used in base package and handlers type TaskConfig interface { diff --git a/weed/worker/types/typed_task_interface.go b/weed/worker/types/typed_task_interface.go index 39eaa2286..1ff26ab40 100644 --- a/weed/worker/types/typed_task_interface.go +++ b/weed/worker/types/typed_task_interface.go @@ -90,24 +90,6 @@ func (r *TypedTaskRegistry) RegisterTypedTask(taskType TaskType, creator TypedTa r.creators[taskType] = creator } -// CreateTypedTask creates a new typed task instance -func (r *TypedTaskRegistry) CreateTypedTask(taskType TaskType) (TypedTaskInterface, error) { - creator, exists := r.creators[taskType] - if !exists { - return nil, ErrTaskTypeNotFound - } - return creator(), nil -} - -// GetSupportedTypes returns all registered typed task types -func (r *TypedTaskRegistry) GetSupportedTypes() []TaskType { - types := make([]TaskType, 0, len(r.creators)) - for taskType := range r.creators { - types = append(types, taskType) - } - return types -} - // Global typed task registry var globalTypedTaskRegistry = NewTypedTaskRegistry() @@ -115,8 +97,3 @@ var globalTypedTaskRegistry = NewTypedTaskRegistry() func RegisterGlobalTypedTask(taskType TaskType, creator TypedTaskCreator) { globalTypedTaskRegistry.RegisterTypedTask(taskType, creator) } - -// GetGlobalTypedTaskRegistry returns the global typed task registry -func GetGlobalTypedTaskRegistry() *TypedTaskRegistry { - return globalTypedTaskRegistry -} diff --git a/weed/worker/types/worker.go b/weed/worker/types/worker.go index 9db5ba2c4..ac6cfac08 100644 --- a/weed/worker/types/worker.go +++ b/weed/worker/types/worker.go @@ -30,47 +30,3 @@ type BaseWorker struct { currentTasks map[string]Task logger Logger } - -// NewBaseWorker creates a new base worker -func NewBaseWorker(id string) *BaseWorker { - return &BaseWorker{ - id: id, - currentTasks: make(map[string]Task), - } -} - -// Configure applies worker configuration -func (w *BaseWorker) Configure(config WorkerCreationConfig) error { - w.id = config.ID - w.capabilities = config.Capabilities - w.maxConcurrent = config.MaxConcurrent - - if config.LoggerFactory != nil { - logger, err := config.LoggerFactory.CreateLogger(context.Background(), LoggerConfig{ - ServiceName: "worker-" + w.id, - MinLevel: LogLevelInfo, - }) - if err != nil { - return err - } - w.logger = logger - } - - return nil -} - -// GetCapabilities returns worker capabilities -func (w *BaseWorker) GetCapabilities() []TaskType { - return w.capabilities -} - -// GetStatus returns current worker status -func (w *BaseWorker) GetStatus() WorkerStatus { - return WorkerStatus{ - WorkerID: w.id, - Status: "active", - Capabilities: w.capabilities, - MaxConcurrent: w.maxConcurrent, - CurrentLoad: len(w.currentTasks), - } -} diff --git a/weed/worker/worker.go b/weed/worker/worker.go index be2a2e9df..ffcd00b9e 100644 --- a/weed/worker/worker.go +++ b/weed/worker/worker.go @@ -383,31 +383,6 @@ func (w *Worker) setReqTick(tick *time.Ticker) *time.Ticker { return w.getReqTick() } -func (w *Worker) getStartTime() time.Time { - respCh := make(chan time.Time, 1) - w.cmds <- workerCommand{ - action: ActionGetStartTime, - data: respCh, - } - return <-respCh -} -func (w *Worker) getCompletedTasks() int { - respCh := make(chan int, 1) - w.cmds <- workerCommand{ - action: ActionGetCompletedTasks, - data: respCh, - } - return <-respCh -} -func (w *Worker) getFailedTasks() int { - respCh := make(chan int, 1) - w.cmds <- workerCommand{ - action: ActionGetFailedTasks, - data: respCh, - } - return <-respCh -} - // getTaskLoggerConfig returns the task logger configuration with worker's log directory func (w *Worker) getTaskLoggerConfig() tasks.TaskLoggerConfig { config := tasks.DefaultTaskLoggerConfig() @@ -543,27 +518,6 @@ func (w *Worker) handleStop(cmd workerCommand) { cmd.resp <- nil } -// RegisterTask registers a task factory -func (w *Worker) RegisterTask(taskType types.TaskType, factory types.TaskFactory) { - w.registry.Register(taskType, factory) -} - -// GetCapabilities returns the worker capabilities -func (w *Worker) GetCapabilities() []types.TaskType { - return w.config.Capabilities -} - -// GetStatus returns the current worker status -func (w *Worker) GetStatus() types.WorkerStatus { - respCh := make(statusResponse, 1) - w.cmds <- workerCommand{ - action: ActionGetStatus, - data: respCh, - resp: nil, - } - return <-respCh -} - // HandleTask handles a task execution func (w *Worker) HandleTask(task *types.TaskInput) error { glog.V(1).Infof("Worker %s received task %s (type: %s, volume: %d)", @@ -579,26 +533,6 @@ func (w *Worker) HandleTask(task *types.TaskInput) error { return nil } -// SetCapabilities sets the worker capabilities -func (w *Worker) SetCapabilities(capabilities []types.TaskType) { - w.config.Capabilities = capabilities -} - -// SetMaxConcurrent sets the maximum concurrent tasks -func (w *Worker) SetMaxConcurrent(max int) { - w.config.MaxConcurrent = max -} - -// SetHeartbeatInterval sets the heartbeat interval -func (w *Worker) SetHeartbeatInterval(interval time.Duration) { - w.config.HeartbeatInterval = interval -} - -// SetTaskRequestInterval sets the task request interval -func (w *Worker) SetTaskRequestInterval(interval time.Duration) { - w.config.TaskRequestInterval = interval -} - // SetAdminClient sets the admin client func (w *Worker) SetAdminClient(client AdminClient) { w.cmds <- workerCommand{ @@ -828,11 +762,6 @@ func (w *Worker) requestTasks() { } } -// GetTaskRegistry returns the task registry -func (w *Worker) GetTaskRegistry() *tasks.TaskRegistry { - return w.registry -} - // connectionMonitorLoop monitors connection status func (w *Worker) connectionMonitorLoop() { ticker := time.NewTicker(30 * time.Second) // Check every 30 seconds @@ -867,34 +796,6 @@ func (w *Worker) connectionMonitorLoop() { } } -// GetConfig returns the worker configuration -func (w *Worker) GetConfig() *types.WorkerConfig { - return w.config -} - -// GetPerformanceMetrics returns performance metrics -func (w *Worker) GetPerformanceMetrics() *types.WorkerPerformance { - - uptime := time.Since(w.getStartTime()) - var successRate float64 - totalTasks := w.getCompletedTasks() + w.getFailedTasks() - if totalTasks > 0 { - successRate = float64(w.getCompletedTasks()) / float64(totalTasks) * 100 - } - - return &types.WorkerPerformance{ - TasksCompleted: w.getCompletedTasks(), - TasksFailed: w.getFailedTasks(), - AverageTaskTime: 0, // Would need to track this - Uptime: uptime, - SuccessRate: successRate, - } -} - -func (w *Worker) GetAdmin() AdminClient { - return w.getAdmin() -} - // messageProcessingLoop processes incoming admin messages func (w *Worker) messageProcessingLoop() { glog.Infof("MESSAGE LOOP STARTED: Worker %s message processing loop started", w.id)