diff --git a/.claude/scheduled_tasks.lock b/.claude/scheduled_tasks.lock
new file mode 100644
index 000000000..df47ff816
--- /dev/null
+++ b/.claude/scheduled_tasks.lock
@@ -0,0 +1 @@
+{"sessionId":"d6574c47-eafc-4a94-9dce-f9ffea22b53c","pid":10111,"acquiredAt":1775248373916}
\ No newline at end of file
diff --git a/.superset/config.json b/.superset/config.json
new file mode 100644
index 000000000..f806b5255
--- /dev/null
+++ b/.superset/config.json
@@ -0,0 +1,5 @@
+{
+ "setup": [],
+ "teardown": [],
+ "run": []
+}
diff --git a/seaweed-volume/Cargo.lock b/seaweed-volume/Cargo.lock
index b5401c9a5..ad47d1d6f 100644
--- a/seaweed-volume/Cargo.lock
+++ b/seaweed-volume/Cargo.lock
@@ -2561,6 +2561,15 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
+[[package]]
+name = "openssl-src"
+version = "300.5.5+3.5.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3f1787d533e03597a7934fd0a765f0d28e94ecc5fb7789f8053b1e699a56f709"
+dependencies = [
+ "cc",
+]
+
[[package]]
name = "openssl-sys"
version = "0.9.111"
@@ -2569,6 +2578,7 @@ checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [
"cc",
"libc",
+ "openssl-src",
"pkg-config",
"vcpkg",
]
@@ -4654,6 +4664,7 @@ dependencies = [
"memmap2",
"mime_guess",
"multer",
+ "openssl",
"parking_lot 0.12.5",
"pprof",
"prometheus",
diff --git a/weed/admin/dash/types.go b/weed/admin/dash/types.go
index 965166de4..a7f22c541 100644
--- a/weed/admin/dash/types.go
+++ b/weed/admin/dash/types.go
@@ -447,11 +447,6 @@ type QueueStats = maintenance.QueueStats
type WorkerDetailsData = maintenance.WorkerDetailsData
type WorkerPerformance = maintenance.WorkerPerformance
-// GetTaskIcon returns the icon CSS class for a task type from its UI provider
-func GetTaskIcon(taskType MaintenanceTaskType) string {
- return maintenance.GetTaskIcon(taskType)
-}
-
// Status constants (these are still static)
const (
TaskStatusPending = maintenance.TaskStatusPending
diff --git a/weed/admin/handlers/cluster_handlers.go b/weed/admin/handlers/cluster_handlers.go
index c5303458f..b3bcc4fe3 100644
--- a/weed/admin/handlers/cluster_handlers.go
+++ b/weed/admin/handlers/cluster_handlers.go
@@ -312,29 +312,6 @@ func (h *ClusterHandlers) ShowClusterFilers(w http.ResponseWriter, r *http.Reque
}
}
-// ShowClusterBrokers renders the cluster message brokers page
-func (h *ClusterHandlers) ShowClusterBrokers(w http.ResponseWriter, r *http.Request) {
- // Get cluster brokers data
- brokersData, err := h.adminServer.GetClusterBrokers()
- if err != nil {
- writeJSONError(w, http.StatusInternalServerError, "Failed to get cluster brokers: "+err.Error())
- return
- }
-
- username := usernameOrDefault(r)
- brokersData.Username = username
-
- // Render HTML template
- w.Header().Set("Content-Type", "text/html")
- brokersComponent := app.ClusterBrokers(*brokersData)
- viewCtx := layout.NewViewContext(r, username, dash.CSRFTokenFromContext(r.Context()))
- layoutComponent := layout.Layout(viewCtx, brokersComponent)
- if err := layoutComponent.Render(r.Context(), w); err != nil {
- writeJSONError(w, http.StatusInternalServerError, "Failed to render template: "+err.Error())
- return
- }
-}
-
// GetClusterTopology returns the cluster topology as JSON
func (h *ClusterHandlers) GetClusterTopology(w http.ResponseWriter, r *http.Request) {
topology, err := h.adminServer.GetClusterTopology()
diff --git a/weed/admin/handlers/mq_handlers.go b/weed/admin/handlers/mq_handlers.go
index 5efa3cc3a..6c6e46e57 100644
--- a/weed/admin/handlers/mq_handlers.go
+++ b/weed/admin/handlers/mq_handlers.go
@@ -78,34 +78,6 @@ func (h *MessageQueueHandlers) ShowTopics(w http.ResponseWriter, r *http.Request
}
}
-// ShowSubscribers renders the message queue subscribers page
-func (h *MessageQueueHandlers) ShowSubscribers(w http.ResponseWriter, r *http.Request) {
- // Get subscribers data
- subscribersData, err := h.adminServer.GetSubscribers()
- if err != nil {
- writeJSONError(w, http.StatusInternalServerError, "Failed to get subscribers: "+err.Error())
- return
- }
-
- // Set username
- username := dash.UsernameFromContext(r.Context())
- if username == "" {
- username = "admin"
- }
- subscribersData.Username = username
-
- // Render HTML template
- w.Header().Set("Content-Type", "text/html")
- subscribersComponent := app.Subscribers(*subscribersData)
- viewCtx := layout.NewViewContext(r, username, dash.CSRFTokenFromContext(r.Context()))
- layoutComponent := layout.Layout(viewCtx, subscribersComponent)
- err = layoutComponent.Render(r.Context(), w)
- if err != nil {
- writeJSONError(w, http.StatusInternalServerError, "Failed to render template: "+err.Error())
- return
- }
-}
-
// ShowTopicDetails renders the topic details page
func (h *MessageQueueHandlers) ShowTopicDetails(w http.ResponseWriter, r *http.Request) {
// Get topic parameters from URL
diff --git a/weed/admin/maintenance/config_verification.go b/weed/admin/maintenance/config_verification.go
deleted file mode 100644
index 0ac40aad1..000000000
--- a/weed/admin/maintenance/config_verification.go
+++ /dev/null
@@ -1,124 +0,0 @@
-package maintenance
-
-import (
- "fmt"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
-)
-
-// VerifyProtobufConfig demonstrates that the protobuf configuration system is working
-func VerifyProtobufConfig() error {
- // Create configuration manager
- configManager := NewMaintenanceConfigManager()
- config := configManager.GetConfig()
-
- // Verify basic configuration
- if !config.Enabled {
- return fmt.Errorf("expected config to be enabled by default")
- }
-
- if config.ScanIntervalSeconds != 30*60 {
- return fmt.Errorf("expected scan interval to be 1800 seconds, got %d", config.ScanIntervalSeconds)
- }
-
- // Verify policy configuration
- if config.Policy == nil {
- return fmt.Errorf("expected policy to be configured")
- }
-
- if config.Policy.GlobalMaxConcurrent != 4 {
- return fmt.Errorf("expected global max concurrent to be 4, got %d", config.Policy.GlobalMaxConcurrent)
- }
-
- // Verify task policies
- vacuumPolicy := config.Policy.TaskPolicies["vacuum"]
- if vacuumPolicy == nil {
- return fmt.Errorf("expected vacuum policy to be configured")
- }
-
- if !vacuumPolicy.Enabled {
- return fmt.Errorf("expected vacuum policy to be enabled")
- }
-
- // Verify typed configuration access
- vacuumConfig := vacuumPolicy.GetVacuumConfig()
- if vacuumConfig == nil {
- return fmt.Errorf("expected vacuum config to be accessible")
- }
-
- if vacuumConfig.GarbageThreshold != 0.3 {
- return fmt.Errorf("expected garbage threshold to be 0.3, got %f", vacuumConfig.GarbageThreshold)
- }
-
- // Verify helper functions work
- if !IsTaskEnabled(config.Policy, "vacuum") {
- return fmt.Errorf("expected vacuum task to be enabled via helper function")
- }
-
- maxConcurrent := GetMaxConcurrent(config.Policy, "vacuum")
- if maxConcurrent != 2 {
- return fmt.Errorf("expected vacuum max concurrent to be 2, got %d", maxConcurrent)
- }
-
- // Verify erasure coding configuration
- ecPolicy := config.Policy.TaskPolicies["erasure_coding"]
- if ecPolicy == nil {
- return fmt.Errorf("expected EC policy to be configured")
- }
-
- ecConfig := ecPolicy.GetErasureCodingConfig()
- if ecConfig == nil {
- return fmt.Errorf("expected EC config to be accessible")
- }
-
- // Verify configurable EC fields only
- if ecConfig.FullnessRatio <= 0 || ecConfig.FullnessRatio > 1 {
- return fmt.Errorf("expected EC config to have valid fullness ratio (0-1), got %f", ecConfig.FullnessRatio)
- }
-
- return nil
-}
-
-// GetProtobufConfigSummary returns a summary of the current protobuf configuration
-func GetProtobufConfigSummary() string {
- configManager := NewMaintenanceConfigManager()
- config := configManager.GetConfig()
-
- summary := fmt.Sprintf("SeaweedFS Protobuf Maintenance Configuration:\n")
- summary += fmt.Sprintf(" Enabled: %v\n", config.Enabled)
- summary += fmt.Sprintf(" Scan Interval: %d seconds\n", config.ScanIntervalSeconds)
- summary += fmt.Sprintf(" Max Retries: %d\n", config.MaxRetries)
- summary += fmt.Sprintf(" Global Max Concurrent: %d\n", config.Policy.GlobalMaxConcurrent)
- summary += fmt.Sprintf(" Task Policies: %d configured\n", len(config.Policy.TaskPolicies))
-
- for taskType, policy := range config.Policy.TaskPolicies {
- summary += fmt.Sprintf(" %s: enabled=%v, max_concurrent=%d\n",
- taskType, policy.Enabled, policy.MaxConcurrent)
- }
-
- return summary
-}
-
-// CreateCustomConfig demonstrates creating a custom protobuf configuration
-func CreateCustomConfig() *worker_pb.MaintenanceConfig {
- return &worker_pb.MaintenanceConfig{
- Enabled: true,
- ScanIntervalSeconds: 60 * 60, // 1 hour
- MaxRetries: 5,
- Policy: &worker_pb.MaintenancePolicy{
- GlobalMaxConcurrent: 8,
- TaskPolicies: map[string]*worker_pb.TaskPolicy{
- "custom_vacuum": {
- Enabled: true,
- MaxConcurrent: 4,
- TaskConfig: &worker_pb.TaskPolicy_VacuumConfig{
- VacuumConfig: &worker_pb.VacuumTaskConfig{
- GarbageThreshold: 0.5,
- MinVolumeAgeHours: 48,
- },
- },
- },
- },
- },
- }
-}
diff --git a/weed/admin/maintenance/maintenance_config_proto.go b/weed/admin/maintenance/maintenance_config_proto.go
index 0d0bca7c6..4295a706f 100644
--- a/weed/admin/maintenance/maintenance_config_proto.go
+++ b/weed/admin/maintenance/maintenance_config_proto.go
@@ -1,24 +1,9 @@
package maintenance
import (
- "fmt"
- "time"
-
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
)
-// MaintenanceConfigManager handles protobuf-based configuration
-type MaintenanceConfigManager struct {
- config *worker_pb.MaintenanceConfig
-}
-
-// NewMaintenanceConfigManager creates a new config manager with defaults
-func NewMaintenanceConfigManager() *MaintenanceConfigManager {
- return &MaintenanceConfigManager{
- config: DefaultMaintenanceConfigProto(),
- }
-}
-
// DefaultMaintenanceConfigProto returns default configuration as protobuf
func DefaultMaintenanceConfigProto() *worker_pb.MaintenanceConfig {
return &worker_pb.MaintenanceConfig{
@@ -34,253 +19,3 @@ func DefaultMaintenanceConfigProto() *worker_pb.MaintenanceConfig {
Policy: nil,
}
}
-
-// GetConfig returns the current configuration
-func (mcm *MaintenanceConfigManager) GetConfig() *worker_pb.MaintenanceConfig {
- return mcm.config
-}
-
-// Type-safe configuration accessors
-
-// GetVacuumConfig returns vacuum-specific configuration for a task type
-func (mcm *MaintenanceConfigManager) GetVacuumConfig(taskType string) *worker_pb.VacuumTaskConfig {
- if policy := mcm.getTaskPolicy(taskType); policy != nil {
- if vacuumConfig := policy.GetVacuumConfig(); vacuumConfig != nil {
- return vacuumConfig
- }
- }
- // Return defaults if not configured
- return &worker_pb.VacuumTaskConfig{
- GarbageThreshold: 0.3,
- MinVolumeAgeHours: 24,
- }
-}
-
-// GetErasureCodingConfig returns EC-specific configuration for a task type
-func (mcm *MaintenanceConfigManager) GetErasureCodingConfig(taskType string) *worker_pb.ErasureCodingTaskConfig {
- if policy := mcm.getTaskPolicy(taskType); policy != nil {
- if ecConfig := policy.GetErasureCodingConfig(); ecConfig != nil {
- return ecConfig
- }
- }
- // Return defaults if not configured
- return &worker_pb.ErasureCodingTaskConfig{
- FullnessRatio: 0.95,
- QuietForSeconds: 3600,
- MinVolumeSizeMb: 100,
- CollectionFilter: "",
- }
-}
-
-// GetBalanceConfig returns balance-specific configuration for a task type
-func (mcm *MaintenanceConfigManager) GetBalanceConfig(taskType string) *worker_pb.BalanceTaskConfig {
- if policy := mcm.getTaskPolicy(taskType); policy != nil {
- if balanceConfig := policy.GetBalanceConfig(); balanceConfig != nil {
- return balanceConfig
- }
- }
- // Return defaults if not configured
- return &worker_pb.BalanceTaskConfig{
- ImbalanceThreshold: 0.2,
- MinServerCount: 2,
- }
-}
-
-// GetReplicationConfig returns replication-specific configuration for a task type
-func (mcm *MaintenanceConfigManager) GetReplicationConfig(taskType string) *worker_pb.ReplicationTaskConfig {
- if policy := mcm.getTaskPolicy(taskType); policy != nil {
- if replicationConfig := policy.GetReplicationConfig(); replicationConfig != nil {
- return replicationConfig
- }
- }
- // Return defaults if not configured
- return &worker_pb.ReplicationTaskConfig{
- TargetReplicaCount: 2,
- }
-}
-
-// Typed convenience methods for getting task configurations
-
-// GetVacuumTaskConfigForType returns vacuum configuration for a specific task type
-func (mcm *MaintenanceConfigManager) GetVacuumTaskConfigForType(taskType string) *worker_pb.VacuumTaskConfig {
- return GetVacuumTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
-}
-
-// GetErasureCodingTaskConfigForType returns erasure coding configuration for a specific task type
-func (mcm *MaintenanceConfigManager) GetErasureCodingTaskConfigForType(taskType string) *worker_pb.ErasureCodingTaskConfig {
- return GetErasureCodingTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
-}
-
-// GetBalanceTaskConfigForType returns balance configuration for a specific task type
-func (mcm *MaintenanceConfigManager) GetBalanceTaskConfigForType(taskType string) *worker_pb.BalanceTaskConfig {
- return GetBalanceTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
-}
-
-// GetReplicationTaskConfigForType returns replication configuration for a specific task type
-func (mcm *MaintenanceConfigManager) GetReplicationTaskConfigForType(taskType string) *worker_pb.ReplicationTaskConfig {
- return GetReplicationTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
-}
-
-// Helper methods
-
-func (mcm *MaintenanceConfigManager) getTaskPolicy(taskType string) *worker_pb.TaskPolicy {
- if mcm.config.Policy != nil && mcm.config.Policy.TaskPolicies != nil {
- return mcm.config.Policy.TaskPolicies[taskType]
- }
- return nil
-}
-
-// IsTaskEnabled returns whether a task type is enabled
-func (mcm *MaintenanceConfigManager) IsTaskEnabled(taskType string) bool {
- if policy := mcm.getTaskPolicy(taskType); policy != nil {
- return policy.Enabled
- }
- return false
-}
-
-// GetMaxConcurrent returns the max concurrent limit for a task type
-func (mcm *MaintenanceConfigManager) GetMaxConcurrent(taskType string) int32 {
- if policy := mcm.getTaskPolicy(taskType); policy != nil {
- return policy.MaxConcurrent
- }
- return 1 // Default
-}
-
-// GetRepeatInterval returns the repeat interval for a task type in seconds
-func (mcm *MaintenanceConfigManager) GetRepeatInterval(taskType string) int32 {
- if policy := mcm.getTaskPolicy(taskType); policy != nil {
- return policy.RepeatIntervalSeconds
- }
- return mcm.config.Policy.DefaultRepeatIntervalSeconds
-}
-
-// GetCheckInterval returns the check interval for a task type in seconds
-func (mcm *MaintenanceConfigManager) GetCheckInterval(taskType string) int32 {
- if policy := mcm.getTaskPolicy(taskType); policy != nil {
- return policy.CheckIntervalSeconds
- }
- return mcm.config.Policy.DefaultCheckIntervalSeconds
-}
-
-// Duration accessor methods
-
-// GetScanInterval returns the scan interval as a time.Duration
-func (mcm *MaintenanceConfigManager) GetScanInterval() time.Duration {
- return time.Duration(mcm.config.ScanIntervalSeconds) * time.Second
-}
-
-// GetWorkerTimeout returns the worker timeout as a time.Duration
-func (mcm *MaintenanceConfigManager) GetWorkerTimeout() time.Duration {
- return time.Duration(mcm.config.WorkerTimeoutSeconds) * time.Second
-}
-
-// GetTaskTimeout returns the task timeout as a time.Duration
-func (mcm *MaintenanceConfigManager) GetTaskTimeout() time.Duration {
- return time.Duration(mcm.config.TaskTimeoutSeconds) * time.Second
-}
-
-// GetRetryDelay returns the retry delay as a time.Duration
-func (mcm *MaintenanceConfigManager) GetRetryDelay() time.Duration {
- return time.Duration(mcm.config.RetryDelaySeconds) * time.Second
-}
-
-// GetCleanupInterval returns the cleanup interval as a time.Duration
-func (mcm *MaintenanceConfigManager) GetCleanupInterval() time.Duration {
- return time.Duration(mcm.config.CleanupIntervalSeconds) * time.Second
-}
-
-// GetTaskRetention returns the task retention period as a time.Duration
-func (mcm *MaintenanceConfigManager) GetTaskRetention() time.Duration {
- return time.Duration(mcm.config.TaskRetentionSeconds) * time.Second
-}
-
-// ValidateMaintenanceConfigWithSchema validates protobuf maintenance configuration using ConfigField rules
-func ValidateMaintenanceConfigWithSchema(config *worker_pb.MaintenanceConfig) error {
- if config == nil {
- return fmt.Errorf("configuration cannot be nil")
- }
-
- // Get the schema to access field validation rules
- schema := GetMaintenanceConfigSchema()
-
- // Validate each field individually using the ConfigField rules
- if err := validateFieldWithSchema(schema, "enabled", config.Enabled); err != nil {
- return err
- }
-
- if err := validateFieldWithSchema(schema, "scan_interval_seconds", int(config.ScanIntervalSeconds)); err != nil {
- return err
- }
-
- if err := validateFieldWithSchema(schema, "worker_timeout_seconds", int(config.WorkerTimeoutSeconds)); err != nil {
- return err
- }
-
- if err := validateFieldWithSchema(schema, "task_timeout_seconds", int(config.TaskTimeoutSeconds)); err != nil {
- return err
- }
-
- if err := validateFieldWithSchema(schema, "retry_delay_seconds", int(config.RetryDelaySeconds)); err != nil {
- return err
- }
-
- if err := validateFieldWithSchema(schema, "max_retries", int(config.MaxRetries)); err != nil {
- return err
- }
-
- if err := validateFieldWithSchema(schema, "cleanup_interval_seconds", int(config.CleanupIntervalSeconds)); err != nil {
- return err
- }
-
- if err := validateFieldWithSchema(schema, "task_retention_seconds", int(config.TaskRetentionSeconds)); err != nil {
- return err
- }
-
- // Validate policy fields if present
- if config.Policy != nil {
- // Note: These field names might need to be adjusted based on the actual schema
- if err := validatePolicyField("global_max_concurrent", int(config.Policy.GlobalMaxConcurrent)); err != nil {
- return err
- }
-
- if err := validatePolicyField("default_repeat_interval_seconds", int(config.Policy.DefaultRepeatIntervalSeconds)); err != nil {
- return err
- }
-
- if err := validatePolicyField("default_check_interval_seconds", int(config.Policy.DefaultCheckIntervalSeconds)); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-// validateFieldWithSchema validates a single field using its ConfigField definition
-func validateFieldWithSchema(schema *MaintenanceConfigSchema, fieldName string, value interface{}) error {
- field := schema.GetFieldByName(fieldName)
- if field == nil {
- // Field not in schema, skip validation
- return nil
- }
-
- return field.ValidateValue(value)
-}
-
-// validatePolicyField validates policy fields (simplified validation for now)
-func validatePolicyField(fieldName string, value int) error {
- switch fieldName {
- case "global_max_concurrent":
- if value < 1 || value > 20 {
- return fmt.Errorf("Global Max Concurrent must be between 1 and 20, got %d", value)
- }
- case "default_repeat_interval":
- if value < 1 || value > 168 {
- return fmt.Errorf("Default Repeat Interval must be between 1 and 168 hours, got %d", value)
- }
- case "default_check_interval":
- if value < 1 || value > 168 {
- return fmt.Errorf("Default Check Interval must be between 1 and 168 hours, got %d", value)
- }
- }
- return nil
-}
diff --git a/weed/admin/maintenance/maintenance_queue.go b/weed/admin/maintenance/maintenance_queue.go
index 28dbc1c5c..dc6546d40 100644
--- a/weed/admin/maintenance/maintenance_queue.go
+++ b/weed/admin/maintenance/maintenance_queue.go
@@ -1055,28 +1055,6 @@ func (mq *MaintenanceQueue) getMaxConcurrentForTaskType(taskType MaintenanceTask
return 1
}
-// getRunningTasks returns all currently running tasks
-func (mq *MaintenanceQueue) getRunningTasks() []*MaintenanceTask {
- var runningTasks []*MaintenanceTask
- for _, task := range mq.tasks {
- if task.Status == TaskStatusAssigned || task.Status == TaskStatusInProgress {
- runningTasks = append(runningTasks, task)
- }
- }
- return runningTasks
-}
-
-// getAvailableWorkers returns all workers that can take more work
-func (mq *MaintenanceQueue) getAvailableWorkers() []*MaintenanceWorker {
- var availableWorkers []*MaintenanceWorker
- for _, worker := range mq.workers {
- if worker.Status == "active" && worker.CurrentLoad < worker.MaxConcurrent {
- availableWorkers = append(availableWorkers, worker)
- }
- }
- return availableWorkers
-}
-
// trackPendingOperation adds a task to the pending operations tracker
func (mq *MaintenanceQueue) trackPendingOperation(task *MaintenanceTask) {
if mq.integration == nil {
diff --git a/weed/admin/maintenance/maintenance_types.go b/weed/admin/maintenance/maintenance_types.go
index 31c797e50..bb8f0a737 100644
--- a/weed/admin/maintenance/maintenance_types.go
+++ b/weed/admin/maintenance/maintenance_types.go
@@ -2,15 +2,11 @@ package maintenance
import (
"html/template"
- "sort"
"sync"
"time"
- "github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/pb/master_pb"
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
- "github.com/seaweedfs/seaweedfs/weed/worker/tasks"
- "github.com/seaweedfs/seaweedfs/weed/worker/types"
)
// AdminClient interface defines what the maintenance system needs from the admin server
@@ -21,51 +17,6 @@ type AdminClient interface {
// MaintenanceTaskType represents different types of maintenance operations
type MaintenanceTaskType string
-// GetRegisteredMaintenanceTaskTypes returns all registered task types as MaintenanceTaskType values
-// sorted alphabetically for consistent menu ordering
-func GetRegisteredMaintenanceTaskTypes() []MaintenanceTaskType {
- typesRegistry := tasks.GetGlobalTypesRegistry()
- var taskTypes []MaintenanceTaskType
-
- for workerTaskType := range typesRegistry.GetAllDetectors() {
- maintenanceTaskType := MaintenanceTaskType(string(workerTaskType))
- taskTypes = append(taskTypes, maintenanceTaskType)
- }
-
- // Sort task types alphabetically to ensure consistent menu ordering
- sort.Slice(taskTypes, func(i, j int) bool {
- return string(taskTypes[i]) < string(taskTypes[j])
- })
-
- return taskTypes
-}
-
-// GetMaintenanceTaskType returns a specific task type if it's registered, or empty string if not found
-func GetMaintenanceTaskType(taskTypeName string) MaintenanceTaskType {
- typesRegistry := tasks.GetGlobalTypesRegistry()
-
- for workerTaskType := range typesRegistry.GetAllDetectors() {
- if string(workerTaskType) == taskTypeName {
- return MaintenanceTaskType(taskTypeName)
- }
- }
-
- return MaintenanceTaskType("")
-}
-
-// IsMaintenanceTaskTypeRegistered checks if a task type is registered
-func IsMaintenanceTaskTypeRegistered(taskType MaintenanceTaskType) bool {
- typesRegistry := tasks.GetGlobalTypesRegistry()
-
- for workerTaskType := range typesRegistry.GetAllDetectors() {
- if string(workerTaskType) == string(taskType) {
- return true
- }
- }
-
- return false
-}
-
// MaintenanceTaskPriority represents task execution priority
type MaintenanceTaskPriority int
@@ -200,14 +151,6 @@ func GetTaskPolicy(mp *MaintenancePolicy, taskType MaintenanceTaskType) *TaskPol
return mp.TaskPolicies[string(taskType)]
}
-// SetTaskPolicy sets the policy for a specific task type
-func SetTaskPolicy(mp *MaintenancePolicy, taskType MaintenanceTaskType, policy *TaskPolicy) {
- if mp.TaskPolicies == nil {
- mp.TaskPolicies = make(map[string]*TaskPolicy)
- }
- mp.TaskPolicies[string(taskType)] = policy
-}
-
// IsTaskEnabled returns whether a task type is enabled
func IsTaskEnabled(mp *MaintenancePolicy, taskType MaintenanceTaskType) bool {
policy := GetTaskPolicy(mp, taskType)
@@ -235,84 +178,6 @@ func GetRepeatInterval(mp *MaintenancePolicy, taskType MaintenanceTaskType) int
return int(policy.RepeatIntervalSeconds)
}
-// GetVacuumTaskConfig returns the vacuum task configuration
-func GetVacuumTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.VacuumTaskConfig {
- policy := GetTaskPolicy(mp, taskType)
- if policy == nil {
- return nil
- }
- return policy.GetVacuumConfig()
-}
-
-// GetErasureCodingTaskConfig returns the erasure coding task configuration
-func GetErasureCodingTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.ErasureCodingTaskConfig {
- policy := GetTaskPolicy(mp, taskType)
- if policy == nil {
- return nil
- }
- return policy.GetErasureCodingConfig()
-}
-
-// GetBalanceTaskConfig returns the balance task configuration
-func GetBalanceTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.BalanceTaskConfig {
- policy := GetTaskPolicy(mp, taskType)
- if policy == nil {
- return nil
- }
- return policy.GetBalanceConfig()
-}
-
-// GetReplicationTaskConfig returns the replication task configuration
-func GetReplicationTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.ReplicationTaskConfig {
- policy := GetTaskPolicy(mp, taskType)
- if policy == nil {
- return nil
- }
- return policy.GetReplicationConfig()
-}
-
-// Note: GetTaskConfig was removed - use typed getters: GetVacuumTaskConfig, GetErasureCodingTaskConfig, GetBalanceTaskConfig, or GetReplicationTaskConfig
-
-// SetVacuumTaskConfig sets the vacuum task configuration
-func SetVacuumTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.VacuumTaskConfig) {
- policy := GetTaskPolicy(mp, taskType)
- if policy != nil {
- policy.TaskConfig = &worker_pb.TaskPolicy_VacuumConfig{
- VacuumConfig: config,
- }
- }
-}
-
-// SetErasureCodingTaskConfig sets the erasure coding task configuration
-func SetErasureCodingTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.ErasureCodingTaskConfig) {
- policy := GetTaskPolicy(mp, taskType)
- if policy != nil {
- policy.TaskConfig = &worker_pb.TaskPolicy_ErasureCodingConfig{
- ErasureCodingConfig: config,
- }
- }
-}
-
-// SetBalanceTaskConfig sets the balance task configuration
-func SetBalanceTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.BalanceTaskConfig) {
- policy := GetTaskPolicy(mp, taskType)
- if policy != nil {
- policy.TaskConfig = &worker_pb.TaskPolicy_BalanceConfig{
- BalanceConfig: config,
- }
- }
-}
-
-// SetReplicationTaskConfig sets the replication task configuration
-func SetReplicationTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.ReplicationTaskConfig) {
- policy := GetTaskPolicy(mp, taskType)
- if policy != nil {
- policy.TaskConfig = &worker_pb.TaskPolicy_ReplicationConfig{
- ReplicationConfig: config,
- }
- }
-}
-
// SetTaskConfig sets a configuration value for a task type (legacy method - use typed setters above)
// Note: SetTaskConfig was removed - use typed setters: SetVacuumTaskConfig, SetErasureCodingTaskConfig, SetBalanceTaskConfig, or SetReplicationTaskConfig
@@ -475,180 +340,6 @@ type ClusterReplicationTask struct {
Metadata map[string]string `json:"metadata,omitempty"`
}
-// BuildMaintenancePolicyFromTasks creates a maintenance policy with configurations
-// from all registered tasks using their UI providers
-func BuildMaintenancePolicyFromTasks() *MaintenancePolicy {
- policy := &MaintenancePolicy{
- TaskPolicies: make(map[string]*TaskPolicy),
- GlobalMaxConcurrent: 4,
- DefaultRepeatIntervalSeconds: 6 * 3600, // 6 hours in seconds
- DefaultCheckIntervalSeconds: 12 * 3600, // 12 hours in seconds
- }
-
- // Get all registered task types from the UI registry
- uiRegistry := tasks.GetGlobalUIRegistry()
- typesRegistry := tasks.GetGlobalTypesRegistry()
-
- for taskType, provider := range uiRegistry.GetAllProviders() {
- // Convert task type to maintenance task type
- maintenanceTaskType := MaintenanceTaskType(string(taskType))
-
- // Get the default configuration from the UI provider
- defaultConfig := provider.GetCurrentConfig()
-
- // Create task policy from UI configuration
- taskPolicy := &TaskPolicy{
- Enabled: true, // Default enabled
- MaxConcurrent: 2, // Default concurrency
- RepeatIntervalSeconds: policy.DefaultRepeatIntervalSeconds,
- CheckIntervalSeconds: policy.DefaultCheckIntervalSeconds,
- }
-
- // Extract configuration using TaskConfig interface - no more map conversions!
- if taskConfig, ok := defaultConfig.(interface{ ToTaskPolicy() *worker_pb.TaskPolicy }); ok {
- // Use protobuf directly for clean, type-safe config extraction
- pbTaskPolicy := taskConfig.ToTaskPolicy()
- taskPolicy.Enabled = pbTaskPolicy.Enabled
- taskPolicy.MaxConcurrent = pbTaskPolicy.MaxConcurrent
- if pbTaskPolicy.RepeatIntervalSeconds > 0 {
- taskPolicy.RepeatIntervalSeconds = pbTaskPolicy.RepeatIntervalSeconds
- }
- if pbTaskPolicy.CheckIntervalSeconds > 0 {
- taskPolicy.CheckIntervalSeconds = pbTaskPolicy.CheckIntervalSeconds
- }
- }
-
- // Also get defaults from scheduler if available (using types.TaskScheduler explicitly)
- var scheduler types.TaskScheduler = typesRegistry.GetScheduler(taskType)
- if scheduler != nil {
- if taskPolicy.MaxConcurrent <= 0 {
- taskPolicy.MaxConcurrent = int32(scheduler.GetMaxConcurrent())
- }
- // Convert default repeat interval to seconds
- if repeatInterval := scheduler.GetDefaultRepeatInterval(); repeatInterval > 0 {
- taskPolicy.RepeatIntervalSeconds = int32(repeatInterval.Seconds())
- }
- }
-
- // Also get defaults from detector if available (using types.TaskDetector explicitly)
- var detector types.TaskDetector = typesRegistry.GetDetector(taskType)
- if detector != nil {
- // Convert scan interval to check interval (seconds)
- if scanInterval := detector.ScanInterval(); scanInterval > 0 {
- taskPolicy.CheckIntervalSeconds = int32(scanInterval.Seconds())
- }
- }
-
- policy.TaskPolicies[string(maintenanceTaskType)] = taskPolicy
- glog.V(3).Infof("Built policy for task type %s: enabled=%v, max_concurrent=%d",
- maintenanceTaskType, taskPolicy.Enabled, taskPolicy.MaxConcurrent)
- }
-
- glog.V(2).Infof("Built maintenance policy with %d task configurations", len(policy.TaskPolicies))
- return policy
-}
-
-// SetPolicyFromTasks sets the maintenance policy from registered tasks
-func SetPolicyFromTasks(policy *MaintenancePolicy) {
- if policy == nil {
- return
- }
-
- // Build new policy from tasks
- newPolicy := BuildMaintenancePolicyFromTasks()
-
- // Copy task policies
- policy.TaskPolicies = newPolicy.TaskPolicies
-
- glog.V(1).Infof("Updated maintenance policy with %d task configurations from registered tasks", len(policy.TaskPolicies))
-}
-
-// GetTaskIcon returns the icon CSS class for a task type from its UI provider
-func GetTaskIcon(taskType MaintenanceTaskType) string {
- typesRegistry := tasks.GetGlobalTypesRegistry()
- uiRegistry := tasks.GetGlobalUIRegistry()
-
- // Convert MaintenanceTaskType to TaskType
- for workerTaskType := range typesRegistry.GetAllDetectors() {
- if string(workerTaskType) == string(taskType) {
- // Get the UI provider for this task type
- provider := uiRegistry.GetProvider(workerTaskType)
- if provider != nil {
- return provider.GetIcon()
- }
- break
- }
- }
-
- // Default icon if no UI provider found
- return "fas fa-cog text-muted"
-}
-
-// GetTaskDisplayName returns the display name for a task type from its UI provider
-func GetTaskDisplayName(taskType MaintenanceTaskType) string {
- typesRegistry := tasks.GetGlobalTypesRegistry()
- uiRegistry := tasks.GetGlobalUIRegistry()
-
- // Convert MaintenanceTaskType to TaskType
- for workerTaskType := range typesRegistry.GetAllDetectors() {
- if string(workerTaskType) == string(taskType) {
- // Get the UI provider for this task type
- provider := uiRegistry.GetProvider(workerTaskType)
- if provider != nil {
- return provider.GetDisplayName()
- }
- break
- }
- }
-
- // Fallback to the task type string
- return string(taskType)
-}
-
-// GetTaskDescription returns the description for a task type from its UI provider
-func GetTaskDescription(taskType MaintenanceTaskType) string {
- typesRegistry := tasks.GetGlobalTypesRegistry()
- uiRegistry := tasks.GetGlobalUIRegistry()
-
- // Convert MaintenanceTaskType to TaskType
- for workerTaskType := range typesRegistry.GetAllDetectors() {
- if string(workerTaskType) == string(taskType) {
- // Get the UI provider for this task type
- provider := uiRegistry.GetProvider(workerTaskType)
- if provider != nil {
- return provider.GetDescription()
- }
- break
- }
- }
-
- // Fallback to a generic description
- return "Configure detailed settings for " + string(taskType) + " tasks."
-}
-
-// BuildMaintenanceMenuItems creates menu items for all registered task types
-func BuildMaintenanceMenuItems() []*MaintenanceMenuItem {
- var menuItems []*MaintenanceMenuItem
-
- // Get all registered task types
- registeredTypes := GetRegisteredMaintenanceTaskTypes()
-
- for _, taskType := range registeredTypes {
- menuItem := &MaintenanceMenuItem{
- TaskType: taskType,
- DisplayName: GetTaskDisplayName(taskType),
- Description: GetTaskDescription(taskType),
- Icon: GetTaskIcon(taskType),
- IsEnabled: IsMaintenanceTaskTypeRegistered(taskType),
- Path: "/maintenance/config/" + string(taskType),
- }
-
- menuItems = append(menuItems, menuItem)
- }
-
- return menuItems
-}
-
// Helper functions to extract configuration fields
// Note: Removed getVacuumConfigField, getErasureCodingConfigField, getBalanceConfigField, getReplicationConfigField
diff --git a/weed/admin/maintenance/maintenance_worker.go b/weed/admin/maintenance/maintenance_worker.go
deleted file mode 100644
index e4a6b4cf6..000000000
--- a/weed/admin/maintenance/maintenance_worker.go
+++ /dev/null
@@ -1,421 +0,0 @@
-package maintenance
-
-import (
- "context"
- "fmt"
- "os"
- "sync"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/worker"
- "github.com/seaweedfs/seaweedfs/weed/worker/tasks"
- "github.com/seaweedfs/seaweedfs/weed/worker/types"
-
- // Import task packages to trigger their auto-registration
- _ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/balance"
- _ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/erasure_coding"
- _ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/vacuum"
-)
-
-// MaintenanceWorkerService manages maintenance task execution
-// TaskExecutor defines the function signature for task execution
-type TaskExecutor func(*MaintenanceWorkerService, *MaintenanceTask) error
-
-// TaskExecutorFactory creates a task executor for a given worker service
-type TaskExecutorFactory func() TaskExecutor
-
-// Global registry for task executor factories
-var taskExecutorFactories = make(map[MaintenanceTaskType]TaskExecutorFactory)
-var executorRegistryMutex sync.RWMutex
-var executorRegistryInitOnce sync.Once
-
-// initializeExecutorFactories dynamically registers executor factories for all auto-registered task types
-func initializeExecutorFactories() {
- executorRegistryInitOnce.Do(func() {
- // Get all registered task types from the global registry
- typesRegistry := tasks.GetGlobalTypesRegistry()
-
- var taskTypes []MaintenanceTaskType
- for workerTaskType := range typesRegistry.GetAllDetectors() {
- // Convert types.TaskType to MaintenanceTaskType by string conversion
- maintenanceTaskType := MaintenanceTaskType(string(workerTaskType))
- taskTypes = append(taskTypes, maintenanceTaskType)
- }
-
- // Register generic executor for all task types
- for _, taskType := range taskTypes {
- RegisterTaskExecutorFactory(taskType, createGenericTaskExecutor)
- }
-
- glog.V(1).Infof("Dynamically registered generic task executor for %d task types: %v", len(taskTypes), taskTypes)
- })
-}
-
-// RegisterTaskExecutorFactory registers a factory function for creating task executors
-func RegisterTaskExecutorFactory(taskType MaintenanceTaskType, factory TaskExecutorFactory) {
- executorRegistryMutex.Lock()
- defer executorRegistryMutex.Unlock()
- taskExecutorFactories[taskType] = factory
- glog.V(2).Infof("Registered executor factory for task type: %s", taskType)
-}
-
-// GetTaskExecutorFactory returns the factory for a task type
-func GetTaskExecutorFactory(taskType MaintenanceTaskType) (TaskExecutorFactory, bool) {
- // Ensure executor factories are initialized
- initializeExecutorFactories()
-
- executorRegistryMutex.RLock()
- defer executorRegistryMutex.RUnlock()
- factory, exists := taskExecutorFactories[taskType]
- return factory, exists
-}
-
-// GetSupportedExecutorTaskTypes returns all task types with registered executor factories
-func GetSupportedExecutorTaskTypes() []MaintenanceTaskType {
- // Ensure executor factories are initialized
- initializeExecutorFactories()
-
- executorRegistryMutex.RLock()
- defer executorRegistryMutex.RUnlock()
-
- taskTypes := make([]MaintenanceTaskType, 0, len(taskExecutorFactories))
- for taskType := range taskExecutorFactories {
- taskTypes = append(taskTypes, taskType)
- }
- return taskTypes
-}
-
-// createGenericTaskExecutor creates a generic task executor that uses the task registry
-func createGenericTaskExecutor() TaskExecutor {
- return func(mws *MaintenanceWorkerService, task *MaintenanceTask) error {
- return mws.executeGenericTask(task)
- }
-}
-
-// init does minimal initialization - actual registration happens lazily
-func init() {
- // Executor factory registration will happen lazily when first accessed
- glog.V(1).Infof("Maintenance worker initialized - executor factories will be registered on first access")
-}
-
-type MaintenanceWorkerService struct {
- workerID string
- address string
- adminServer string
- capabilities []MaintenanceTaskType
- maxConcurrent int
- currentTasks map[string]*MaintenanceTask
- queue *MaintenanceQueue
- adminClient AdminClient
- running bool
- stopChan chan struct{}
-
- // Task execution registry
- taskExecutors map[MaintenanceTaskType]TaskExecutor
-
- // Task registry for creating task instances
- taskRegistry *tasks.TaskRegistry
-}
-
-// NewMaintenanceWorkerService creates a new maintenance worker service
-func NewMaintenanceWorkerService(workerID, address, adminServer string) *MaintenanceWorkerService {
- // Get all registered maintenance task types dynamically
- capabilities := GetRegisteredMaintenanceTaskTypes()
-
- worker := &MaintenanceWorkerService{
- workerID: workerID,
- address: address,
- adminServer: adminServer,
- capabilities: capabilities,
- maxConcurrent: 2, // Default concurrent task limit
- currentTasks: make(map[string]*MaintenanceTask),
- stopChan: make(chan struct{}),
- taskExecutors: make(map[MaintenanceTaskType]TaskExecutor),
- taskRegistry: tasks.GetGlobalTaskRegistry(), // Use global registry with auto-registered tasks
- }
-
- // Initialize task executor registry
- worker.initializeTaskExecutors()
-
- glog.V(1).Infof("Created maintenance worker with %d registered task types", len(worker.taskRegistry.GetAll()))
-
- return worker
-}
-
-// executeGenericTask executes a task using the task registry instead of hardcoded methods
-func (mws *MaintenanceWorkerService) executeGenericTask(task *MaintenanceTask) error {
- glog.V(2).Infof("Executing generic task %s: %s for volume %d", task.ID, task.Type, task.VolumeID)
-
- // Validate that task has proper typed parameters
- if task.TypedParams == nil {
- return fmt.Errorf("task %s has no typed parameters - task was not properly planned (insufficient destinations)", task.ID)
- }
-
- // Convert MaintenanceTask to types.TaskType
- taskType := types.TaskType(string(task.Type))
-
- // Create task instance using the registry
- taskInstance, err := mws.taskRegistry.Get(taskType).Create(task.TypedParams)
- if err != nil {
- return fmt.Errorf("failed to create task instance: %w", err)
- }
-
- // Update progress to show task has started
- mws.updateTaskProgress(task.ID, 5)
-
- // Execute the task
- err = taskInstance.Execute(context.Background(), task.TypedParams)
- if err != nil {
- return fmt.Errorf("task execution failed: %w", err)
- }
-
- // Update progress to show completion
- mws.updateTaskProgress(task.ID, 100)
-
- glog.V(2).Infof("Generic task %s completed successfully", task.ID)
- return nil
-}
-
-// initializeTaskExecutors sets up the task execution registry dynamically
-func (mws *MaintenanceWorkerService) initializeTaskExecutors() {
- mws.taskExecutors = make(map[MaintenanceTaskType]TaskExecutor)
-
- // Get all registered executor factories and create executors
- executorRegistryMutex.RLock()
- defer executorRegistryMutex.RUnlock()
-
- for taskType, factory := range taskExecutorFactories {
- executor := factory()
- mws.taskExecutors[taskType] = executor
- glog.V(3).Infof("Initialized executor for task type: %s", taskType)
- }
-
- glog.V(2).Infof("Initialized %d task executors", len(mws.taskExecutors))
-}
-
-// RegisterTaskExecutor allows dynamic registration of new task executors
-func (mws *MaintenanceWorkerService) RegisterTaskExecutor(taskType MaintenanceTaskType, executor TaskExecutor) {
- if mws.taskExecutors == nil {
- mws.taskExecutors = make(map[MaintenanceTaskType]TaskExecutor)
- }
- mws.taskExecutors[taskType] = executor
- glog.V(1).Infof("Registered executor for task type: %s", taskType)
-}
-
-// GetSupportedTaskTypes returns all task types that this worker can execute
-func (mws *MaintenanceWorkerService) GetSupportedTaskTypes() []MaintenanceTaskType {
- return GetSupportedExecutorTaskTypes()
-}
-
-// Start begins the worker service
-func (mws *MaintenanceWorkerService) Start() error {
- mws.running = true
-
- // Register with admin server
- worker := &MaintenanceWorker{
- ID: mws.workerID,
- Address: mws.address,
- Capabilities: mws.capabilities,
- MaxConcurrent: mws.maxConcurrent,
- }
-
- if mws.queue != nil {
- mws.queue.RegisterWorker(worker)
- }
-
- // Start worker loop
- go mws.workerLoop()
-
- glog.Infof("Maintenance worker %s started at %s", mws.workerID, mws.address)
- return nil
-}
-
-// Stop terminates the worker service
-func (mws *MaintenanceWorkerService) Stop() {
- mws.running = false
- close(mws.stopChan)
-
- // Wait for current tasks to complete or timeout
- timeout := time.NewTimer(30 * time.Second)
- defer timeout.Stop()
-
- for len(mws.currentTasks) > 0 {
- select {
- case <-timeout.C:
- glog.Warningf("Worker %s stopping with %d tasks still running", mws.workerID, len(mws.currentTasks))
- return
- case <-time.After(time.Second):
- // Check again
- }
- }
-
- glog.Infof("Maintenance worker %s stopped", mws.workerID)
-}
-
-// workerLoop is the main worker event loop
-func (mws *MaintenanceWorkerService) workerLoop() {
- heartbeatTicker := time.NewTicker(30 * time.Second)
- defer heartbeatTicker.Stop()
-
- taskRequestTicker := time.NewTicker(5 * time.Second)
- defer taskRequestTicker.Stop()
-
- for mws.running {
- select {
- case <-mws.stopChan:
- return
- case <-heartbeatTicker.C:
- mws.sendHeartbeat()
- case <-taskRequestTicker.C:
- mws.requestTasks()
- }
- }
-}
-
-// sendHeartbeat sends heartbeat to admin server
-func (mws *MaintenanceWorkerService) sendHeartbeat() {
- if mws.queue != nil {
- mws.queue.UpdateWorkerHeartbeat(mws.workerID)
- }
-}
-
-// requestTasks requests new tasks from the admin server
-func (mws *MaintenanceWorkerService) requestTasks() {
- if len(mws.currentTasks) >= mws.maxConcurrent {
- return // Already at capacity
- }
-
- if mws.queue != nil {
- task := mws.queue.GetNextTask(mws.workerID, mws.capabilities)
- if task != nil {
- mws.executeTask(task)
- }
- }
-}
-
-// executeTask executes a maintenance task
-func (mws *MaintenanceWorkerService) executeTask(task *MaintenanceTask) {
- mws.currentTasks[task.ID] = task
-
- go func() {
- defer func() {
- delete(mws.currentTasks, task.ID)
- }()
-
- glog.Infof("Worker %s executing task %s: %s", mws.workerID, task.ID, task.Type)
-
- // Execute task using dynamic executor registry
- var err error
- if executor, exists := mws.taskExecutors[task.Type]; exists {
- err = executor(mws, task)
- } else {
- err = fmt.Errorf("unsupported task type: %s", task.Type)
- glog.Errorf("No executor registered for task type: %s", task.Type)
- }
-
- // Report task completion
- if mws.queue != nil {
- errorMsg := ""
- if err != nil {
- errorMsg = err.Error()
- }
- mws.queue.CompleteTask(task.ID, errorMsg)
- }
-
- if err != nil {
- glog.Errorf("Worker %s failed to execute task %s: %v", mws.workerID, task.ID, err)
- } else {
- glog.Infof("Worker %s completed task %s successfully", mws.workerID, task.ID)
- }
- }()
-}
-
-// updateTaskProgress updates the progress of a task
-func (mws *MaintenanceWorkerService) updateTaskProgress(taskID string, progress float64) {
- if mws.queue != nil {
- mws.queue.UpdateTaskProgress(taskID, progress)
- }
-}
-
-// GetStatus returns the current status of the worker
-func (mws *MaintenanceWorkerService) GetStatus() map[string]interface{} {
- return map[string]interface{}{
- "worker_id": mws.workerID,
- "address": mws.address,
- "running": mws.running,
- "capabilities": mws.capabilities,
- "max_concurrent": mws.maxConcurrent,
- "current_tasks": len(mws.currentTasks),
- "task_details": mws.currentTasks,
- }
-}
-
-// SetQueue sets the maintenance queue for the worker
-func (mws *MaintenanceWorkerService) SetQueue(queue *MaintenanceQueue) {
- mws.queue = queue
-}
-
-// SetAdminClient sets the admin client for the worker
-func (mws *MaintenanceWorkerService) SetAdminClient(client AdminClient) {
- mws.adminClient = client
-}
-
-// SetCapabilities sets the worker capabilities
-func (mws *MaintenanceWorkerService) SetCapabilities(capabilities []MaintenanceTaskType) {
- mws.capabilities = capabilities
-}
-
-// SetMaxConcurrent sets the maximum concurrent tasks
-func (mws *MaintenanceWorkerService) SetMaxConcurrent(max int) {
- mws.maxConcurrent = max
-}
-
-// SetHeartbeatInterval sets the heartbeat interval (placeholder for future use)
-func (mws *MaintenanceWorkerService) SetHeartbeatInterval(interval time.Duration) {
- // Future implementation for configurable heartbeat
-}
-
-// SetTaskRequestInterval sets the task request interval (placeholder for future use)
-func (mws *MaintenanceWorkerService) SetTaskRequestInterval(interval time.Duration) {
- // Future implementation for configurable task requests
-}
-
-// MaintenanceWorkerCommand represents a standalone maintenance worker command
-type MaintenanceWorkerCommand struct {
- workerService *MaintenanceWorkerService
-}
-
-// NewMaintenanceWorkerCommand creates a new worker command
-func NewMaintenanceWorkerCommand(workerID, address, adminServer string) *MaintenanceWorkerCommand {
- return &MaintenanceWorkerCommand{
- workerService: NewMaintenanceWorkerService(workerID, address, adminServer),
- }
-}
-
-// Run starts the maintenance worker as a standalone service
-func (mwc *MaintenanceWorkerCommand) Run() error {
- // Generate or load persistent worker ID if not provided
- if mwc.workerService.workerID == "" {
- // Get current working directory for worker ID persistence
- wd, err := os.Getwd()
- if err != nil {
- return fmt.Errorf("failed to get working directory: %w", err)
- }
-
- workerID, err := worker.GenerateOrLoadWorkerID(wd)
- if err != nil {
- return fmt.Errorf("failed to generate or load worker ID: %w", err)
- }
- mwc.workerService.workerID = workerID
- }
-
- // Start the worker service
- err := mwc.workerService.Start()
- if err != nil {
- return fmt.Errorf("failed to start maintenance worker: %w", err)
- }
-
- // Wait for interrupt signal
- select {}
-}
diff --git a/weed/admin/plugin/plugin.go b/weed/admin/plugin/plugin.go
index e14e7ae41..094a15830 100644
--- a/weed/admin/plugin/plugin.go
+++ b/weed/admin/plugin/plugin.go
@@ -122,6 +122,7 @@ type Plugin struct {
type streamSession struct {
workerID string
outgoing chan *plugin_pb.AdminToWorkerMessage
+ done chan struct{}
closeOnce sync.Once
}
@@ -274,6 +275,7 @@ func (r *Plugin) WorkerStream(stream plugin_pb.PluginControlService_WorkerStream
session := &streamSession{
workerID: workerID,
outgoing: make(chan *plugin_pb.AdminToWorkerMessage, r.outgoingBuffer),
+ done: make(chan struct{}),
}
r.putSession(session)
defer r.cleanupSession(workerID)
@@ -908,8 +910,10 @@ func (r *Plugin) sendLoop(
return nil
case <-r.shutdownCh:
return nil
- case msg, ok := <-session.outgoing:
- if !ok {
+ case <-session.done:
+ return nil
+ case msg := <-session.outgoing:
+ if msg == nil {
return nil
}
if err := stream.Send(msg); err != nil {
@@ -930,6 +934,8 @@ func (r *Plugin) sendToWorker(workerID string, message *plugin_pb.AdminToWorkerM
select {
case <-r.shutdownCh:
return fmt.Errorf("plugin is shutting down")
+ case <-session.done:
+ return fmt.Errorf("worker %s session is closed", workerID)
case session.outgoing <- message:
return nil
case <-time.After(r.sendTimeout):
@@ -1425,7 +1431,7 @@ func CloneConfigValueMap(in map[string]*plugin_pb.ConfigValue) map[string]*plugi
func (s *streamSession) close() {
s.closeOnce.Do(func() {
- close(s.outgoing)
+ close(s.done)
})
}
diff --git a/weed/admin/plugin/plugin_cancel_test.go b/weed/admin/plugin/plugin_cancel_test.go
index 2a966ae8c..ef129ea08 100644
--- a/weed/admin/plugin/plugin_cancel_test.go
+++ b/weed/admin/plugin/plugin_cancel_test.go
@@ -26,7 +26,7 @@ func TestRunDetectionSendsCancelOnContextDone(t *testing.T) {
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
},
})
- session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4)}
+ session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4), done: make(chan struct{})}
pluginSvc.putSession(session)
ctx, cancel := context.WithCancel(context.Background())
@@ -77,7 +77,7 @@ func TestExecuteJobSendsCancelOnContextDone(t *testing.T) {
{JobType: jobType, CanExecute: true, MaxExecutionConcurrency: 1},
},
})
- session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4)}
+ session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4), done: make(chan struct{})}
pluginSvc.putSession(session)
job := &plugin_pb.JobSpec{JobId: "job-1", JobType: jobType}
@@ -135,8 +135,8 @@ func TestAdminScriptExecutionBlocksOtherDetection(t *testing.T) {
{JobType: "vacuum", CanDetect: true, MaxDetectionConcurrency: 1},
},
})
- adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)}
- otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)}
+ adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
+ otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
pluginSvc.putSession(adminSession)
pluginSvc.putSession(otherSession)
@@ -214,8 +214,8 @@ func TestAdminScriptExecutionBlocksOtherExecution(t *testing.T) {
{JobType: "vacuum", CanExecute: true, MaxExecutionConcurrency: 1},
},
})
- adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)}
- otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)}
+ adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
+ otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
pluginSvc.putSession(adminSession)
pluginSvc.putSession(otherSession)
diff --git a/weed/admin/plugin/plugin_detection_test.go b/weed/admin/plugin/plugin_detection_test.go
index be2aac50c..ee86c353a 100644
--- a/weed/admin/plugin/plugin_detection_test.go
+++ b/weed/admin/plugin/plugin_detection_test.go
@@ -22,7 +22,7 @@ func TestRunDetectionIncludesLatestSuccessfulRun(t *testing.T) {
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
},
})
- session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)}
+ session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
pluginSvc.putSession(session)
oldSuccess := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
@@ -80,7 +80,7 @@ func TestRunDetectionOmitsLastSuccessfulRunWhenNoSuccessHistory(t *testing.T) {
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
},
})
- session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)}
+ session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
pluginSvc.putSession(session)
if err := pluginSvc.store.AppendRunRecord(jobType, &JobRunRecord{
@@ -130,7 +130,7 @@ func TestRunDetectionWithReportCapturesDetectionActivities(t *testing.T) {
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
},
})
- session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)}
+ session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
pluginSvc.putSession(session)
reportCh := make(chan *DetectionReport, 1)
@@ -210,7 +210,7 @@ func TestRunDetectionAdminScriptUsesLastCompletedRun(t *testing.T) {
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
},
})
- session := &streamSession{workerID: "worker-admin-script", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)}
+ session := &streamSession{workerID: "worker-admin-script", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
pluginSvc.putSession(session)
successCompleted := time.Date(2026, 2, 1, 10, 0, 0, 0, time.UTC)
diff --git a/weed/admin/plugin/plugin_scheduler.go b/weed/admin/plugin/plugin_scheduler.go
index 248ca8985..121074913 100644
--- a/weed/admin/plugin/plugin_scheduler.go
+++ b/weed/admin/plugin/plugin_scheduler.go
@@ -95,16 +95,6 @@ func (r *Plugin) laneSchedulerLoop(ls *schedulerLaneState) {
}
}
-// schedulerLoop is kept for backward compatibility; it delegates to
-// laneSchedulerLoop with the default lane. New code should not call this.
-func (r *Plugin) schedulerLoop() {
- ls := r.lanes[LaneDefault]
- if ls == nil {
- ls = newLaneState(LaneDefault)
- }
- r.laneSchedulerLoop(ls)
-}
-
// runLaneSchedulerIteration runs one scheduling pass for a single lane,
// processing only the job types assigned to that lane.
//
@@ -229,82 +219,6 @@ func (r *Plugin) runLaneSchedulerIterationConcurrent(ls *schedulerLaneState, job
return hadJobs.Load()
}
-// runSchedulerIteration is kept for backward compatibility. It runs a
-// single iteration across ALL job types (equivalent to the old single-loop
-// behavior). It is only used by the legacy schedulerLoop() fallback.
-func (r *Plugin) runSchedulerIteration() bool {
- ls := r.lanes[LaneDefault]
- if ls == nil {
- ls = newLaneState(LaneDefault)
- }
- // For backward compat, the old function processes all job types.
- r.expireStaleJobs(time.Now().UTC())
-
- jobTypes := r.registry.DetectableJobTypes()
- if len(jobTypes) == 0 {
- r.setSchedulerLoopState("", "idle")
- return false
- }
-
- r.setSchedulerLoopState("", "waiting_for_lock")
- releaseLock, err := r.acquireAdminLock("plugin scheduler iteration")
- if err != nil {
- glog.Warningf("Plugin scheduler failed to acquire lock: %v", err)
- r.setSchedulerLoopState("", "idle")
- return false
- }
- if releaseLock != nil {
- defer releaseLock()
- }
-
- active := make(map[string]struct{}, len(jobTypes))
- hadJobs := false
-
- for _, jobType := range jobTypes {
- active[jobType] = struct{}{}
-
- policy, enabled, err := r.loadSchedulerPolicy(jobType)
- if err != nil {
- glog.Warningf("Plugin scheduler failed to load policy for %s: %v", jobType, err)
- continue
- }
- if !enabled {
- r.clearSchedulerJobType(jobType)
- continue
- }
- initialDelay := time.Duration(0)
- if runInfo := r.snapshotSchedulerRun(jobType); runInfo.lastRunStartedAt.IsZero() {
- initialDelay = 5 * time.Second
- }
- if !r.markDetectionDue(jobType, policy.DetectionInterval, initialDelay) {
- continue
- }
-
- detected := r.runJobTypeIteration(jobType, policy)
- if detected {
- hadJobs = true
- }
- }
-
- r.pruneSchedulerState(active)
- r.pruneDetectorLeases(active)
- r.setSchedulerLoopState("", "idle")
- return hadJobs
-}
-
-// wakeLane wakes the scheduler goroutine for a specific lane.
-func (r *Plugin) wakeLane(lane SchedulerLane) {
- if r == nil {
- return
- }
- if ls, ok := r.lanes[lane]; ok {
- select {
- case ls.wakeCh <- struct{}{}:
- default:
- }
- }
-}
-
// wakeAllLanes wakes all lane scheduler goroutines.
func (r *Plugin) wakeAllLanes() {
if r == nil {
diff --git a/weed/admin/plugin/scheduler_status.go b/weed/admin/plugin/scheduler_status.go
index 19de4ea2e..d5a33069a 100644
--- a/weed/admin/plugin/scheduler_status.go
+++ b/weed/admin/plugin/scheduler_status.go
@@ -210,16 +210,6 @@ func (r *Plugin) setSchedulerLoopStateForJobType(jobType, phase string) {
}
}
-func (r *Plugin) recordSchedulerIterationComplete(hadJobs bool) {
- if r == nil {
- return
- }
- r.schedulerLoopMu.Lock()
- r.schedulerLoopState.lastIterationHadJobs = hadJobs
- r.schedulerLoopState.lastIterationCompleted = time.Now().UTC()
- r.schedulerLoopMu.Unlock()
-}
-
func (r *Plugin) snapshotSchedulerLoopState() schedulerLoopState {
if r == nil {
return schedulerLoopState{}
diff --git a/weed/admin/view/app/template_helpers.go b/weed/admin/view/app/template_helpers.go
index 14814a9bd..fff28de09 100644
--- a/weed/admin/view/app/template_helpers.go
+++ b/weed/admin/view/app/template_helpers.go
@@ -6,20 +6,6 @@ import (
"strings"
)
-// getStatusColor returns Bootstrap color class for status
-func getStatusColor(status string) string {
- switch status {
- case "active", "healthy":
- return "success"
- case "warning":
- return "warning"
- case "critical", "unreachable":
- return "danger"
- default:
- return "secondary"
- }
-}
-
// formatBytes converts bytes to human readable format
func formatBytes(bytes int64) string {
if bytes == 0 {
diff --git a/weed/cluster/cluster.go b/weed/cluster/cluster.go
index 8327065b3..4d4614fb0 100644
--- a/weed/cluster/cluster.go
+++ b/weed/cluster/cluster.go
@@ -95,18 +95,6 @@ func NewCluster() *Cluster {
}
}
-func (cluster *Cluster) getGroupMembers(filerGroup FilerGroupName, nodeType string, createIfNotFound bool) *GroupMembers {
- switch nodeType {
- case FilerType:
- return cluster.filerGroups.getGroupMembers(filerGroup, createIfNotFound)
- case BrokerType:
- return cluster.brokerGroups.getGroupMembers(filerGroup, createIfNotFound)
- case S3Type:
- return cluster.s3Groups.getGroupMembers(filerGroup, createIfNotFound)
- }
- return nil
-}
-
func (cluster *Cluster) AddClusterNode(ns, nodeType string, dataCenter DataCenter, rack Rack, address pb.ServerAddress, version string) []*master_pb.KeepConnectedResponse {
filerGroup := FilerGroupName(ns)
switch nodeType {
diff --git a/weed/command/admin.go b/weed/command/admin.go
index f5e4a8360..6d6dc7198 100644
--- a/weed/command/admin.go
+++ b/weed/command/admin.go
@@ -511,11 +511,6 @@ func recoveryMiddleware(next http.Handler) http.Handler {
})
}
-// GetAdminOptions returns the admin command options for testing
-func GetAdminOptions() *AdminOptions {
- return &AdminOptions{}
-}
-
// loadOrGenerateSessionKeys loads or creates authentication/encryption keys for session cookies.
func loadOrGenerateSessionKeys(dataDir string) ([]byte, []byte, error) {
const keyLen = 32
diff --git a/weed/command/download.go b/weed/command/download.go
index e44335097..a155ad74a 100644
--- a/weed/command/download.go
+++ b/weed/command/download.go
@@ -132,16 +132,3 @@ func fetchContent(masterFn operation.GetMasterFn, grpcDialOption grpc.DialOption
content, e = io.ReadAll(rc.Body)
return
}
-
-func WriteFile(filename string, data []byte, perm os.FileMode) error {
- f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm)
- if err != nil {
- return err
- }
- n, err := f.Write(data)
- f.Close()
- if err == nil && n < len(data) {
- err = io.ErrShortWrite
- }
- return err
-}
diff --git a/weed/credential/config_loader.go b/weed/credential/config_loader.go
index 959f1cfb4..df57b55d3 100644
--- a/weed/credential/config_loader.go
+++ b/weed/credential/config_loader.go
@@ -57,42 +57,6 @@ func LoadCredentialConfiguration() (*CredentialConfig, error) {
}, nil
}
-// GetCredentialStoreConfig extracts credential store configuration from command line flags
-// This is used when credential store is configured via command line instead of credential.toml
-func GetCredentialStoreConfig(store string, config util.Configuration, prefix string) *CredentialConfig {
- if store == "" {
- return nil
- }
-
- return &CredentialConfig{
- Store: store,
- Config: config,
- Prefix: prefix,
- }
-}
-
-// MergeCredentialConfig merges command line credential config with credential.toml config
-// Command line flags take priority over credential.toml
-func MergeCredentialConfig(cmdLineStore string, cmdLineConfig util.Configuration, cmdLinePrefix string) (*CredentialConfig, error) {
- // If command line credential store is specified, use it
- if cmdLineStore != "" {
- glog.V(0).Infof("Using command line credential configuration: store=%s", cmdLineStore)
- return GetCredentialStoreConfig(cmdLineStore, cmdLineConfig, cmdLinePrefix), nil
- }
-
- // Otherwise, try to load from credential.toml
- config, err := LoadCredentialConfiguration()
- if err != nil {
- return nil, err
- }
-
- if config == nil {
- glog.V(1).Info("No credential store configured")
- }
-
- return config, nil
-}
-
// NewCredentialManagerWithDefaults creates a credential manager with fallback to defaults
// If explicitStore is provided, it will be used regardless of credential.toml
// If explicitStore is empty, it tries credential.toml first, then defaults to "filer_etc"
diff --git a/weed/credential/filer_etc/filer_etc_policy.go b/weed/credential/filer_etc/filer_etc_policy.go
index c83e56647..98cf1e721 100644
--- a/weed/credential/filer_etc/filer_etc_policy.go
+++ b/weed/credential/filer_etc/filer_etc_policy.go
@@ -207,32 +207,6 @@ func (store *FilerEtcStore) loadPoliciesFromMultiFile(ctx context.Context, polic
})
}
-func (store *FilerEtcStore) migratePoliciesToMultiFile(ctx context.Context, policies map[string]policy_engine.PolicyDocument) error {
- glog.Infof("Migrating IAM policies to multi-file layout...")
-
- // 1. Save all policies to individual files
- for name, policy := range policies {
- if err := store.savePolicy(ctx, name, policy); err != nil {
- return err
- }
- }
-
- // 2. Rename legacy file
- return store.withFilerClient(func(client filer_pb.SeaweedFilerClient) error {
- _, err := client.AtomicRenameEntry(ctx, &filer_pb.AtomicRenameEntryRequest{
- OldDirectory: filer.IamConfigDirectory,
- OldName: filer.IamPoliciesFile,
- NewDirectory: filer.IamConfigDirectory,
- NewName: IamLegacyPoliciesOldFile,
- })
- if err != nil {
- glog.Errorf("Failed to rename legacy IAM policies file %s/%s to %s: %v",
- filer.IamConfigDirectory, filer.IamPoliciesFile, IamLegacyPoliciesOldFile, err)
- }
- return err
- })
-}
-
func (store *FilerEtcStore) savePolicy(ctx context.Context, name string, document policy_engine.PolicyDocument) error {
if err := validatePolicyName(name); err != nil {
return err
diff --git a/weed/credential/migration.go b/weed/credential/migration.go
deleted file mode 100644
index 41d0e3840..000000000
--- a/weed/credential/migration.go
+++ /dev/null
@@ -1,221 +0,0 @@
-package credential
-
-import (
- "context"
- "fmt"
-
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb"
- "github.com/seaweedfs/seaweedfs/weed/util"
-)
-
-// MigrateCredentials migrates credentials from one store to another
-func MigrateCredentials(fromStoreName, toStoreName CredentialStoreTypeName, configuration util.Configuration, fromPrefix, toPrefix string) error {
- ctx := context.Background()
-
- // Create source credential manager
- fromCM, err := NewCredentialManager(fromStoreName, configuration, fromPrefix)
- if err != nil {
- return fmt.Errorf("failed to create source credential manager (%s): %v", fromStoreName, err)
- }
- defer fromCM.Shutdown()
-
- // Create destination credential manager
- toCM, err := NewCredentialManager(toStoreName, configuration, toPrefix)
- if err != nil {
- return fmt.Errorf("failed to create destination credential manager (%s): %v", toStoreName, err)
- }
- defer toCM.Shutdown()
-
- // Load configuration from source
- glog.Infof("Loading configuration from %s store...", fromStoreName)
- config, err := fromCM.LoadConfiguration(ctx)
- if err != nil {
- return fmt.Errorf("failed to load configuration from source store: %w", err)
- }
-
- if config == nil || len(config.Identities) == 0 {
- glog.Info("No identities found in source store")
- return nil
- }
-
- glog.Infof("Found %d identities in source store", len(config.Identities))
-
- // Migrate each identity
- var migrated, failed int
- for _, identity := range config.Identities {
- glog.V(1).Infof("Migrating user: %s", identity.Name)
-
- // Check if user already exists in destination
- existingUser, err := toCM.GetUser(ctx, identity.Name)
- if err != nil && err != ErrUserNotFound {
- glog.Errorf("Failed to check if user %s exists in destination: %v", identity.Name, err)
- failed++
- continue
- }
-
- if existingUser != nil {
- glog.Warningf("User %s already exists in destination store, skipping", identity.Name)
- continue
- }
-
- // Create user in destination
- err = toCM.CreateUser(ctx, identity)
- if err != nil {
- glog.Errorf("Failed to create user %s in destination store: %v", identity.Name, err)
- failed++
- continue
- }
-
- migrated++
- glog.V(1).Infof("Successfully migrated user: %s", identity.Name)
- }
-
- glog.Infof("Migration completed: %d migrated, %d failed", migrated, failed)
-
- if failed > 0 {
- return fmt.Errorf("migration completed with %d failures", failed)
- }
-
- return nil
-}
-
-// ExportCredentials exports credentials from a store to a configuration
-func ExportCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string) (*iam_pb.S3ApiConfiguration, error) {
- ctx := context.Background()
-
- // Create credential manager
- cm, err := NewCredentialManager(storeName, configuration, prefix)
- if err != nil {
- return nil, fmt.Errorf("failed to create credential manager (%s): %v", storeName, err)
- }
- defer cm.Shutdown()
-
- // Load configuration
- config, err := cm.LoadConfiguration(ctx)
- if err != nil {
- return nil, fmt.Errorf("failed to load configuration: %w", err)
- }
-
- return config, nil
-}
-
-// ImportCredentials imports credentials from a configuration to a store
-func ImportCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string, config *iam_pb.S3ApiConfiguration) error {
- ctx := context.Background()
-
- // Create credential manager
- cm, err := NewCredentialManager(storeName, configuration, prefix)
- if err != nil {
- return fmt.Errorf("failed to create credential manager (%s): %v", storeName, err)
- }
- defer cm.Shutdown()
-
- // Import each identity
- var imported, failed int
- for _, identity := range config.Identities {
- glog.V(1).Infof("Importing user: %s", identity.Name)
-
- // Check if user already exists
- existingUser, err := cm.GetUser(ctx, identity.Name)
- if err != nil && err != ErrUserNotFound {
- glog.Errorf("Failed to check if user %s exists: %v", identity.Name, err)
- failed++
- continue
- }
-
- if existingUser != nil {
- glog.Warningf("User %s already exists, skipping", identity.Name)
- continue
- }
-
- // Create user
- err = cm.CreateUser(ctx, identity)
- if err != nil {
- glog.Errorf("Failed to create user %s: %v", identity.Name, err)
- failed++
- continue
- }
-
- imported++
- glog.V(1).Infof("Successfully imported user: %s", identity.Name)
- }
-
- glog.Infof("Import completed: %d imported, %d failed", imported, failed)
-
- if failed > 0 {
- return fmt.Errorf("import completed with %d failures", failed)
- }
-
- return nil
-}
-
-// ValidateCredentials validates that all credentials in a store are accessible
-func ValidateCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string) error {
- ctx := context.Background()
-
- // Create credential manager
- cm, err := NewCredentialManager(storeName, configuration, prefix)
- if err != nil {
- return fmt.Errorf("failed to create credential manager (%s): %v", storeName, err)
- }
- defer cm.Shutdown()
-
- // Load configuration
- config, err := cm.LoadConfiguration(ctx)
- if err != nil {
- return fmt.Errorf("failed to load configuration: %w", err)
- }
-
- if config == nil || len(config.Identities) == 0 {
- glog.Info("No identities found in store")
- return nil
- }
-
- glog.Infof("Validating %d identities...", len(config.Identities))
-
- // Validate each identity
- var validated, failed int
- for _, identity := range config.Identities {
- // Check if user can be retrieved
- user, err := cm.GetUser(ctx, identity.Name)
- if err != nil {
- glog.Errorf("Failed to retrieve user %s: %v", identity.Name, err)
- failed++
- continue
- }
-
- if user == nil {
- glog.Errorf("User %s not found", identity.Name)
- failed++
- continue
- }
-
- // Validate access keys
- for _, credential := range identity.Credentials {
- accessKeyUser, err := cm.GetUserByAccessKey(ctx, credential.AccessKey)
- if err != nil {
- glog.Errorf("Failed to retrieve user by access key %s: %v", credential.AccessKey, err)
- failed++
- continue
- }
-
- if accessKeyUser == nil || accessKeyUser.Name != identity.Name {
- glog.Errorf("Access key %s does not map to correct user %s", credential.AccessKey, identity.Name)
- failed++
- continue
- }
- }
-
- validated++
- glog.V(1).Infof("Successfully validated user: %s", identity.Name)
- }
-
- glog.Infof("Validation completed: %d validated, %d failed", validated, failed)
-
- if failed > 0 {
- return fmt.Errorf("validation completed with %d failures", failed)
- }
-
- return nil
-}
diff --git a/weed/filer/filer_notify_read.go b/weed/filer/filer_notify_read.go
index 0cf71efe1..cf0641852 100644
--- a/weed/filer/filer_notify_read.go
+++ b/weed/filer/filer_notify_read.go
@@ -246,10 +246,6 @@ func NewLogFileEntryCollector(f *Filer, startPosition log_buffer.MessagePosition
}
}
-func (c *LogFileEntryCollector) hasMore() bool {
- return c.dayEntryQueue.Len() > 0
-}
-
func (c *LogFileEntryCollector) collectMore(v *OrderedLogVisitor) (err error) {
dayEntry := c.dayEntryQueue.Dequeue()
if dayEntry == nil {
diff --git a/weed/filer/meta_replay.go b/weed/filer/meta_replay.go
index f6b009e92..51c4e6987 100644
--- a/weed/filer/meta_replay.go
+++ b/weed/filer/meta_replay.go
@@ -2,7 +2,6 @@ package filer
import (
"context"
- "sync"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
@@ -36,39 +35,3 @@ func Replay(filerStore FilerStore, resp *filer_pb.SubscribeMetadataResponse) err
return nil
}
-
-// ParallelProcessDirectoryStructure processes each entry in parallel, and also ensure parent directories are processed first.
-// This also assumes the parent directories are in the entryChan already.
-func ParallelProcessDirectoryStructure(entryChan chan *Entry, concurrency int, eachEntryFn func(entry *Entry) error) (firstErr error) {
-
- executors := util.NewLimitedConcurrentExecutor(concurrency)
-
- var wg sync.WaitGroup
- for entry := range entryChan {
- wg.Add(1)
- if entry.IsDirectory() {
- func() {
- defer wg.Done()
- if err := eachEntryFn(entry); err != nil {
- if firstErr == nil {
- firstErr = err
- }
- }
- }()
- } else {
- executors.Execute(func() {
- defer wg.Done()
- if err := eachEntryFn(entry); err != nil {
- if firstErr == nil {
- firstErr = err
- }
- }
- })
- }
- if firstErr != nil {
- break
- }
- }
- wg.Wait()
- return
-}
diff --git a/weed/filer/redis3/ItemList.go b/weed/filer/redis3/ItemList.go
index 05457e596..b4043d01c 100644
--- a/weed/filer/redis3/ItemList.go
+++ b/weed/filer/redis3/ItemList.go
@@ -16,15 +16,6 @@ type ItemList struct {
prefix string
}
-func newItemList(client redis.UniversalClient, prefix string, store skiplist.ListStore, batchSize int) *ItemList {
- return &ItemList{
- skipList: skiplist.New(store),
- batchSize: batchSize,
- client: client,
- prefix: prefix,
- }
-}
-
/*
Be reluctant to create new nodes. Try to fit into either previous node or next node.
Prefer to add to previous node.
diff --git a/weed/filer/redis_lua/redis_cluster_store.go b/weed/filer/redis_lua/redis_cluster_store.go
deleted file mode 100644
index b64342fc2..000000000
--- a/weed/filer/redis_lua/redis_cluster_store.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package redis_lua
-
-import (
- "github.com/redis/go-redis/v9"
- "github.com/seaweedfs/seaweedfs/weed/filer"
- "github.com/seaweedfs/seaweedfs/weed/util"
-)
-
-func init() {
- filer.Stores = append(filer.Stores, &RedisLuaClusterStore{})
-}
-
-type RedisLuaClusterStore struct {
- UniversalRedisLuaStore
-}
-
-func (store *RedisLuaClusterStore) GetName() string {
- return "redis_lua_cluster"
-}
-
-func (store *RedisLuaClusterStore) Initialize(configuration util.Configuration, prefix string) (err error) {
-
- configuration.SetDefault(prefix+"useReadOnly", false)
- configuration.SetDefault(prefix+"routeByLatency", false)
-
- return store.initialize(
- configuration.GetStringSlice(prefix+"addresses"),
- configuration.GetString(prefix+"username"),
- configuration.GetString(prefix+"password"),
- configuration.GetString(prefix+"keyPrefix"),
- configuration.GetBool(prefix+"useReadOnly"),
- configuration.GetBool(prefix+"routeByLatency"),
- configuration.GetStringSlice(prefix+"superLargeDirectories"),
- )
-}
-
-func (store *RedisLuaClusterStore) initialize(addresses []string, username string, password string, keyPrefix string, readOnly, routeByLatency bool, superLargeDirectories []string) (err error) {
- store.Client = redis.NewClusterClient(&redis.ClusterOptions{
- Addrs: addresses,
- Username: username,
- Password: password,
- ReadOnly: readOnly,
- RouteByLatency: routeByLatency,
- })
- store.keyPrefix = keyPrefix
- store.loadSuperLargeDirectories(superLargeDirectories)
- return
-}
diff --git a/weed/filer/redis_lua/redis_sentinel_store.go b/weed/filer/redis_lua/redis_sentinel_store.go
deleted file mode 100644
index 6dd85dd06..000000000
--- a/weed/filer/redis_lua/redis_sentinel_store.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package redis_lua
-
-import (
- "time"
-
- "github.com/redis/go-redis/v9"
- "github.com/seaweedfs/seaweedfs/weed/filer"
- "github.com/seaweedfs/seaweedfs/weed/util"
-)
-
-func init() {
- filer.Stores = append(filer.Stores, &RedisLuaSentinelStore{})
-}
-
-type RedisLuaSentinelStore struct {
- UniversalRedisLuaStore
-}
-
-func (store *RedisLuaSentinelStore) GetName() string {
- return "redis_lua_sentinel"
-}
-
-func (store *RedisLuaSentinelStore) Initialize(configuration util.Configuration, prefix string) (err error) {
- return store.initialize(
- configuration.GetStringSlice(prefix+"addresses"),
- configuration.GetString(prefix+"masterName"),
- configuration.GetString(prefix+"username"),
- configuration.GetString(prefix+"password"),
- configuration.GetInt(prefix+"database"),
- configuration.GetString(prefix+"keyPrefix"),
- )
-}
-
-func (store *RedisLuaSentinelStore) initialize(addresses []string, masterName string, username string, password string, database int, keyPrefix string) (err error) {
- store.Client = redis.NewFailoverClient(&redis.FailoverOptions{
- MasterName: masterName,
- SentinelAddrs: addresses,
- Username: username,
- Password: password,
- DB: database,
- MinRetryBackoff: time.Millisecond * 100,
- MaxRetryBackoff: time.Minute * 1,
- ReadTimeout: time.Second * 30,
- WriteTimeout: time.Second * 5,
- })
- store.keyPrefix = keyPrefix
- return
-}
diff --git a/weed/filer/redis_lua/redis_store.go b/weed/filer/redis_lua/redis_store.go
deleted file mode 100644
index 4f6354e96..000000000
--- a/weed/filer/redis_lua/redis_store.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package redis_lua
-
-import (
- "github.com/redis/go-redis/v9"
- "github.com/seaweedfs/seaweedfs/weed/filer"
- "github.com/seaweedfs/seaweedfs/weed/util"
-)
-
-func init() {
- filer.Stores = append(filer.Stores, &RedisLuaStore{})
-}
-
-type RedisLuaStore struct {
- UniversalRedisLuaStore
-}
-
-func (store *RedisLuaStore) GetName() string {
- return "redis_lua"
-}
-
-func (store *RedisLuaStore) Initialize(configuration util.Configuration, prefix string) (err error) {
- return store.initialize(
- configuration.GetString(prefix+"address"),
- configuration.GetString(prefix+"username"),
- configuration.GetString(prefix+"password"),
- configuration.GetInt(prefix+"database"),
- configuration.GetString(prefix+"keyPrefix"),
- configuration.GetStringSlice(prefix+"superLargeDirectories"),
- )
-}
-
-func (store *RedisLuaStore) initialize(hostPort string, username string, password string, database int, keyPrefix string, superLargeDirectories []string) (err error) {
- store.Client = redis.NewClient(&redis.Options{
- Addr: hostPort,
- Username: username,
- Password: password,
- DB: database,
- })
- store.keyPrefix = keyPrefix
- store.loadSuperLargeDirectories(superLargeDirectories)
- return
-}
diff --git a/weed/filer/redis_lua/stored_procedure/delete_entry.lua b/weed/filer/redis_lua/stored_procedure/delete_entry.lua
deleted file mode 100644
index 445337c77..000000000
--- a/weed/filer/redis_lua/stored_procedure/delete_entry.lua
+++ /dev/null
@@ -1,19 +0,0 @@
--- KEYS[1]: full path of entry
-local fullpath = KEYS[1]
--- KEYS[2]: full path of entry
-local fullpath_list_key = KEYS[2]
--- KEYS[3]: dir of the entry
-local dir_list_key = KEYS[3]
-
--- ARGV[1]: isSuperLargeDirectory
-local isSuperLargeDirectory = ARGV[1] == "1"
--- ARGV[2]: name of the entry
-local name = ARGV[2]
-
-redis.call("DEL", fullpath, fullpath_list_key)
-
-if not isSuperLargeDirectory and name ~= "" then
- redis.call("ZREM", dir_list_key, name)
-end
-
-return 0
\ No newline at end of file
diff --git a/weed/filer/redis_lua/stored_procedure/delete_folder_children.lua b/weed/filer/redis_lua/stored_procedure/delete_folder_children.lua
deleted file mode 100644
index 77e4839f9..000000000
--- a/weed/filer/redis_lua/stored_procedure/delete_folder_children.lua
+++ /dev/null
@@ -1,15 +0,0 @@
--- KEYS[1]: full path of entry
-local fullpath = KEYS[1]
-
-if fullpath ~= "" and string.sub(fullpath, -1) == "/" then
- fullpath = string.sub(fullpath, 0, -2)
-end
-
-local files = redis.call("ZRANGE", fullpath .. "\0", "0", "-1")
-
-for _, name in ipairs(files) do
- local file_path = fullpath .. "/" .. name
- redis.call("DEL", file_path, file_path .. "\0")
-end
-
-return 0
\ No newline at end of file
diff --git a/weed/filer/redis_lua/stored_procedure/init.go b/weed/filer/redis_lua/stored_procedure/init.go
deleted file mode 100644
index 685ea364d..000000000
--- a/weed/filer/redis_lua/stored_procedure/init.go
+++ /dev/null
@@ -1,25 +0,0 @@
-package stored_procedure
-
-import (
- _ "embed"
-
- "github.com/redis/go-redis/v9"
-)
-
-func init() {
- InsertEntryScript = redis.NewScript(insertEntry)
- DeleteEntryScript = redis.NewScript(deleteEntry)
- DeleteFolderChildrenScript = redis.NewScript(deleteFolderChildren)
-}
-
-//go:embed insert_entry.lua
-var insertEntry string
-var InsertEntryScript *redis.Script
-
-//go:embed delete_entry.lua
-var deleteEntry string
-var DeleteEntryScript *redis.Script
-
-//go:embed delete_folder_children.lua
-var deleteFolderChildren string
-var DeleteFolderChildrenScript *redis.Script
diff --git a/weed/filer/redis_lua/stored_procedure/insert_entry.lua b/weed/filer/redis_lua/stored_procedure/insert_entry.lua
deleted file mode 100644
index 8deef3446..000000000
--- a/weed/filer/redis_lua/stored_procedure/insert_entry.lua
+++ /dev/null
@@ -1,27 +0,0 @@
--- KEYS[1]: full path of entry
-local full_path = KEYS[1]
--- KEYS[2]: dir of the entry
-local dir_list_key = KEYS[2]
-
--- ARGV[1]: content of the entry
-local entry = ARGV[1]
--- ARGV[2]: TTL of the entry
-local ttlSec = tonumber(ARGV[2])
--- ARGV[3]: isSuperLargeDirectory
-local isSuperLargeDirectory = ARGV[3] == "1"
--- ARGV[4]: zscore of the entry in zset
-local zscore = tonumber(ARGV[4])
--- ARGV[5]: name of the entry
-local name = ARGV[5]
-
-if ttlSec > 0 then
- redis.call("SET", full_path, entry, "EX", ttlSec)
-else
- redis.call("SET", full_path, entry)
-end
-
-if not isSuperLargeDirectory and name ~= "" then
- redis.call("ZADD", dir_list_key, "NX", zscore, name)
-end
-
-return 0
\ No newline at end of file
diff --git a/weed/filer/redis_lua/universal_redis_store.go b/weed/filer/redis_lua/universal_redis_store.go
deleted file mode 100644
index 0a02a0730..000000000
--- a/weed/filer/redis_lua/universal_redis_store.go
+++ /dev/null
@@ -1,206 +0,0 @@
-package redis_lua
-
-import (
- "context"
- "fmt"
- "time"
-
- "github.com/redis/go-redis/v9"
-
- "github.com/seaweedfs/seaweedfs/weed/filer"
- "github.com/seaweedfs/seaweedfs/weed/filer/redis_lua/stored_procedure"
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/util"
-)
-
-const (
- DIR_LIST_MARKER = "\x00"
-)
-
-type UniversalRedisLuaStore struct {
- Client redis.UniversalClient
- keyPrefix string
- superLargeDirectoryHash map[string]bool
-}
-
-func (store *UniversalRedisLuaStore) isSuperLargeDirectory(dir string) (isSuperLargeDirectory bool) {
- _, isSuperLargeDirectory = store.superLargeDirectoryHash[dir]
- return
-}
-
-func (store *UniversalRedisLuaStore) loadSuperLargeDirectories(superLargeDirectories []string) {
- // set directory hash
- store.superLargeDirectoryHash = make(map[string]bool)
- for _, dir := range superLargeDirectories {
- store.superLargeDirectoryHash[dir] = true
- }
-}
-
-func (store *UniversalRedisLuaStore) getKey(key string) string {
- if store.keyPrefix == "" {
- return key
- }
- return store.keyPrefix + key
-}
-
-func (store *UniversalRedisLuaStore) BeginTransaction(ctx context.Context) (context.Context, error) {
- return ctx, nil
-}
-func (store *UniversalRedisLuaStore) CommitTransaction(ctx context.Context) error {
- return nil
-}
-func (store *UniversalRedisLuaStore) RollbackTransaction(ctx context.Context) error {
- return nil
-}
-
-func (store *UniversalRedisLuaStore) InsertEntry(ctx context.Context, entry *filer.Entry) (err error) {
-
- value, err := entry.EncodeAttributesAndChunks()
- if err != nil {
- return fmt.Errorf("encoding %s %+v: %v", entry.FullPath, entry.Attr, err)
- }
-
- if len(entry.GetChunks()) > filer.CountEntryChunksForGzip {
- value = util.MaybeGzipData(value)
- }
-
- dir, name := entry.FullPath.DirAndName()
-
- err = stored_procedure.InsertEntryScript.Run(ctx, store.Client,
- []string{store.getKey(string(entry.FullPath)), store.getKey(genDirectoryListKey(dir))},
- value, entry.TtlSec,
- store.isSuperLargeDirectory(dir), 0, name).Err()
-
- if err != nil {
- return fmt.Errorf("persisting %s : %v", entry.FullPath, err)
- }
-
- return nil
-}
-
-func (store *UniversalRedisLuaStore) UpdateEntry(ctx context.Context, entry *filer.Entry) (err error) {
-
- return store.InsertEntry(ctx, entry)
-}
-
-func (store *UniversalRedisLuaStore) FindEntry(ctx context.Context, fullpath util.FullPath) (entry *filer.Entry, err error) {
-
- data, err := store.Client.Get(ctx, store.getKey(string(fullpath))).Result()
- if err == redis.Nil {
- return nil, filer_pb.ErrNotFound
- }
-
- if err != nil {
- return nil, fmt.Errorf("get %s : %v", fullpath, err)
- }
-
- entry = &filer.Entry{
- FullPath: fullpath,
- }
- err = entry.DecodeAttributesAndChunks(util.MaybeDecompressData([]byte(data)))
- if err != nil {
- return entry, fmt.Errorf("decode %s : %v", entry.FullPath, err)
- }
-
- return entry, nil
-}
-
-func (store *UniversalRedisLuaStore) DeleteEntry(ctx context.Context, fullpath util.FullPath) (err error) {
-
- dir, name := fullpath.DirAndName()
-
- err = stored_procedure.DeleteEntryScript.Run(ctx, store.Client,
- []string{store.getKey(string(fullpath)), store.getKey(genDirectoryListKey(string(fullpath))), store.getKey(genDirectoryListKey(dir))},
- store.isSuperLargeDirectory(dir), name).Err()
-
- if err != nil {
- return fmt.Errorf("DeleteEntry %s : %v", fullpath, err)
- }
-
- return nil
-}
-
-func (store *UniversalRedisLuaStore) DeleteFolderChildren(ctx context.Context, fullpath util.FullPath) (err error) {
-
- if store.isSuperLargeDirectory(string(fullpath)) {
- return nil
- }
-
- err = stored_procedure.DeleteFolderChildrenScript.Run(ctx, store.Client,
- []string{store.getKey(string(fullpath))}).Err()
-
- if err != nil {
- return fmt.Errorf("DeleteFolderChildren %s : %v", fullpath, err)
- }
-
- return nil
-}
-
-func (store *UniversalRedisLuaStore) ListDirectoryPrefixedEntries(ctx context.Context, dirPath util.FullPath, startFileName string, includeStartFile bool, limit int64, prefix string, eachEntryFunc filer.ListEachEntryFunc) (lastFileName string, err error) {
- return lastFileName, filer.ErrUnsupportedListDirectoryPrefixed
-}
-
-func (store *UniversalRedisLuaStore) ListDirectoryEntries(ctx context.Context, dirPath util.FullPath, startFileName string, includeStartFile bool, limit int64, eachEntryFunc filer.ListEachEntryFunc) (lastFileName string, err error) {
-
- dirListKey := store.getKey(genDirectoryListKey(string(dirPath)))
-
- min := "-"
- if startFileName != "" {
- if includeStartFile {
- min = "[" + startFileName
- } else {
- min = "(" + startFileName
- }
- }
-
- members, err := store.Client.ZRangeByLex(ctx, dirListKey, &redis.ZRangeBy{
- Min: min,
- Max: "+",
- Offset: 0,
- Count: limit,
- }).Result()
- if err != nil {
- return lastFileName, fmt.Errorf("list %s : %v", dirPath, err)
- }
-
- // fetch entry meta
- for _, fileName := range members {
- path := util.NewFullPath(string(dirPath), fileName)
- entry, err := store.FindEntry(ctx, path)
- lastFileName = fileName
- if err != nil {
- glog.V(0).InfofCtx(ctx, "list %s : %v", path, err)
- if err == filer_pb.ErrNotFound {
- continue
- }
- } else {
- if entry.TtlSec > 0 {
- if entry.Attr.Crtime.Add(time.Duration(entry.TtlSec) * time.Second).Before(time.Now()) {
- store.DeleteEntry(ctx, path)
- continue
- }
- }
-
- resEachEntryFunc, resEachEntryFuncErr := eachEntryFunc(entry)
- if resEachEntryFuncErr != nil {
- err = fmt.Errorf("failed to process eachEntryFunc: %w", resEachEntryFuncErr)
- break
- }
-
- if !resEachEntryFunc {
- break
- }
- }
- }
-
- return lastFileName, err
-}
-
-func genDirectoryListKey(dir string) (dirList string) {
- return dir + DIR_LIST_MARKER
-}
-
-func (store *UniversalRedisLuaStore) Shutdown() {
- store.Client.Close()
-}
diff --git a/weed/filer/redis_lua/universal_redis_store_kv.go b/weed/filer/redis_lua/universal_redis_store_kv.go
deleted file mode 100644
index 79b6495ce..000000000
--- a/weed/filer/redis_lua/universal_redis_store_kv.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package redis_lua
-
-import (
- "context"
- "fmt"
-
- "github.com/redis/go-redis/v9"
- "github.com/seaweedfs/seaweedfs/weed/filer"
-)
-
-func (store *UniversalRedisLuaStore) KvPut(ctx context.Context, key []byte, value []byte) (err error) {
-
- _, err = store.Client.Set(ctx, string(key), value, 0).Result()
-
- if err != nil {
- return fmt.Errorf("kv put: %w", err)
- }
-
- return nil
-}
-
-func (store *UniversalRedisLuaStore) KvGet(ctx context.Context, key []byte) (value []byte, err error) {
-
- data, err := store.Client.Get(ctx, string(key)).Result()
-
- if err == redis.Nil {
- return nil, filer.ErrKvNotFound
- }
-
- return []byte(data), err
-}
-
-func (store *UniversalRedisLuaStore) KvDelete(ctx context.Context, key []byte) (err error) {
-
- _, err = store.Client.Del(ctx, string(key)).Result()
-
- if err != nil {
- return fmt.Errorf("kv delete: %w", err)
- }
-
- return nil
-}
diff --git a/weed/filer/stream.go b/weed/filer/stream.go
index c60d147e5..e49794fd2 100644
--- a/weed/filer/stream.go
+++ b/weed/filer/stream.go
@@ -102,10 +102,6 @@ func PrepareStreamContent(masterClient wdclient.HasLookupFileIdFunction, jwtFunc
type VolumeServerJwtFunction func(fileId string) string
-func noJwtFunc(string) string {
- return ""
-}
-
type CacheInvalidator interface {
InvalidateCache(fileId string)
}
@@ -276,33 +272,6 @@ func writeZero(w io.Writer, size int64) (err error) {
return
}
-func ReadAll(ctx context.Context, buffer []byte, masterClient *wdclient.MasterClient, chunks []*filer_pb.FileChunk) error {
-
- lookupFileIdFn := func(ctx context.Context, fileId string) (targetUrls []string, err error) {
- return masterClient.LookupFileId(ctx, fileId)
- }
-
- chunkViews := ViewFromChunks(ctx, lookupFileIdFn, chunks, 0, int64(len(buffer)))
-
- idx := 0
-
- for x := chunkViews.Front(); x != nil; x = x.Next {
- chunkView := x.Value
- urlStrings, err := lookupFileIdFn(ctx, chunkView.FileId)
- if err != nil {
- glog.V(1).InfofCtx(ctx, "operation LookupFileId %s failed, err: %v", chunkView.FileId, err)
- return err
- }
-
- n, err := util_http.RetriedFetchChunkData(ctx, buffer[idx:idx+int(chunkView.ViewSize)], urlStrings, chunkView.CipherKey, chunkView.IsGzipped, chunkView.IsFullChunk(), chunkView.OffsetInChunk, chunkView.FileId)
- if err != nil {
- return err
- }
- idx += n
- }
- return nil
-}
-
// ---------------- ChunkStreamReader ----------------------------------
type ChunkStreamReader struct {
head *Interval[*ChunkView]
diff --git a/weed/filer/stream_failover_test.go b/weed/filer/stream_failover_test.go
deleted file mode 100644
index aaa59c523..000000000
--- a/weed/filer/stream_failover_test.go
+++ /dev/null
@@ -1,281 +0,0 @@
-package filer
-
-import (
- "bytes"
- "context"
- "errors"
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/wdclient"
-)
-
-// mockMasterClient implements HasLookupFileIdFunction and CacheInvalidator
-type mockMasterClient struct {
- lookupFunc func(ctx context.Context, fileId string) ([]string, error)
- invalidatedFileIds []string
-}
-
-func (m *mockMasterClient) GetLookupFileIdFunction() wdclient.LookupFileIdFunctionType {
- return m.lookupFunc
-}
-
-func (m *mockMasterClient) InvalidateCache(fileId string) {
- m.invalidatedFileIds = append(m.invalidatedFileIds, fileId)
-}
-
-// Test urlSlicesEqual helper function
-func TestUrlSlicesEqual(t *testing.T) {
- tests := []struct {
- name string
- a []string
- b []string
- expected bool
- }{
- {
- name: "identical slices",
- a: []string{"http://server1", "http://server2"},
- b: []string{"http://server1", "http://server2"},
- expected: true,
- },
- {
- name: "same URLs different order",
- a: []string{"http://server1", "http://server2"},
- b: []string{"http://server2", "http://server1"},
- expected: true,
- },
- {
- name: "different URLs",
- a: []string{"http://server1", "http://server2"},
- b: []string{"http://server1", "http://server3"},
- expected: false,
- },
- {
- name: "different lengths",
- a: []string{"http://server1"},
- b: []string{"http://server1", "http://server2"},
- expected: false,
- },
- {
- name: "empty slices",
- a: []string{},
- b: []string{},
- expected: true,
- },
- {
- name: "duplicates in both",
- a: []string{"http://server1", "http://server1"},
- b: []string{"http://server1", "http://server1"},
- expected: true,
- },
- {
- name: "different duplicate counts",
- a: []string{"http://server1", "http://server1"},
- b: []string{"http://server1", "http://server2"},
- expected: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := urlSlicesEqual(tt.a, tt.b)
- if result != tt.expected {
- t.Errorf("urlSlicesEqual(%v, %v) = %v; want %v", tt.a, tt.b, result, tt.expected)
- }
- })
- }
-}
-
-// Test cache invalidation when read fails
-func TestStreamContentWithCacheInvalidation(t *testing.T) {
- ctx := context.Background()
- fileId := "3,01234567890"
-
- callCount := 0
- oldUrls := []string{"http://failed-server:8080"}
- newUrls := []string{"http://working-server:8080"}
-
- mock := &mockMasterClient{
- lookupFunc: func(ctx context.Context, fid string) ([]string, error) {
- callCount++
- if callCount == 1 {
- // First call returns failing server
- return oldUrls, nil
- }
- // After invalidation, return working server
- return newUrls, nil
- },
- }
-
- // Create a simple chunk
- chunks := []*filer_pb.FileChunk{
- {
- FileId: fileId,
- Offset: 0,
- Size: 10,
- },
- }
-
- streamFn, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
- if err != nil {
- t.Fatalf("PrepareStreamContentWithThrottler failed: %v", err)
- }
-
- // Note: This test can't fully execute streamFn because it would require actual HTTP servers
- // However, we can verify the setup was created correctly
- if streamFn == nil {
- t.Fatal("Expected non-nil stream function")
- }
-
- // Verify the lookup was called
- if callCount != 1 {
- t.Errorf("Expected 1 lookup call, got %d", callCount)
- }
-}
-
-// Test that InvalidateCache is called on read failure
-func TestCacheInvalidationInterface(t *testing.T) {
- mock := &mockMasterClient{
- lookupFunc: func(ctx context.Context, fileId string) ([]string, error) {
- return []string{"http://server:8080"}, nil
- },
- }
-
- fileId := "3,test123"
-
- // Simulate invalidation
- if invalidator, ok := interface{}(mock).(CacheInvalidator); ok {
- invalidator.InvalidateCache(fileId)
- } else {
- t.Fatal("mockMasterClient should implement CacheInvalidator")
- }
-
- // Check that the file ID was recorded as invalidated
- if len(mock.invalidatedFileIds) != 1 {
- t.Fatalf("Expected 1 invalidated file ID, got %d", len(mock.invalidatedFileIds))
- }
- if mock.invalidatedFileIds[0] != fileId {
- t.Errorf("Expected invalidated file ID %s, got %s", fileId, mock.invalidatedFileIds[0])
- }
-}
-
-// Test retry logic doesn't retry with same URLs
-func TestRetryLogicSkipsSameUrls(t *testing.T) {
- // This test verifies that the urlSlicesEqual check prevents infinite retries
- sameUrls := []string{"http://server1:8080", "http://server2:8080"}
- differentUrls := []string{"http://server3:8080", "http://server4:8080"}
-
- // Same URLs should return true (and thus skip retry)
- if !urlSlicesEqual(sameUrls, sameUrls) {
- t.Error("Expected same URLs to be equal")
- }
-
- // Different URLs should return false (and thus allow retry)
- if urlSlicesEqual(sameUrls, differentUrls) {
- t.Error("Expected different URLs to not be equal")
- }
-}
-
-func TestCanceledStreamSkipsCacheInvalidation(t *testing.T) {
- ctx, cancel := context.WithCancel(context.Background())
- fileId := "3,canceled"
-
- mock := &mockMasterClient{
- lookupFunc: func(ctx context.Context, fid string) ([]string, error) {
- return []string{"http://server:8080"}, nil
- },
- }
-
- chunks := []*filer_pb.FileChunk{
- {
- FileId: fileId,
- Offset: 0,
- Size: 10,
- },
- }
-
- streamFn, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
- if err != nil {
- t.Fatalf("PrepareStreamContentWithThrottler failed: %v", err)
- }
-
- cancel()
-
- err = streamFn(&bytes.Buffer{})
- if err != context.Canceled {
- t.Fatalf("expected context.Canceled, got %v", err)
- }
- if len(mock.invalidatedFileIds) != 0 {
- t.Fatalf("expected no cache invalidation on cancellation, got %v", mock.invalidatedFileIds)
- }
-}
-
-func TestPrepareStreamContentSkipsLookupWhenContextAlreadyCanceled(t *testing.T) {
- oldSchedule := getLookupFileIdBackoffSchedule
- getLookupFileIdBackoffSchedule = []time.Duration{time.Millisecond}
- t.Cleanup(func() {
- getLookupFileIdBackoffSchedule = oldSchedule
- })
-
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
-
- lookupCalls := 0
- mock := &mockMasterClient{
- lookupFunc: func(ctx context.Context, fileId string) ([]string, error) {
- lookupCalls++
- return nil, errors.New("lookup should not run")
- },
- }
-
- chunks := []*filer_pb.FileChunk{
- {
- FileId: "3,precanceled",
- Offset: 0,
- Size: 10,
- },
- }
-
- _, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
- if !errors.Is(err, context.Canceled) {
- t.Fatalf("expected context.Canceled, got %v", err)
- }
- if lookupCalls != 0 {
- t.Fatalf("expected no lookup calls after cancellation, got %d", lookupCalls)
- }
-}
-
-func TestPrepareStreamContentStopsLookupRetriesAfterContextCancellation(t *testing.T) {
- oldSchedule := getLookupFileIdBackoffSchedule
- getLookupFileIdBackoffSchedule = []time.Duration{time.Millisecond, time.Millisecond, time.Millisecond}
- t.Cleanup(func() {
- getLookupFileIdBackoffSchedule = oldSchedule
- })
-
- ctx, cancel := context.WithCancel(context.Background())
- lookupCalls := 0
- mock := &mockMasterClient{
- lookupFunc: func(ctx context.Context, fileId string) ([]string, error) {
- lookupCalls++
- cancel()
- return nil, context.Canceled
- },
- }
-
- chunks := []*filer_pb.FileChunk{
- {
- FileId: "3,cancel-during-lookup",
- Offset: 0,
- Size: 10,
- },
- }
-
- _, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
- if !errors.Is(err, context.Canceled) {
- t.Fatalf("expected context.Canceled, got %v", err)
- }
- if lookupCalls != 1 {
- t.Fatalf("expected lookup retries to stop after cancellation, got %d calls", lookupCalls)
- }
-}
diff --git a/weed/iam/helpers.go b/weed/iam/helpers.go
index ef94940af..cbbc86fb2 100644
--- a/weed/iam/helpers.go
+++ b/weed/iam/helpers.go
@@ -37,11 +37,6 @@ func GenerateRandomString(length int, charset string) (string, error) {
return string(b), nil
}
-// GenerateAccessKeyId generates a new access key ID.
-func GenerateAccessKeyId() (string, error) {
- return GenerateRandomString(AccessKeyIdLength, CharsetUpper)
-}
-
// GenerateSecretAccessKey generates a new secret access key.
func GenerateSecretAccessKey() (string, error) {
return GenerateRandomString(SecretAccessKeyLength, Charset)
@@ -179,11 +174,3 @@ func MapToIdentitiesAction(action string) string {
return ""
}
}
-
-// MaskAccessKey masks an access key for logging, showing only the first 4 characters.
-func MaskAccessKey(accessKeyId string) string {
- if len(accessKeyId) > 4 {
- return accessKeyId[:4] + "***"
- }
- return accessKeyId
-}
diff --git a/weed/iam/helpers_test.go b/weed/iam/helpers_test.go
deleted file mode 100644
index 6b39a3779..000000000
--- a/weed/iam/helpers_test.go
+++ /dev/null
@@ -1,164 +0,0 @@
-package iam
-
-import (
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/stretchr/testify/assert"
-)
-
-func TestHash(t *testing.T) {
- input := "test"
- result := Hash(&input)
- assert.NotEmpty(t, result)
- assert.Len(t, result, 40) // SHA1 hex is 40 chars
-
- // Same input should produce same hash
- result2 := Hash(&input)
- assert.Equal(t, result, result2)
-
- // Different input should produce different hash
- different := "different"
- result3 := Hash(&different)
- assert.NotEqual(t, result, result3)
-}
-
-func TestGenerateRandomString(t *testing.T) {
- // Valid generation
- result, err := GenerateRandomString(10, CharsetUpper)
- assert.NoError(t, err)
- assert.Len(t, result, 10)
-
- // Different calls should produce different results (with high probability)
- result2, err := GenerateRandomString(10, CharsetUpper)
- assert.NoError(t, err)
- assert.NotEqual(t, result, result2)
-
- // Invalid length
- _, err = GenerateRandomString(0, CharsetUpper)
- assert.Error(t, err)
-
- _, err = GenerateRandomString(-1, CharsetUpper)
- assert.Error(t, err)
-
- // Empty charset
- _, err = GenerateRandomString(10, "")
- assert.Error(t, err)
-}
-
-func TestGenerateAccessKeyId(t *testing.T) {
- keyId, err := GenerateAccessKeyId()
- assert.NoError(t, err)
- assert.Len(t, keyId, AccessKeyIdLength)
-}
-
-func TestGenerateSecretAccessKey(t *testing.T) {
- secretKey, err := GenerateSecretAccessKey()
- assert.NoError(t, err)
- assert.Len(t, secretKey, SecretAccessKeyLength)
-}
-
-func TestGenerateSecretAccessKey_URLSafe(t *testing.T) {
- // Generate multiple keys to increase probability of catching unsafe chars
- for i := 0; i < 100; i++ {
- secretKey, err := GenerateSecretAccessKey()
- assert.NoError(t, err)
-
- // Verify no URL-unsafe characters that would cause authentication issues
- assert.NotContains(t, secretKey, "/", "Secret key should not contain /")
- assert.NotContains(t, secretKey, "+", "Secret key should not contain +")
-
- // Verify only expected characters are present
- for _, char := range secretKey {
- assert.Contains(t, Charset, string(char), "Secret key contains unexpected character: %c", char)
- }
- }
-}
-
-func TestStringSlicesEqual(t *testing.T) {
- tests := []struct {
- a []string
- b []string
- expected bool
- }{
- {[]string{"a", "b", "c"}, []string{"a", "b", "c"}, true},
- {[]string{"c", "b", "a"}, []string{"a", "b", "c"}, true}, // Order independent
- {[]string{"a", "b"}, []string{"a", "b", "c"}, false},
- {[]string{}, []string{}, true},
- {nil, nil, true},
- {[]string{"a"}, []string{"b"}, false},
- }
-
- for _, test := range tests {
- result := StringSlicesEqual(test.a, test.b)
- assert.Equal(t, test.expected, result)
- }
-}
-
-func TestMapToStatementAction(t *testing.T) {
- tests := []struct {
- input string
- expected string
- }{
- {StatementActionAdmin, s3_constants.ACTION_ADMIN},
- {StatementActionWrite, s3_constants.ACTION_WRITE},
- {StatementActionRead, s3_constants.ACTION_READ},
- {StatementActionList, s3_constants.ACTION_LIST},
- {StatementActionDelete, s3_constants.ACTION_DELETE_BUCKET},
- // Test fine-grained S3 action mappings (Issue #7864)
- {"DeleteObject", s3_constants.ACTION_WRITE},
- {"s3:DeleteObject", s3_constants.ACTION_WRITE},
- {"PutObject", s3_constants.ACTION_WRITE},
- {"s3:PutObject", s3_constants.ACTION_WRITE},
- {"GetObject", s3_constants.ACTION_READ},
- {"s3:GetObject", s3_constants.ACTION_READ},
- {"ListBucket", s3_constants.ACTION_LIST},
- {"s3:ListBucket", s3_constants.ACTION_LIST},
- {"PutObjectAcl", s3_constants.ACTION_WRITE_ACP},
- {"s3:PutObjectAcl", s3_constants.ACTION_WRITE_ACP},
- {"GetObjectAcl", s3_constants.ACTION_READ_ACP},
- {"s3:GetObjectAcl", s3_constants.ACTION_READ_ACP},
- {"unknown", ""},
- }
-
- for _, test := range tests {
- result := MapToStatementAction(test.input)
- assert.Equal(t, test.expected, result, "Failed for input: %s", test.input)
- }
-}
-
-func TestMapToIdentitiesAction(t *testing.T) {
- tests := []struct {
- input string
- expected string
- }{
- {s3_constants.ACTION_ADMIN, StatementActionAdmin},
- {s3_constants.ACTION_WRITE, StatementActionWrite},
- {s3_constants.ACTION_READ, StatementActionRead},
- {s3_constants.ACTION_LIST, StatementActionList},
- {s3_constants.ACTION_DELETE_BUCKET, StatementActionDelete},
- {"unknown", ""},
- }
-
- for _, test := range tests {
- result := MapToIdentitiesAction(test.input)
- assert.Equal(t, test.expected, result)
- }
-}
-
-func TestMaskAccessKey(t *testing.T) {
- tests := []struct {
- input string
- expected string
- }{
- {"AKIAIOSFODNN7EXAMPLE", "AKIA***"},
- {"AKIA", "AKIA"},
- {"AKI", "AKI"},
- {"", ""},
- }
-
- for _, test := range tests {
- result := MaskAccessKey(test.input)
- assert.Equal(t, test.expected, result)
- }
-}
diff --git a/weed/iam/integration/iam_manager.go b/weed/iam/integration/iam_manager.go
index fb8a47895..bfff9cefd 100644
--- a/weed/iam/integration/iam_manager.go
+++ b/weed/iam/integration/iam_manager.go
@@ -202,32 +202,6 @@ func (m *IAMManager) getFilerAddress() string {
return "" // Fallback to empty string if no provider is set
}
-// createRoleStore creates a role store based on configuration
-func (m *IAMManager) createRoleStore(config *RoleStoreConfig) (RoleStore, error) {
- if config == nil {
- // Default to generic cached filer role store when no config provided
- return NewGenericCachedRoleStore(nil, nil)
- }
-
- switch config.StoreType {
- case "", "filer":
- // Check if caching is explicitly disabled
- if config.StoreConfig != nil {
- if noCache, ok := config.StoreConfig["noCache"].(bool); ok && noCache {
- return NewFilerRoleStore(config.StoreConfig, nil)
- }
- }
- // Default to generic cached filer store for better performance
- return NewGenericCachedRoleStore(config.StoreConfig, nil)
- case "cached-filer", "generic-cached":
- return NewGenericCachedRoleStore(config.StoreConfig, nil)
- case "memory":
- return NewMemoryRoleStore(), nil
- default:
- return nil, fmt.Errorf("unsupported role store type: %s", config.StoreType)
- }
-}
-
// createRoleStoreWithProvider creates a role store with a filer address provider function
func (m *IAMManager) createRoleStoreWithProvider(config *RoleStoreConfig, filerAddressProvider func() string) (RoleStore, error) {
if config == nil {
diff --git a/weed/iam/integration/role_store.go b/weed/iam/integration/role_store.go
index f2dc128c7..11fbbb44e 100644
--- a/weed/iam/integration/role_store.go
+++ b/weed/iam/integration/role_store.go
@@ -388,157 +388,3 @@ type CachedFilerRoleStoreConfig struct {
ListTTL string `json:"listTtl,omitempty"` // e.g., "1m", "30s"
MaxCacheSize int `json:"maxCacheSize,omitempty"` // Maximum number of cached roles
}
-
-// NewCachedFilerRoleStore creates a new cached filer-based role store
-func NewCachedFilerRoleStore(config map[string]interface{}) (*CachedFilerRoleStore, error) {
- // Create underlying filer store
- filerStore, err := NewFilerRoleStore(config, nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create filer role store: %w", err)
- }
-
- // Parse cache configuration with defaults
- cacheTTL := 5 * time.Minute // Default 5 minutes for role cache
- listTTL := 1 * time.Minute // Default 1 minute for list cache
- maxCacheSize := 1000 // Default max 1000 cached roles
-
- if config != nil {
- if ttlStr, ok := config["ttl"].(string); ok && ttlStr != "" {
- if parsed, err := time.ParseDuration(ttlStr); err == nil {
- cacheTTL = parsed
- }
- }
- if listTTLStr, ok := config["listTtl"].(string); ok && listTTLStr != "" {
- if parsed, err := time.ParseDuration(listTTLStr); err == nil {
- listTTL = parsed
- }
- }
- if maxSize, ok := config["maxCacheSize"].(int); ok && maxSize > 0 {
- maxCacheSize = maxSize
- }
- }
-
- // Create ccache instances with appropriate configurations
- pruneCount := int64(maxCacheSize) >> 3
- if pruneCount <= 0 {
- pruneCount = 100
- }
-
- store := &CachedFilerRoleStore{
- filerStore: filerStore,
- cache: ccache.New(ccache.Configure().MaxSize(int64(maxCacheSize)).ItemsToPrune(uint32(pruneCount))),
- listCache: ccache.New(ccache.Configure().MaxSize(100).ItemsToPrune(10)), // Smaller cache for lists
- ttl: cacheTTL,
- listTTL: listTTL,
- }
-
- glog.V(2).Infof("Initialized CachedFilerRoleStore with TTL %v, List TTL %v, Max Cache Size %d",
- cacheTTL, listTTL, maxCacheSize)
-
- return store, nil
-}
-
-// StoreRole stores a role definition and invalidates the cache
-func (c *CachedFilerRoleStore) StoreRole(ctx context.Context, filerAddress string, roleName string, role *RoleDefinition) error {
- // Store in filer
- err := c.filerStore.StoreRole(ctx, filerAddress, roleName, role)
- if err != nil {
- return err
- }
-
- // Invalidate cache entries
- c.cache.Delete(roleName)
- c.listCache.Clear() // Invalidate list cache
-
- glog.V(3).Infof("Stored and invalidated cache for role %s", roleName)
- return nil
-}
-
-// GetRole retrieves a role definition with caching
-func (c *CachedFilerRoleStore) GetRole(ctx context.Context, filerAddress string, roleName string) (*RoleDefinition, error) {
- // Try to get from cache first
- item := c.cache.Get(roleName)
- if item != nil {
- // Cache hit - return cached role (DO NOT extend TTL)
- role := item.Value().(*RoleDefinition)
- glog.V(4).Infof("Cache hit for role %s", roleName)
- return copyRoleDefinition(role), nil
- }
-
- // Cache miss - fetch from filer
- glog.V(4).Infof("Cache miss for role %s, fetching from filer", roleName)
- role, err := c.filerStore.GetRole(ctx, filerAddress, roleName)
- if err != nil {
- return nil, err
- }
-
- // Cache the result with TTL
- c.cache.Set(roleName, copyRoleDefinition(role), c.ttl)
- glog.V(3).Infof("Cached role %s with TTL %v", roleName, c.ttl)
- return role, nil
-}
-
-// ListRoles lists all role names with caching
-func (c *CachedFilerRoleStore) ListRoles(ctx context.Context, filerAddress string) ([]string, error) {
- // Use a constant key for the role list cache
- const listCacheKey = "role_list"
-
- // Try to get from list cache first
- item := c.listCache.Get(listCacheKey)
- if item != nil {
- // Cache hit - return cached list (DO NOT extend TTL)
- roles := item.Value().([]string)
- glog.V(4).Infof("List cache hit, returning %d roles", len(roles))
- return append([]string(nil), roles...), nil // Return a copy
- }
-
- // Cache miss - fetch from filer
- glog.V(4).Infof("List cache miss, fetching from filer")
- roles, err := c.filerStore.ListRoles(ctx, filerAddress)
- if err != nil {
- return nil, err
- }
-
- // Cache the result with TTL (store a copy)
- rolesCopy := append([]string(nil), roles...)
- c.listCache.Set(listCacheKey, rolesCopy, c.listTTL)
- glog.V(3).Infof("Cached role list with %d entries, TTL %v", len(roles), c.listTTL)
- return roles, nil
-}
-
-// DeleteRole deletes a role definition and invalidates the cache
-func (c *CachedFilerRoleStore) DeleteRole(ctx context.Context, filerAddress string, roleName string) error {
- // Delete from filer
- err := c.filerStore.DeleteRole(ctx, filerAddress, roleName)
- if err != nil {
- return err
- }
-
- // Invalidate cache entries
- c.cache.Delete(roleName)
- c.listCache.Clear() // Invalidate list cache
-
- glog.V(3).Infof("Deleted and invalidated cache for role %s", roleName)
- return nil
-}
-
-// ClearCache clears all cached entries (for testing or manual cache invalidation)
-func (c *CachedFilerRoleStore) ClearCache() {
- c.cache.Clear()
- c.listCache.Clear()
- glog.V(2).Infof("Cleared all role cache entries")
-}
-
-// GetCacheStats returns cache statistics
-func (c *CachedFilerRoleStore) GetCacheStats() map[string]interface{} {
- return map[string]interface{}{
- "roleCache": map[string]interface{}{
- "size": c.cache.ItemCount(),
- "ttl": c.ttl.String(),
- },
- "listCache": map[string]interface{}{
- "size": c.listCache.ItemCount(),
- "ttl": c.listTTL.String(),
- },
- }
-}
diff --git a/weed/iam/policy/condition_set_test.go b/weed/iam/policy/condition_set_test.go
deleted file mode 100644
index 4c7e8bb67..000000000
--- a/weed/iam/policy/condition_set_test.go
+++ /dev/null
@@ -1,687 +0,0 @@
-package policy
-
-import (
- "context"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestConditionSetOperators(t *testing.T) {
- engine := setupTestPolicyEngine(t)
-
- t.Run("ForAnyValue:StringEquals", func(t *testing.T) {
- trustPolicy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowOIDC",
- Effect: "Allow",
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- Condition: map[string]map[string]interface{}{
- "ForAnyValue:StringEquals": {
- "oidc:roles": []string{"Dev.SeaweedFS.TestBucket.ReadWrite", "Dev.SeaweedFS.Admin"},
- },
- },
- },
- },
- }
-
- // Match: Admin is in the requested roles
- evalCtxMatch := &EvaluationContext{
- Principal: "web-identity-user",
- Action: "sts:AssumeRoleWithWebIdentity",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "oidc:roles": []string{"Dev.SeaweedFS.Admin", "OtherRole"},
- },
- }
- resultMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxMatch)
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultMatch.Effect)
-
- // No Match
- evalCtxNoMatch := &EvaluationContext{
- Principal: "web-identity-user",
- Action: "sts:AssumeRoleWithWebIdentity",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "oidc:roles": []string{"OtherRole1", "OtherRole2"},
- },
- }
- resultNoMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxNoMatch)
- require.NoError(t, err)
- assert.Equal(t, EffectDeny, resultNoMatch.Effect)
-
- // No Match: Empty context for ForAnyValue (should deny)
- evalCtxEmpty := &EvaluationContext{
- Principal: "web-identity-user",
- Action: "sts:AssumeRoleWithWebIdentity",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "oidc:roles": []string{},
- },
- }
- resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxEmpty)
- require.NoError(t, err)
- assert.Equal(t, EffectDeny, resultEmpty.Effect, "ForAnyValue should deny when context is empty")
- })
-
- t.Run("ForAllValues:StringEquals", func(t *testing.T) {
- trustPolicyAll := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowOIDCAll",
- Effect: "Allow",
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- Condition: map[string]map[string]interface{}{
- "ForAllValues:StringEquals": {
- "oidc:roles": []string{"RoleA", "RoleB", "RoleC"},
- },
- },
- },
- },
- }
-
- // Match: All requested roles ARE in the allowed set
- evalCtxAllMatch := &EvaluationContext{
- Principal: "web-identity-user",
- Action: "sts:AssumeRoleWithWebIdentity",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "oidc:roles": []string{"RoleA", "RoleB"},
- },
- }
- resultAllMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxAllMatch)
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultAllMatch.Effect)
-
- // Fail: RoleD is NOT in the allowed set
- evalCtxAllFail := &EvaluationContext{
- Principal: "web-identity-user",
- Action: "sts:AssumeRoleWithWebIdentity",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "oidc:roles": []string{"RoleA", "RoleD"},
- },
- }
- resultAllFail, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxAllFail)
- require.NoError(t, err)
- assert.Equal(t, EffectDeny, resultAllFail.Effect)
-
- // Vacuously true: Request has NO roles
- evalCtxEmpty := &EvaluationContext{
- Principal: "web-identity-user",
- Action: "sts:AssumeRoleWithWebIdentity",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "oidc:roles": []string{},
- },
- }
- resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxEmpty)
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultEmpty.Effect)
- })
-
- t.Run("ForAllValues:NumericEqualsVacuouslyTrue", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowNumericAll",
- Effect: "Allow",
- Action: []string{"sts:AssumeRole"},
- Condition: map[string]map[string]interface{}{
- "ForAllValues:NumericEquals": {
- "aws:MultiFactorAuthAge": []string{"3600", "7200"},
- },
- },
- },
- },
- }
-
- // Vacuously true: Request has NO MFA age info
- evalCtxEmpty := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "aws:MultiFactorAuthAge": []string{},
- },
- }
- resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when numeric context is empty for ForAllValues")
- })
-
- t.Run("ForAllValues:BoolVacuouslyTrue", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowBoolAll",
- Effect: "Allow",
- Action: []string{"sts:AssumeRole"},
- Condition: map[string]map[string]interface{}{
- "ForAllValues:Bool": {
- "aws:SecureTransport": "true",
- },
- },
- },
- },
- }
-
- // Vacuously true
- evalCtxEmpty := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "aws:SecureTransport": []interface{}{},
- },
- }
- resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when bool context is empty for ForAllValues")
- })
-
- t.Run("ForAllValues:DateVacuouslyTrue", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowDateAll",
- Effect: "Allow",
- Action: []string{"sts:AssumeRole"},
- Condition: map[string]map[string]interface{}{
- "ForAllValues:DateGreaterThan": {
- "aws:CurrentTime": "2020-01-01T00:00:00Z",
- },
- },
- },
- },
- }
-
- // Vacuously true
- evalCtxEmpty := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "aws:CurrentTime": []interface{}{},
- },
- }
- resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when date context is empty for ForAllValues")
- })
-
- t.Run("ForAllValues:DateWithLabelsAsStrings", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowDateStrings",
- Effect: "Allow",
- Action: []string{"sts:AssumeRole"},
- Condition: map[string]map[string]interface{}{
- "ForAllValues:DateGreaterThan": {
- "aws:CurrentTime": "2020-01-01T00:00:00Z",
- },
- },
- },
- },
- }
-
- evalCtx := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "aws:CurrentTime": []string{"2021-01-01T00:00:00Z", "2022-01-01T00:00:00Z"},
- },
- }
- result, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtx)
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, result.Effect, "Should allow when date context is a slice of strings")
- })
-
- t.Run("ForAllValues:BoolWithLabelsAsStrings", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowBoolStrings",
- Effect: "Allow",
- Action: []string{"sts:AssumeRole"},
- Condition: map[string]map[string]interface{}{
- "ForAllValues:Bool": {
- "aws:SecureTransport": "true",
- },
- },
- },
- },
- }
-
- evalCtx := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "aws:SecureTransport": []string{"true", "true"},
- },
- }
- result, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtx)
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, result.Effect, "Should allow when bool context is a slice of strings")
- })
-
- t.Run("StringEqualsIgnoreCaseWithVariable", func(t *testing.T) {
- policyDoc := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowVar",
- Effect: "Allow",
- Action: []string{"s3:GetObject"},
- Resource: []string{"arn:aws:s3:::bucket/*"},
- Condition: map[string]map[string]interface{}{
- "StringEqualsIgnoreCase": {
- "s3:prefix": "${aws:username}/",
- },
- },
- },
- },
- }
-
- err := engine.AddPolicy("", "var-policy", policyDoc)
- require.NoError(t, err)
-
- evalCtx := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/ALICE/file.txt",
- RequestContext: map[string]interface{}{
- "s3:prefix": "ALICE/",
- "aws:username": "alice",
- },
- }
-
- result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"var-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, result.Effect, "Should allow when variable expands and matches case-insensitively")
- })
-
- t.Run("StringLike:CaseSensitivity", func(t *testing.T) {
- policyDoc := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowCaseSensitiveLike",
- Effect: "Allow",
- Action: []string{"s3:GetObject"},
- Resource: []string{"arn:aws:s3:::bucket/*"},
- Condition: map[string]map[string]interface{}{
- "StringLike": {
- "s3:prefix": "Project/*",
- },
- },
- },
- },
- }
-
- err := engine.AddPolicy("", "like-policy", policyDoc)
- require.NoError(t, err)
-
- // Match: Case sensitive match
- evalCtxMatch := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/Project/file.txt",
- RequestContext: map[string]interface{}{
- "s3:prefix": "Project/data",
- },
- }
- resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"like-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultMatch.Effect, "Should allow when case matches exactly")
-
- // Fail: Case insensitive match (should fail for StringLike)
- evalCtxFail := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/project/file.txt",
- RequestContext: map[string]interface{}{
- "s3:prefix": "project/data", // lowercase 'p'
- },
- }
- resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"like-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when case does not match for StringLike")
- })
-
- t.Run("NumericNotEquals:Logic", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "DenySpecificAges",
- Effect: "Allow",
- Action: []string{"sts:AssumeRole"},
- Resource: []string{"*"},
- Condition: map[string]map[string]interface{}{
- "ForAllValues:NumericNotEquals": {
- "aws:MultiFactorAuthAge": []string{"3600", "7200"},
- },
- },
- },
- },
- }
-
- err := engine.AddPolicy("", "numeric-not-equals-policy", policy)
- require.NoError(t, err)
-
- // Fail: One age matches an excluded value (3600)
- evalCtxFail := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "aws:MultiFactorAuthAge": []string{"3600", "1800"},
- },
- }
- resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"numeric-not-equals-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when one age matches an excluded value")
-
- // Pass: No age matches any excluded value
- evalCtxPass := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "aws:MultiFactorAuthAge": []string{"1800", "900"},
- },
- }
- resultPass, err := engine.Evaluate(context.Background(), "", evalCtxPass, []string{"numeric-not-equals-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultPass.Effect, "Should allow when no age matches excluded values")
- })
-
- t.Run("DateNotEquals:Logic", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "DenySpecificTimes",
- Effect: "Allow",
- Action: []string{"sts:AssumeRole"},
- Resource: []string{"*"},
- Condition: map[string]map[string]interface{}{
- "ForAllValues:DateNotEquals": {
- "aws:CurrentTime": []string{"2024-01-01T00:00:00Z", "2024-01-02T00:00:00Z"},
- },
- },
- },
- },
- }
-
- err := engine.AddPolicy("", "date-not-equals-policy", policy)
- require.NoError(t, err)
-
- // Fail: One time matches an excluded value
- evalCtxFail := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "aws:CurrentTime": []string{"2024-01-01T00:00:00Z", "2024-01-03T00:00:00Z"},
- },
- }
- resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"date-not-equals-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when one date matches an excluded value")
- })
-
- t.Run("IpAddress:SetOperators", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowSpecificIPs",
- Effect: "Allow",
- Action: []string{"s3:GetObject"},
- Resource: []string{"*"},
- Condition: map[string]map[string]interface{}{
- "ForAllValues:IpAddress": {
- "aws:SourceIp": []string{"192.168.1.0/24", "10.0.0.1"},
- },
- },
- },
- },
- }
-
- err := engine.AddPolicy("", "ip-set-policy", policy)
- require.NoError(t, err)
-
- // Match: All source IPs are in allowed ranges
- evalCtxMatch := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/file.txt",
- RequestContext: map[string]interface{}{
- "aws:SourceIp": []string{"192.168.1.10", "10.0.0.1"},
- },
- }
- resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-set-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultMatch.Effect)
-
- // Fail: One source IP is NOT in allowed ranges
- evalCtxFail := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/file.txt",
- RequestContext: map[string]interface{}{
- "aws:SourceIp": []string{"192.168.1.10", "172.16.0.1"},
- },
- }
- resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"ip-set-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectDeny, resultFail.Effect)
-
- // ForAnyValue: IPAddress
- policyAny := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowAnySpecificIPs",
- Effect: "Allow",
- Action: []string{"s3:GetObject"},
- Resource: []string{"*"},
- Condition: map[string]map[string]interface{}{
- "ForAnyValue:IpAddress": {
- "aws:SourceIp": []string{"192.168.1.0/24"},
- },
- },
- },
- },
- }
- err = engine.AddPolicy("", "ip-any-policy", policyAny)
- require.NoError(t, err)
-
- evalCtxAnyMatch := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/file.txt",
- RequestContext: map[string]interface{}{
- "aws:SourceIp": []string{"192.168.1.10", "172.16.0.1"},
- },
- }
- resultAnyMatch, err := engine.Evaluate(context.Background(), "", evalCtxAnyMatch, []string{"ip-any-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultAnyMatch.Effect)
- })
-
- t.Run("IpAddress:SingleStringValue", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowSingleIP",
- Effect: "Allow",
- Action: []string{"s3:GetObject"},
- Resource: []string{"*"},
- Condition: map[string]map[string]interface{}{
- "IpAddress": {
- "aws:SourceIp": "192.168.1.1",
- },
- },
- },
- },
- }
-
- err := engine.AddPolicy("", "ip-single-policy", policy)
- require.NoError(t, err)
-
- evalCtxMatch := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/file.txt",
- RequestContext: map[string]interface{}{
- "aws:SourceIp": "192.168.1.1",
- },
- }
- resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-single-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultMatch.Effect)
-
- evalCtxNoMatch := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/file.txt",
- RequestContext: map[string]interface{}{
- "aws:SourceIp": "10.0.0.1",
- },
- }
- resultNoMatch, err := engine.Evaluate(context.Background(), "", evalCtxNoMatch, []string{"ip-single-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectDeny, resultNoMatch.Effect)
- })
-
- t.Run("Bool:StringSlicePolicyValues", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowWithBoolStrings",
- Effect: "Allow",
- Action: []string{"s3:GetObject"},
- Resource: []string{"*"},
- Condition: map[string]map[string]interface{}{
- "Bool": {
- "aws:SecureTransport": []string{"true", "false"},
- },
- },
- },
- },
- }
-
- err := engine.AddPolicy("", "bool-string-slice-policy", policy)
- require.NoError(t, err)
-
- evalCtx := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/file.txt",
- RequestContext: map[string]interface{}{
- "aws:SecureTransport": "true",
- },
- }
- result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"bool-string-slice-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, result.Effect)
- })
-
- t.Run("StringEqualsIgnoreCase:StringSlicePolicyValues", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowWithIgnoreCaseStrings",
- Effect: "Allow",
- Action: []string{"s3:GetObject"},
- Resource: []string{"*"},
- Condition: map[string]map[string]interface{}{
- "StringEqualsIgnoreCase": {
- "s3:x-amz-server-side-encryption": []string{"AES256", "aws:kms"},
- },
- },
- },
- },
- }
-
- err := engine.AddPolicy("", "string-ignorecase-slice-policy", policy)
- require.NoError(t, err)
-
- evalCtx := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/file.txt",
- RequestContext: map[string]interface{}{
- "s3:x-amz-server-side-encryption": "aes256",
- },
- }
- result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"string-ignorecase-slice-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, result.Effect)
- })
-
- t.Run("IpAddress:CustomContextKey", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowCustomIPKey",
- Effect: "Allow",
- Action: []string{"s3:GetObject"},
- Resource: []string{"*"},
- Condition: map[string]map[string]interface{}{
- "IpAddress": {
- "custom:VpcIp": "10.0.0.0/16",
- },
- },
- },
- },
- }
-
- err := engine.AddPolicy("", "ip-custom-key-policy", policy)
- require.NoError(t, err)
-
- evalCtxMatch := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/file.txt",
- RequestContext: map[string]interface{}{
- "custom:VpcIp": "10.0.5.1",
- },
- }
- resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-custom-key-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultMatch.Effect)
-
- evalCtxNoMatch := &EvaluationContext{
- Principal: "user",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::bucket/file.txt",
- RequestContext: map[string]interface{}{
- "custom:VpcIp": "192.168.1.1",
- },
- }
- resultNoMatch, err := engine.Evaluate(context.Background(), "", evalCtxNoMatch, []string{"ip-custom-key-policy"})
- require.NoError(t, err)
- assert.Equal(t, EffectDeny, resultNoMatch.Effect)
- })
-}
diff --git a/weed/iam/policy/negation_test.go b/weed/iam/policy/negation_test.go
deleted file mode 100644
index 31eed396f..000000000
--- a/weed/iam/policy/negation_test.go
+++ /dev/null
@@ -1,101 +0,0 @@
-package policy
-
-import (
- "context"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestNegationSetOperators(t *testing.T) {
- engine := setupTestPolicyEngine(t)
-
- t.Run("ForAllValues:StringNotEquals", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "DenyAdmin",
- Effect: "Allow",
- Action: []string{"sts:AssumeRole"},
- Condition: map[string]map[string]interface{}{
- "ForAllValues:StringNotEquals": {
- "oidc:roles": []string{"Admin"},
- },
- },
- },
- },
- }
-
- // All roles are NOT "Admin" -> Should Allow
- evalCtxAllow := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "oidc:roles": []string{"User", "Developer"},
- },
- }
- resultAllow, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxAllow)
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultAllow.Effect, "Should allow when ALL roles satisfy StringNotEquals Admin")
-
- // One role is "Admin" -> Should Deny
- evalCtxDeny := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "oidc:roles": []string{"Admin", "User"},
- },
- }
- resultDeny, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxDeny)
- require.NoError(t, err)
- assert.Equal(t, EffectDeny, resultDeny.Effect, "Should deny when one role is Admin and fails StringNotEquals")
- })
-
- t.Run("ForAnyValue:StringNotEquals", func(t *testing.T) {
- policy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "Requirement",
- Effect: "Allow",
- Action: []string{"sts:AssumeRole"},
- Condition: map[string]map[string]interface{}{
- "ForAnyValue:StringNotEquals": {
- "oidc:roles": []string{"Prohibited"},
- },
- },
- },
- },
- }
-
- // At least one role is NOT prohibited -> Should Allow
- evalCtxAllow := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "oidc:roles": []string{"Prohibited", "Allowed"},
- },
- }
- resultAllow, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxAllow)
- require.NoError(t, err)
- assert.Equal(t, EffectAllow, resultAllow.Effect, "Should allow when at least one role is NOT Prohibited")
-
- // All roles are Prohibited -> Should Deny
- evalCtxDeny := &EvaluationContext{
- Principal: "user",
- Action: "sts:AssumeRole",
- Resource: "arn:aws:iam::role/test-role",
- RequestContext: map[string]interface{}{
- "oidc:roles": []string{"Prohibited", "Prohibited"},
- },
- }
- resultDeny, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxDeny)
- require.NoError(t, err)
- assert.Equal(t, EffectDeny, resultDeny.Effect, "Should deny when ALL roles are Prohibited")
- })
-}
diff --git a/weed/iam/policy/policy_engine.go b/weed/iam/policy/policy_engine.go
index c8cd07367..7feca5c92 100644
--- a/weed/iam/policy/policy_engine.go
+++ b/weed/iam/policy/policy_engine.go
@@ -1155,11 +1155,6 @@ func ValidatePolicyDocumentWithType(policy *PolicyDocument, policyType string) e
return nil
}
-// validateStatement validates a single statement (for backward compatibility)
-func validateStatement(statement *Statement) error {
- return validateStatementWithType(statement, "resource")
-}
-
// validateStatementWithType validates a single statement based on policy type
func validateStatementWithType(statement *Statement, policyType string) error {
if statement.Effect != "Allow" && statement.Effect != "Deny" {
@@ -1198,29 +1193,6 @@ func validateStatementWithType(statement *Statement, policyType string) error {
return nil
}
-// matchResource checks if a resource pattern matches a requested resource
-// Uses hybrid approach: simple suffix wildcards for compatibility, filepath.Match for complex patterns
-func matchResource(pattern, resource string) bool {
- if pattern == resource {
- return true
- }
-
- // Handle simple suffix wildcard (backward compatibility)
- if strings.HasSuffix(pattern, "*") {
- prefix := pattern[:len(pattern)-1]
- return strings.HasPrefix(resource, prefix)
- }
-
- // For complex patterns, use filepath.Match for advanced wildcard support (*, ?, [])
- matched, err := filepath.Match(pattern, resource)
- if err != nil {
- // Fallback to exact match if pattern is malformed
- return pattern == resource
- }
-
- return matched
-}
-
// awsIAMMatch performs AWS IAM-compliant pattern matching with case-insensitivity and policy variable support
func awsIAMMatch(pattern, value string, evalCtx *EvaluationContext) bool {
// Step 1: Substitute policy variables (e.g., ${aws:username}, ${saml:username})
@@ -1274,16 +1246,6 @@ func expandPolicyVariables(pattern string, evalCtx *EvaluationContext) string {
return result
}
-// getContextValue safely gets a value from the evaluation context
-func getContextValue(evalCtx *EvaluationContext, key, defaultValue string) string {
- if value, exists := evalCtx.RequestContext[key]; exists {
- if str, ok := value.(string); ok {
- return str
- }
- }
- return defaultValue
-}
-
// AwsWildcardMatch performs case-insensitive wildcard matching like AWS IAM
func AwsWildcardMatch(pattern, value string) bool {
// Create regex pattern key for caching
@@ -1322,29 +1284,6 @@ func AwsWildcardMatch(pattern, value string) bool {
return regex.MatchString(value)
}
-// matchAction checks if an action pattern matches a requested action
-// Uses hybrid approach: simple suffix wildcards for compatibility, filepath.Match for complex patterns
-func matchAction(pattern, action string) bool {
- if pattern == action {
- return true
- }
-
- // Handle simple suffix wildcard (backward compatibility)
- if strings.HasSuffix(pattern, "*") {
- prefix := pattern[:len(pattern)-1]
- return strings.HasPrefix(action, prefix)
- }
-
- // For complex patterns, use filepath.Match for advanced wildcard support (*, ?, [])
- matched, err := filepath.Match(pattern, action)
- if err != nil {
- // Fallback to exact match if pattern is malformed
- return pattern == action
- }
-
- return matched
-}
-
// evaluateStringConditionIgnoreCase evaluates string conditions with case insensitivity
func (e *PolicyEngine) evaluateStringConditionIgnoreCase(block map[string]interface{}, evalCtx *EvaluationContext, shouldMatch bool, useWildcard bool, forAllValues bool) bool {
for key, expectedValues := range block {
diff --git a/weed/iam/policy/policy_engine_principal_test.go b/weed/iam/policy/policy_engine_principal_test.go
deleted file mode 100644
index 58714eb98..000000000
--- a/weed/iam/policy/policy_engine_principal_test.go
+++ /dev/null
@@ -1,421 +0,0 @@
-package policy
-
-import (
- "context"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// TestPrincipalMatching tests the matchesPrincipal method
-func TestPrincipalMatching(t *testing.T) {
- engine := setupTestPolicyEngine(t)
-
- tests := []struct {
- name string
- principal interface{}
- evalCtx *EvaluationContext
- want bool
- }{
- {
- name: "plain wildcard principal",
- principal: "*",
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{},
- },
- want: true,
- },
- {
- name: "structured wildcard federated principal",
- principal: map[string]interface{}{
- "Federated": "*",
- },
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{},
- },
- want: true,
- },
- {
- name: "wildcard in array",
- principal: map[string]interface{}{
- "Federated": []interface{}{"specific-provider", "*"},
- },
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{},
- },
- want: true,
- },
- {
- name: "specific federated provider match",
- principal: map[string]interface{}{
- "Federated": "https://example.com/oidc",
- },
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "https://example.com/oidc",
- },
- },
- want: true,
- },
- {
- name: "specific federated provider no match",
- principal: map[string]interface{}{
- "Federated": "https://example.com/oidc",
- },
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "https://other.com/oidc",
- },
- },
- want: false,
- },
- {
- name: "array with specific provider match",
- principal: map[string]interface{}{
- "Federated": []string{"https://provider1.com", "https://provider2.com"},
- },
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "https://provider2.com",
- },
- },
- want: true,
- },
- {
- name: "AWS principal match",
- principal: map[string]interface{}{
- "AWS": "arn:aws:iam::123456789012:user/alice",
- },
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{
- "aws:PrincipalArn": "arn:aws:iam::123456789012:user/alice",
- },
- },
- want: true,
- },
- {
- name: "Service principal match",
- principal: map[string]interface{}{
- "Service": "s3.amazonaws.com",
- },
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{
- "aws:PrincipalServiceName": "s3.amazonaws.com",
- },
- },
- want: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := engine.matchesPrincipal(tt.principal, tt.evalCtx)
- assert.Equal(t, tt.want, result, "Principal matching failed for: %s", tt.name)
- })
- }
-}
-
-// TestEvaluatePrincipalValue tests the evaluatePrincipalValue method
-func TestEvaluatePrincipalValue(t *testing.T) {
- engine := setupTestPolicyEngine(t)
-
- tests := []struct {
- name string
- principalValue interface{}
- contextKey string
- evalCtx *EvaluationContext
- want bool
- }{
- {
- name: "wildcard string",
- principalValue: "*",
- contextKey: "aws:FederatedProvider",
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{},
- },
- want: true,
- },
- {
- name: "specific string match",
- principalValue: "https://example.com",
- contextKey: "aws:FederatedProvider",
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "https://example.com",
- },
- },
- want: true,
- },
- {
- name: "specific string no match",
- principalValue: "https://example.com",
- contextKey: "aws:FederatedProvider",
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "https://other.com",
- },
- },
- want: false,
- },
- {
- name: "wildcard in array",
- principalValue: []interface{}{"provider1", "*"},
- contextKey: "aws:FederatedProvider",
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{},
- },
- want: true,
- },
- {
- name: "array match",
- principalValue: []string{"provider1", "provider2", "provider3"},
- contextKey: "aws:FederatedProvider",
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "provider2",
- },
- },
- want: true,
- },
- {
- name: "array no match",
- principalValue: []string{"provider1", "provider2"},
- contextKey: "aws:FederatedProvider",
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "provider3",
- },
- },
- want: false,
- },
- {
- name: "missing context key",
- principalValue: "specific-value",
- contextKey: "aws:FederatedProvider",
- evalCtx: &EvaluationContext{
- RequestContext: map[string]interface{}{},
- },
- want: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := engine.evaluatePrincipalValue(tt.principalValue, tt.evalCtx, tt.contextKey)
- assert.Equal(t, tt.want, result, "Principal value evaluation failed for: %s", tt.name)
- })
- }
-}
-
-// TestTrustPolicyEvaluation tests the EvaluateTrustPolicy method
-func TestTrustPolicyEvaluation(t *testing.T) {
- engine := setupTestPolicyEngine(t)
-
- tests := []struct {
- name string
- trustPolicy *PolicyDocument
- evalCtx *EvaluationContext
- wantEffect Effect
- wantErr bool
- }{
- {
- name: "wildcard federated principal allows any provider",
- trustPolicy: &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Effect: "Allow",
- Principal: map[string]interface{}{
- "Federated": "*",
- },
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- },
- },
- },
- evalCtx: &EvaluationContext{
- Action: "sts:AssumeRoleWithWebIdentity",
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "https://any-provider.com",
- },
- },
- wantEffect: EffectAllow,
- wantErr: false,
- },
- {
- name: "specific federated principal matches",
- trustPolicy: &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Effect: "Allow",
- Principal: map[string]interface{}{
- "Federated": "https://example.com/oidc",
- },
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- },
- },
- },
- evalCtx: &EvaluationContext{
- Action: "sts:AssumeRoleWithWebIdentity",
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "https://example.com/oidc",
- },
- },
- wantEffect: EffectAllow,
- wantErr: false,
- },
- {
- name: "specific federated principal does not match",
- trustPolicy: &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Effect: "Allow",
- Principal: map[string]interface{}{
- "Federated": "https://example.com/oidc",
- },
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- },
- },
- },
- evalCtx: &EvaluationContext{
- Action: "sts:AssumeRoleWithWebIdentity",
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "https://other.com/oidc",
- },
- },
- wantEffect: EffectDeny,
- wantErr: false,
- },
- {
- name: "plain wildcard principal",
- trustPolicy: &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Effect: "Allow",
- Principal: "*",
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- },
- },
- },
- evalCtx: &EvaluationContext{
- Action: "sts:AssumeRoleWithWebIdentity",
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "https://any-provider.com",
- },
- },
- wantEffect: EffectAllow,
- wantErr: false,
- },
- {
- name: "trust policy with conditions",
- trustPolicy: &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Effect: "Allow",
- Principal: map[string]interface{}{
- "Federated": "*",
- },
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- Condition: map[string]map[string]interface{}{
- "StringEquals": {
- "oidc:aud": "my-app-id",
- },
- },
- },
- },
- },
- evalCtx: &EvaluationContext{
- Action: "sts:AssumeRoleWithWebIdentity",
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "https://provider.com",
- "oidc:aud": "my-app-id",
- },
- },
- wantEffect: EffectAllow,
- wantErr: false,
- },
- {
- name: "trust policy condition not met",
- trustPolicy: &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Effect: "Allow",
- Principal: map[string]interface{}{
- "Federated": "*",
- },
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- Condition: map[string]map[string]interface{}{
- "StringEquals": {
- "oidc:aud": "my-app-id",
- },
- },
- },
- },
- },
- evalCtx: &EvaluationContext{
- Action: "sts:AssumeRoleWithWebIdentity",
- RequestContext: map[string]interface{}{
- "aws:FederatedProvider": "https://provider.com",
- "oidc:aud": "wrong-app-id",
- },
- },
- wantEffect: EffectDeny,
- wantErr: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result, err := engine.EvaluateTrustPolicy(context.Background(), tt.trustPolicy, tt.evalCtx)
-
- if tt.wantErr {
- assert.Error(t, err)
- } else {
- require.NoError(t, err)
- assert.Equal(t, tt.wantEffect, result.Effect, "Trust policy evaluation failed for: %s", tt.name)
- }
- })
- }
-}
-
-// TestGetPrincipalContextKey tests the context key mapping
-func TestGetPrincipalContextKey(t *testing.T) {
- tests := []struct {
- name string
- principalType string
- want string
- }{
- {
- name: "Federated principal",
- principalType: "Federated",
- want: "aws:FederatedProvider",
- },
- {
- name: "AWS principal",
- principalType: "AWS",
- want: "aws:PrincipalArn",
- },
- {
- name: "Service principal",
- principalType: "Service",
- want: "aws:PrincipalServiceName",
- },
- {
- name: "Custom principal type",
- principalType: "CustomType",
- want: "aws:PrincipalCustomType",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := getPrincipalContextKey(tt.principalType)
- assert.Equal(t, tt.want, result, "Context key mapping failed for: %s", tt.name)
- })
- }
-}
diff --git a/weed/iam/policy/policy_engine_test.go b/weed/iam/policy/policy_engine_test.go
deleted file mode 100644
index 3a150ba99..000000000
--- a/weed/iam/policy/policy_engine_test.go
+++ /dev/null
@@ -1,426 +0,0 @@
-package policy
-
-import (
- "context"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// TestPolicyEngineInitialization tests policy engine initialization
-func TestPolicyEngineInitialization(t *testing.T) {
- tests := []struct {
- name string
- config *PolicyEngineConfig
- wantErr bool
- }{
- {
- name: "valid config",
- config: &PolicyEngineConfig{
- DefaultEffect: "Deny",
- StoreType: "memory",
- },
- wantErr: false,
- },
- {
- name: "invalid default effect",
- config: &PolicyEngineConfig{
- DefaultEffect: "Invalid",
- StoreType: "memory",
- },
- wantErr: true,
- },
- {
- name: "nil config",
- config: nil,
- wantErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- engine := NewPolicyEngine()
-
- err := engine.Initialize(tt.config)
-
- if tt.wantErr {
- assert.Error(t, err)
- } else {
- assert.NoError(t, err)
- assert.True(t, engine.IsInitialized())
- }
- })
- }
-}
-
-// TestPolicyDocumentValidation tests policy document structure validation
-func TestPolicyDocumentValidation(t *testing.T) {
- tests := []struct {
- name string
- policy *PolicyDocument
- wantErr bool
- errorMsg string
- }{
- {
- name: "valid policy document",
- policy: &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowS3Read",
- Effect: "Allow",
- Action: []string{"s3:GetObject", "s3:ListBucket"},
- Resource: []string{"arn:aws:s3:::mybucket/*"},
- },
- },
- },
- wantErr: false,
- },
- {
- name: "missing version",
- policy: &PolicyDocument{
- Statement: []Statement{
- {
- Effect: "Allow",
- Action: []string{"s3:GetObject"},
- Resource: []string{"arn:aws:s3:::mybucket/*"},
- },
- },
- },
- wantErr: true,
- errorMsg: "version is required",
- },
- {
- name: "empty statements",
- policy: &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{},
- },
- wantErr: true,
- errorMsg: "at least one statement is required",
- },
- {
- name: "invalid effect",
- policy: &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Effect: "Maybe",
- Action: []string{"s3:GetObject"},
- Resource: []string{"arn:aws:s3:::mybucket/*"},
- },
- },
- },
- wantErr: true,
- errorMsg: "invalid effect",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := ValidatePolicyDocument(tt.policy)
-
- if tt.wantErr {
- assert.Error(t, err)
- if tt.errorMsg != "" {
- assert.Contains(t, err.Error(), tt.errorMsg)
- }
- } else {
- assert.NoError(t, err)
- }
- })
- }
-}
-
-// TestPolicyEvaluation tests policy evaluation logic
-func TestPolicyEvaluation(t *testing.T) {
- engine := setupTestPolicyEngine(t)
-
- // Add test policies
- readPolicy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowS3Read",
- Effect: "Allow",
- Action: []string{"s3:GetObject", "s3:ListBucket"},
- Resource: []string{
- "arn:aws:s3:::public-bucket/*", // For object operations
- "arn:aws:s3:::public-bucket", // For bucket operations
- },
- },
- },
- }
-
- err := engine.AddPolicy("", "read-policy", readPolicy)
- require.NoError(t, err)
-
- denyPolicy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "DenyS3Delete",
- Effect: "Deny",
- Action: []string{"s3:DeleteObject"},
- Resource: []string{"arn:aws:s3:::*"},
- },
- },
- }
-
- err = engine.AddPolicy("", "deny-policy", denyPolicy)
- require.NoError(t, err)
-
- tests := []struct {
- name string
- context *EvaluationContext
- policies []string
- want Effect
- }{
- {
- name: "allow read access",
- context: &EvaluationContext{
- Principal: "user:alice",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::public-bucket/file.txt",
- RequestContext: map[string]interface{}{
- "aws:SourceIp": "192.168.1.100",
- },
- },
- policies: []string{"read-policy"},
- want: EffectAllow,
- },
- {
- name: "deny delete access (explicit deny)",
- context: &EvaluationContext{
- Principal: "user:alice",
- Action: "s3:DeleteObject",
- Resource: "arn:aws:s3:::public-bucket/file.txt",
- },
- policies: []string{"read-policy", "deny-policy"},
- want: EffectDeny,
- },
- {
- name: "deny by default (no matching policy)",
- context: &EvaluationContext{
- Principal: "user:alice",
- Action: "s3:PutObject",
- Resource: "arn:aws:s3:::public-bucket/file.txt",
- },
- policies: []string{"read-policy"},
- want: EffectDeny,
- },
- {
- name: "allow with wildcard action",
- context: &EvaluationContext{
- Principal: "user:admin",
- Action: "s3:ListBucket",
- Resource: "arn:aws:s3:::public-bucket",
- },
- policies: []string{"read-policy"},
- want: EffectAllow,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result, err := engine.Evaluate(context.Background(), "", tt.context, tt.policies)
-
- assert.NoError(t, err)
- assert.Equal(t, tt.want, result.Effect)
-
- // Verify evaluation details
- assert.NotNil(t, result.EvaluationDetails)
- assert.Equal(t, tt.context.Action, result.EvaluationDetails.Action)
- assert.Equal(t, tt.context.Resource, result.EvaluationDetails.Resource)
- })
- }
-}
-
-// TestConditionEvaluation tests policy conditions
-func TestConditionEvaluation(t *testing.T) {
- engine := setupTestPolicyEngine(t)
-
- // Policy with IP address condition
- conditionalPolicy := &PolicyDocument{
- Version: "2012-10-17",
- Statement: []Statement{
- {
- Sid: "AllowFromOfficeIP",
- Effect: "Allow",
- Action: []string{"s3:*"},
- Resource: []string{"arn:aws:s3:::*"},
- Condition: map[string]map[string]interface{}{
- "IpAddress": {
- "aws:SourceIp": []string{"192.168.1.0/24", "10.0.0.0/8"},
- },
- },
- },
- },
- }
-
- err := engine.AddPolicy("", "ip-conditional", conditionalPolicy)
- require.NoError(t, err)
-
- tests := []struct {
- name string
- context *EvaluationContext
- want Effect
- }{
- {
- name: "allow from office IP",
- context: &EvaluationContext{
- Principal: "user:alice",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::mybucket/file.txt",
- RequestContext: map[string]interface{}{
- "aws:SourceIp": "192.168.1.100",
- },
- },
- want: EffectAllow,
- },
- {
- name: "deny from external IP",
- context: &EvaluationContext{
- Principal: "user:alice",
- Action: "s3:GetObject",
- Resource: "arn:aws:s3:::mybucket/file.txt",
- RequestContext: map[string]interface{}{
- "aws:SourceIp": "8.8.8.8",
- },
- },
- want: EffectDeny,
- },
- {
- name: "allow from internal IP",
- context: &EvaluationContext{
- Principal: "user:alice",
- Action: "s3:PutObject",
- Resource: "arn:aws:s3:::mybucket/newfile.txt",
- RequestContext: map[string]interface{}{
- "aws:SourceIp": "10.1.2.3",
- },
- },
- want: EffectAllow,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result, err := engine.Evaluate(context.Background(), "", tt.context, []string{"ip-conditional"})
-
- assert.NoError(t, err)
- assert.Equal(t, tt.want, result.Effect)
- })
- }
-}
-
-// TestResourceMatching tests resource ARN matching
-func TestResourceMatching(t *testing.T) {
- tests := []struct {
- name string
- policyResource string
- requestResource string
- want bool
- }{
- {
- name: "exact match",
- policyResource: "arn:aws:s3:::mybucket/file.txt",
- requestResource: "arn:aws:s3:::mybucket/file.txt",
- want: true,
- },
- {
- name: "wildcard match",
- policyResource: "arn:aws:s3:::mybucket/*",
- requestResource: "arn:aws:s3:::mybucket/folder/file.txt",
- want: true,
- },
- {
- name: "bucket wildcard",
- policyResource: "arn:aws:s3:::*",
- requestResource: "arn:aws:s3:::anybucket/file.txt",
- want: true,
- },
- {
- name: "no match different bucket",
- policyResource: "arn:aws:s3:::mybucket/*",
- requestResource: "arn:aws:s3:::otherbucket/file.txt",
- want: false,
- },
- {
- name: "prefix match",
- policyResource: "arn:aws:s3:::mybucket/documents/*",
- requestResource: "arn:aws:s3:::mybucket/documents/secret.txt",
- want: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := matchResource(tt.policyResource, tt.requestResource)
- assert.Equal(t, tt.want, result)
- })
- }
-}
-
-// TestActionMatching tests action pattern matching
-func TestActionMatching(t *testing.T) {
- tests := []struct {
- name string
- policyAction string
- requestAction string
- want bool
- }{
- {
- name: "exact match",
- policyAction: "s3:GetObject",
- requestAction: "s3:GetObject",
- want: true,
- },
- {
- name: "wildcard service",
- policyAction: "s3:*",
- requestAction: "s3:PutObject",
- want: true,
- },
- {
- name: "wildcard all",
- policyAction: "*",
- requestAction: "filer:CreateEntry",
- want: true,
- },
- {
- name: "prefix match",
- policyAction: "s3:Get*",
- requestAction: "s3:GetObject",
- want: true,
- },
- {
- name: "no match different service",
- policyAction: "s3:GetObject",
- requestAction: "filer:GetEntry",
- want: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := matchAction(tt.policyAction, tt.requestAction)
- assert.Equal(t, tt.want, result)
- })
- }
-}
-
-// Helper function to set up test policy engine
-func setupTestPolicyEngine(t *testing.T) *PolicyEngine {
- engine := NewPolicyEngine()
- config := &PolicyEngineConfig{
- DefaultEffect: "Deny",
- StoreType: "memory",
- }
-
- err := engine.Initialize(config)
- require.NoError(t, err)
-
- return engine
-}
diff --git a/weed/iam/providers/provider_test.go b/weed/iam/providers/provider_test.go
deleted file mode 100644
index 99cf360c1..000000000
--- a/weed/iam/providers/provider_test.go
+++ /dev/null
@@ -1,246 +0,0 @@
-package providers
-
-import (
- "context"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// TestIdentityProviderInterface tests the core identity provider interface
-func TestIdentityProviderInterface(t *testing.T) {
- tests := []struct {
- name string
- provider IdentityProvider
- wantErr bool
- }{
- // We'll add test cases as we implement providers
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Test provider name
- name := tt.provider.Name()
- assert.NotEmpty(t, name, "Provider name should not be empty")
-
- // Test initialization
- err := tt.provider.Initialize(nil)
- if tt.wantErr {
- assert.Error(t, err)
- return
- }
- require.NoError(t, err)
-
- // Test authentication with invalid token
- ctx := context.Background()
- _, err = tt.provider.Authenticate(ctx, "invalid-token")
- assert.Error(t, err, "Should fail with invalid token")
- })
- }
-}
-
-// TestExternalIdentityValidation tests external identity structure validation
-func TestExternalIdentityValidation(t *testing.T) {
- tests := []struct {
- name string
- identity *ExternalIdentity
- wantErr bool
- }{
- {
- name: "valid identity",
- identity: &ExternalIdentity{
- UserID: "user123",
- Email: "user@example.com",
- DisplayName: "Test User",
- Groups: []string{"group1", "group2"},
- Attributes: map[string]string{"dept": "engineering"},
- Provider: "test-provider",
- },
- wantErr: false,
- },
- {
- name: "missing user id",
- identity: &ExternalIdentity{
- Email: "user@example.com",
- Provider: "test-provider",
- },
- wantErr: true,
- },
- {
- name: "missing provider",
- identity: &ExternalIdentity{
- UserID: "user123",
- Email: "user@example.com",
- },
- wantErr: true,
- },
- {
- name: "invalid email",
- identity: &ExternalIdentity{
- UserID: "user123",
- Email: "invalid-email",
- Provider: "test-provider",
- },
- wantErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := tt.identity.Validate()
- if tt.wantErr {
- assert.Error(t, err)
- } else {
- assert.NoError(t, err)
- }
- })
- }
-}
-
-// TestTokenClaimsValidation tests token claims structure
-func TestTokenClaimsValidation(t *testing.T) {
- tests := []struct {
- name string
- claims *TokenClaims
- valid bool
- }{
- {
- name: "valid claims",
- claims: &TokenClaims{
- Subject: "user123",
- Issuer: "https://provider.example.com",
- Audience: "seaweedfs",
- ExpiresAt: time.Now().Add(time.Hour),
- IssuedAt: time.Now().Add(-time.Minute),
- Claims: map[string]interface{}{"email": "user@example.com"},
- },
- valid: true,
- },
- {
- name: "expired token",
- claims: &TokenClaims{
- Subject: "user123",
- Issuer: "https://provider.example.com",
- Audience: "seaweedfs",
- ExpiresAt: time.Now().Add(-time.Hour), // Expired
- IssuedAt: time.Now().Add(-time.Hour * 2),
- Claims: map[string]interface{}{"email": "user@example.com"},
- },
- valid: false,
- },
- {
- name: "future issued token",
- claims: &TokenClaims{
- Subject: "user123",
- Issuer: "https://provider.example.com",
- Audience: "seaweedfs",
- ExpiresAt: time.Now().Add(time.Hour),
- IssuedAt: time.Now().Add(time.Hour), // Future
- Claims: map[string]interface{}{"email": "user@example.com"},
- },
- valid: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- valid := tt.claims.IsValid()
- assert.Equal(t, tt.valid, valid)
- })
- }
-}
-
-// TestProviderRegistry tests provider registration and discovery
-func TestProviderRegistry(t *testing.T) {
- // Clear registry for test
- registry := NewProviderRegistry()
-
- t.Run("register provider", func(t *testing.T) {
- mockProvider := &MockProvider{name: "test-provider"}
-
- err := registry.RegisterProvider(mockProvider)
- assert.NoError(t, err)
-
- // Test duplicate registration
- err = registry.RegisterProvider(mockProvider)
- assert.Error(t, err, "Should not allow duplicate registration")
- })
-
- t.Run("get provider", func(t *testing.T) {
- provider, exists := registry.GetProvider("test-provider")
- assert.True(t, exists)
- assert.Equal(t, "test-provider", provider.Name())
-
- // Test non-existent provider
- _, exists = registry.GetProvider("non-existent")
- assert.False(t, exists)
- })
-
- t.Run("list providers", func(t *testing.T) {
- providers := registry.ListProviders()
- assert.Len(t, providers, 1)
- assert.Equal(t, "test-provider", providers[0])
- })
-}
-
-// MockProvider for testing
-type MockProvider struct {
- name string
- initialized bool
- shouldError bool
-}
-
-func (m *MockProvider) Name() string {
- return m.name
-}
-
-func (m *MockProvider) Initialize(config interface{}) error {
- if m.shouldError {
- return assert.AnError
- }
- m.initialized = true
- return nil
-}
-
-func (m *MockProvider) Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) {
- if !m.initialized {
- return nil, assert.AnError
- }
- if token == "invalid-token" {
- return nil, assert.AnError
- }
- return &ExternalIdentity{
- UserID: "test-user",
- Email: "test@example.com",
- DisplayName: "Test User",
- Provider: m.name,
- }, nil
-}
-
-func (m *MockProvider) GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) {
- if !m.initialized || userID == "" {
- return nil, assert.AnError
- }
- return &ExternalIdentity{
- UserID: userID,
- Email: userID + "@example.com",
- DisplayName: "User " + userID,
- Provider: m.name,
- }, nil
-}
-
-func (m *MockProvider) ValidateToken(ctx context.Context, token string) (*TokenClaims, error) {
- if !m.initialized || token == "invalid-token" {
- return nil, assert.AnError
- }
- return &TokenClaims{
- Subject: "test-user",
- Issuer: "test-issuer",
- Audience: "seaweedfs",
- ExpiresAt: time.Now().Add(time.Hour),
- IssuedAt: time.Now(),
- Claims: map[string]interface{}{"email": "test@example.com"},
- }, nil
-}
diff --git a/weed/iam/providers/registry.go b/weed/iam/providers/registry.go
deleted file mode 100644
index dee50df44..000000000
--- a/weed/iam/providers/registry.go
+++ /dev/null
@@ -1,109 +0,0 @@
-package providers
-
-import (
- "fmt"
- "sync"
-)
-
-// ProviderRegistry manages registered identity providers
-type ProviderRegistry struct {
- mu sync.RWMutex
- providers map[string]IdentityProvider
-}
-
-// NewProviderRegistry creates a new provider registry
-func NewProviderRegistry() *ProviderRegistry {
- return &ProviderRegistry{
- providers: make(map[string]IdentityProvider),
- }
-}
-
-// RegisterProvider registers a new identity provider
-func (r *ProviderRegistry) RegisterProvider(provider IdentityProvider) error {
- if provider == nil {
- return fmt.Errorf("provider cannot be nil")
- }
-
- name := provider.Name()
- if name == "" {
- return fmt.Errorf("provider name cannot be empty")
- }
-
- r.mu.Lock()
- defer r.mu.Unlock()
-
- if _, exists := r.providers[name]; exists {
- return fmt.Errorf("provider %s is already registered", name)
- }
-
- r.providers[name] = provider
- return nil
-}
-
-// GetProvider retrieves a provider by name
-func (r *ProviderRegistry) GetProvider(name string) (IdentityProvider, bool) {
- r.mu.RLock()
- defer r.mu.RUnlock()
-
- provider, exists := r.providers[name]
- return provider, exists
-}
-
-// ListProviders returns all registered provider names
-func (r *ProviderRegistry) ListProviders() []string {
- r.mu.RLock()
- defer r.mu.RUnlock()
-
- var names []string
- for name := range r.providers {
- names = append(names, name)
- }
- return names
-}
-
-// UnregisterProvider removes a provider from the registry
-func (r *ProviderRegistry) UnregisterProvider(name string) error {
- r.mu.Lock()
- defer r.mu.Unlock()
-
- if _, exists := r.providers[name]; !exists {
- return fmt.Errorf("provider %s is not registered", name)
- }
-
- delete(r.providers, name)
- return nil
-}
-
-// Clear removes all providers from the registry
-func (r *ProviderRegistry) Clear() {
- r.mu.Lock()
- defer r.mu.Unlock()
-
- r.providers = make(map[string]IdentityProvider)
-}
-
-// GetProviderCount returns the number of registered providers
-func (r *ProviderRegistry) GetProviderCount() int {
- r.mu.RLock()
- defer r.mu.RUnlock()
-
- return len(r.providers)
-}
-
-// Default global registry
-var defaultRegistry = NewProviderRegistry()
-
-// RegisterProvider registers a provider in the default registry
-func RegisterProvider(provider IdentityProvider) error {
- return defaultRegistry.RegisterProvider(provider)
-}
-
-// GetProvider retrieves a provider from the default registry
-func GetProvider(name string) (IdentityProvider, bool) {
- return defaultRegistry.GetProvider(name)
-}
-
-// ListProviders returns all provider names from the default registry
-func ListProviders() []string {
- return defaultRegistry.ListProviders()
-}
diff --git a/weed/iam/sts/cross_instance_token_test.go b/weed/iam/sts/cross_instance_token_test.go
deleted file mode 100644
index 8a375a885..000000000
--- a/weed/iam/sts/cross_instance_token_test.go
+++ /dev/null
@@ -1,503 +0,0 @@
-package sts
-
-import (
- "context"
- "testing"
- "time"
-
- "github.com/golang-jwt/jwt/v5"
- "github.com/seaweedfs/seaweedfs/weed/iam/oidc"
- "github.com/seaweedfs/seaweedfs/weed/iam/providers"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// Test-only constants for mock providers
-const (
- ProviderTypeMock = "mock"
-)
-
-// createMockOIDCProvider creates a mock OIDC provider for testing
-// This is only available in test builds
-func createMockOIDCProvider(name string, config map[string]interface{}) (providers.IdentityProvider, error) {
- // Convert config to OIDC format
- factory := NewProviderFactory()
- oidcConfig, err := factory.convertToOIDCConfig(config)
- if err != nil {
- return nil, err
- }
-
- // Set default values for mock provider if not provided
- if oidcConfig.Issuer == "" {
- oidcConfig.Issuer = "http://localhost:9999"
- }
-
- provider := oidc.NewMockOIDCProvider(name)
- if err := provider.Initialize(oidcConfig); err != nil {
- return nil, err
- }
-
- // Set up default test data for the mock provider
- provider.SetupDefaultTestData()
-
- return provider, nil
-}
-
-// createMockJWT creates a test JWT token with the specified issuer for mock provider testing
-func createMockJWT(t *testing.T, issuer, subject string) string {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "iss": issuer,
- "sub": subject,
- "aud": "test-client",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- })
-
- tokenString, err := token.SignedString([]byte("test-signing-key"))
- require.NoError(t, err)
- return tokenString
-}
-
-// TestCrossInstanceTokenUsage verifies that tokens generated by one STS instance
-// can be used and validated by other STS instances in a distributed environment
-func TestCrossInstanceTokenUsage(t *testing.T) {
- ctx := context.Background()
- // Dummy filer address for testing
-
- // Common configuration that would be shared across all instances in production
- sharedConfig := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{12 * time.Hour},
- Issuer: "distributed-sts-cluster", // SAME across all instances
- SigningKey: []byte(TestSigningKey32Chars), // SAME across all instances
- Providers: []*ProviderConfig{
- {
- Name: "company-oidc",
- Type: ProviderTypeOIDC,
- Enabled: true,
- Config: map[string]interface{}{
- ConfigFieldIssuer: "https://sso.company.com/realms/production",
- ConfigFieldClientID: "seaweedfs-cluster",
- ConfigFieldJWKSUri: "https://sso.company.com/realms/production/protocol/openid-connect/certs",
- },
- },
- },
- }
-
- // Create multiple STS instances simulating different S3 gateway instances
- instanceA := NewSTSService() // e.g., s3-gateway-1
- instanceB := NewSTSService() // e.g., s3-gateway-2
- instanceC := NewSTSService() // e.g., s3-gateway-3
-
- // Initialize all instances with IDENTICAL configuration
- err := instanceA.Initialize(sharedConfig)
- require.NoError(t, err, "Instance A should initialize")
-
- err = instanceB.Initialize(sharedConfig)
- require.NoError(t, err, "Instance B should initialize")
-
- err = instanceC.Initialize(sharedConfig)
- require.NoError(t, err, "Instance C should initialize")
-
- // Set up mock trust policy validator for all instances (required for STS testing)
- mockValidator := &MockTrustPolicyValidator{}
- instanceA.SetTrustPolicyValidator(mockValidator)
- instanceB.SetTrustPolicyValidator(mockValidator)
- instanceC.SetTrustPolicyValidator(mockValidator)
-
- // Manually register mock provider for testing (not available in production)
- mockProviderConfig := map[string]interface{}{
- ConfigFieldIssuer: "http://test-mock:9999",
- ConfigFieldClientID: TestClientID,
- }
- mockProviderA, err := createMockOIDCProvider("test-mock", mockProviderConfig)
- require.NoError(t, err)
- mockProviderB, err := createMockOIDCProvider("test-mock", mockProviderConfig)
- require.NoError(t, err)
- mockProviderC, err := createMockOIDCProvider("test-mock", mockProviderConfig)
- require.NoError(t, err)
-
- instanceA.RegisterProvider(mockProviderA)
- instanceB.RegisterProvider(mockProviderB)
- instanceC.RegisterProvider(mockProviderC)
-
- // Test 1: Token generated on Instance A can be validated on Instance B & C
- t.Run("cross_instance_token_validation", func(t *testing.T) {
- // Generate session token on Instance A
- sessionId := TestSessionID
- expiresAt := time.Now().Add(time.Hour)
-
- tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
- require.NoError(t, err, "Instance A should generate token")
-
- // Validate token on Instance B
- claimsFromB, err := instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
- require.NoError(t, err, "Instance B should validate token from Instance A")
- assert.Equal(t, sessionId, claimsFromB.SessionId, "Session ID should match")
-
- // Validate same token on Instance C
- claimsFromC, err := instanceC.GetTokenGenerator().ValidateSessionToken(tokenFromA)
- require.NoError(t, err, "Instance C should validate token from Instance A")
- assert.Equal(t, sessionId, claimsFromC.SessionId, "Session ID should match")
-
- // All instances should extract identical claims
- assert.Equal(t, claimsFromB.SessionId, claimsFromC.SessionId)
- assert.Equal(t, claimsFromB.ExpiresAt.Unix(), claimsFromC.ExpiresAt.Unix())
- assert.Equal(t, claimsFromB.IssuedAt.Unix(), claimsFromC.IssuedAt.Unix())
- })
-
- // Test 2: Complete assume role flow across instances
- t.Run("cross_instance_assume_role_flow", func(t *testing.T) {
- // Step 1: User authenticates and assumes role on Instance A
- // Create a valid JWT token for the mock provider
- mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
-
- assumeRequest := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/CrossInstanceTestRole",
- WebIdentityToken: mockToken, // JWT token for mock provider
- RoleSessionName: "cross-instance-test-session",
- DurationSeconds: int64ToPtr(3600),
- }
-
- // Instance A processes assume role request
- responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
- require.NoError(t, err, "Instance A should process assume role")
-
- sessionToken := responseFromA.Credentials.SessionToken
- accessKeyId := responseFromA.Credentials.AccessKeyId
- secretAccessKey := responseFromA.Credentials.SecretAccessKey
-
- // Verify response structure
- assert.NotEmpty(t, sessionToken, "Should have session token")
- assert.NotEmpty(t, accessKeyId, "Should have access key ID")
- assert.NotEmpty(t, secretAccessKey, "Should have secret access key")
- assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user")
-
- // Step 2: Use session token on Instance B (different instance)
- sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, sessionToken)
- require.NoError(t, err, "Instance B should validate session token from Instance A")
-
- assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName)
- assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn)
-
- // Step 3: Use same session token on Instance C (yet another instance)
- sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, sessionToken)
- require.NoError(t, err, "Instance C should validate session token from Instance A")
-
- // All instances should return identical session information
- assert.Equal(t, sessionInfoFromB.SessionId, sessionInfoFromC.SessionId)
- assert.Equal(t, sessionInfoFromB.SessionName, sessionInfoFromC.SessionName)
- assert.Equal(t, sessionInfoFromB.RoleArn, sessionInfoFromC.RoleArn)
- assert.Equal(t, sessionInfoFromB.Subject, sessionInfoFromC.Subject)
- assert.Equal(t, sessionInfoFromB.Provider, sessionInfoFromC.Provider)
- })
-
- // Test 3: Session revocation across instances
- t.Run("cross_instance_session_revocation", func(t *testing.T) {
- // Create session on Instance A
- mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
-
- assumeRequest := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/RevocationTestRole",
- WebIdentityToken: mockToken,
- RoleSessionName: "revocation-test-session",
- }
-
- response, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
- require.NoError(t, err)
- sessionToken := response.Credentials.SessionToken
-
- // Verify token works on Instance B
- _, err = instanceB.ValidateSessionToken(ctx, sessionToken)
- require.NoError(t, err, "Token should be valid on Instance B initially")
-
- // Validate session on Instance C to verify cross-instance token compatibility
- _, err = instanceC.ValidateSessionToken(ctx, sessionToken)
- require.NoError(t, err, "Instance C should be able to validate session token")
-
- // In a stateless JWT system, tokens remain valid on all instances since they're self-contained
- // No revocation is possible without breaking the stateless architecture
- _, err = instanceA.ValidateSessionToken(ctx, sessionToken)
- assert.NoError(t, err, "Token should still be valid on Instance A (stateless system)")
-
- // Verify token is still valid on Instance B
- _, err = instanceB.ValidateSessionToken(ctx, sessionToken)
- assert.NoError(t, err, "Token should still be valid on Instance B (stateless system)")
- })
-
- // Test 4: Provider consistency across instances
- t.Run("provider_consistency_affects_token_generation", func(t *testing.T) {
- // All instances should have same providers and be able to process same OIDC tokens
- providerNamesA := instanceA.getProviderNames()
- providerNamesB := instanceB.getProviderNames()
- providerNamesC := instanceC.getProviderNames()
-
- assert.ElementsMatch(t, providerNamesA, providerNamesB, "Instance A and B should have same providers")
- assert.ElementsMatch(t, providerNamesB, providerNamesC, "Instance B and C should have same providers")
-
- // All instances should be able to process same web identity token
- testToken := createMockJWT(t, "http://test-mock:9999", "test-user")
-
- // Try to assume role with same token on different instances
- assumeRequest := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/ProviderTestRole",
- WebIdentityToken: testToken,
- RoleSessionName: "provider-consistency-test",
- }
-
- // Should work on any instance
- responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
- responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, assumeRequest)
- responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, assumeRequest)
-
- require.NoError(t, errA, "Instance A should process OIDC token")
- require.NoError(t, errB, "Instance B should process OIDC token")
- require.NoError(t, errC, "Instance C should process OIDC token")
-
- // All should return valid responses (sessions will have different IDs but same structure)
- assert.NotEmpty(t, responseA.Credentials.SessionToken)
- assert.NotEmpty(t, responseB.Credentials.SessionToken)
- assert.NotEmpty(t, responseC.Credentials.SessionToken)
- })
-}
-
-// TestSTSDistributedConfigurationRequirements tests the configuration requirements
-// for cross-instance token compatibility
-func TestSTSDistributedConfigurationRequirements(t *testing.T) {
- _ = "localhost:8888" // Dummy filer address for testing (not used in these tests)
-
- t.Run("same_signing_key_required", func(t *testing.T) {
- // Instance A with signing key 1
- configA := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{12 * time.Hour},
- Issuer: "test-sts",
- SigningKey: []byte("signing-key-1-32-characters-long"),
- }
-
- // Instance B with different signing key
- configB := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{12 * time.Hour},
- Issuer: "test-sts",
- SigningKey: []byte("signing-key-2-32-characters-long"), // DIFFERENT!
- }
-
- instanceA := NewSTSService()
- instanceB := NewSTSService()
-
- err := instanceA.Initialize(configA)
- require.NoError(t, err)
-
- err = instanceB.Initialize(configB)
- require.NoError(t, err)
-
- // Generate token on Instance A
- sessionId := "test-session"
- expiresAt := time.Now().Add(time.Hour)
- tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
- require.NoError(t, err)
-
- // Instance A should validate its own token
- _, err = instanceA.GetTokenGenerator().ValidateSessionToken(tokenFromA)
- assert.NoError(t, err, "Instance A should validate own token")
-
- // Instance B should REJECT token due to different signing key
- _, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
- assert.Error(t, err, "Instance B should reject token with different signing key")
- assert.Contains(t, err.Error(), "invalid token", "Should be signature validation error")
- })
-
- t.Run("same_issuer_required", func(t *testing.T) {
- sharedSigningKey := []byte("shared-signing-key-32-characters-lo")
-
- // Instance A with issuer 1
- configA := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{12 * time.Hour},
- Issuer: "sts-cluster-1",
- SigningKey: sharedSigningKey,
- }
-
- // Instance B with different issuer
- configB := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{12 * time.Hour},
- Issuer: "sts-cluster-2", // DIFFERENT!
- SigningKey: sharedSigningKey,
- }
-
- instanceA := NewSTSService()
- instanceB := NewSTSService()
-
- err := instanceA.Initialize(configA)
- require.NoError(t, err)
-
- err = instanceB.Initialize(configB)
- require.NoError(t, err)
-
- // Generate token on Instance A
- sessionId := "test-session"
- expiresAt := time.Now().Add(time.Hour)
- tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
- require.NoError(t, err)
-
- // Instance B should REJECT token due to different issuer
- _, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
- assert.Error(t, err, "Instance B should reject token with different issuer")
- assert.Contains(t, err.Error(), "invalid issuer", "Should be issuer validation error")
- })
-
- t.Run("identical_configuration_required", func(t *testing.T) {
- // Identical configuration
- identicalConfig := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{12 * time.Hour},
- Issuer: "production-sts-cluster",
- SigningKey: []byte("production-signing-key-32-chars-l"),
- }
-
- // Create multiple instances with identical config
- instances := make([]*STSService, 5)
- for i := 0; i < 5; i++ {
- instances[i] = NewSTSService()
- err := instances[i].Initialize(identicalConfig)
- require.NoError(t, err, "Instance %d should initialize", i)
- }
-
- // Generate token on Instance 0
- sessionId := "multi-instance-test"
- expiresAt := time.Now().Add(time.Hour)
- token, err := instances[0].GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
- require.NoError(t, err)
-
- // All other instances should validate the token
- for i := 1; i < 5; i++ {
- claims, err := instances[i].GetTokenGenerator().ValidateSessionToken(token)
- require.NoError(t, err, "Instance %d should validate token", i)
- assert.Equal(t, sessionId, claims.SessionId, "Instance %d should extract correct session ID", i)
- }
- })
-}
-
-// TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios
-func TestSTSRealWorldDistributedScenarios(t *testing.T) {
- ctx := context.Background()
-
- t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) {
- // Simulate real production scenario:
- // 1. User authenticates with OIDC provider
- // 2. User calls AssumeRoleWithWebIdentity on S3 Gateway 1
- // 3. User makes S3 requests that hit S3 Gateway 2 & 3 via load balancer
- // 4. All instances should handle the session token correctly
-
- productionConfig := &STSConfig{
- TokenDuration: FlexibleDuration{2 * time.Hour},
- MaxSessionLength: FlexibleDuration{24 * time.Hour},
- Issuer: "seaweedfs-production-sts",
- SigningKey: []byte("prod-signing-key-32-characters-lon"),
-
- Providers: []*ProviderConfig{
- {
- Name: "corporate-oidc",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "issuer": "https://sso.company.com/realms/production",
- "clientId": "seaweedfs-prod-cluster",
- "clientSecret": "supersecret-prod-key",
- "scopes": []string{"openid", "profile", "email", "groups"},
- },
- },
- },
- }
-
- // Create 3 S3 Gateway instances behind load balancer
- gateway1 := NewSTSService()
- gateway2 := NewSTSService()
- gateway3 := NewSTSService()
-
- err := gateway1.Initialize(productionConfig)
- require.NoError(t, err)
-
- err = gateway2.Initialize(productionConfig)
- require.NoError(t, err)
-
- err = gateway3.Initialize(productionConfig)
- require.NoError(t, err)
-
- // Set up mock trust policy validator for all gateway instances
- mockValidator := &MockTrustPolicyValidator{}
- gateway1.SetTrustPolicyValidator(mockValidator)
- gateway2.SetTrustPolicyValidator(mockValidator)
- gateway3.SetTrustPolicyValidator(mockValidator)
-
- // Manually register mock provider for testing (not available in production)
- mockProviderConfig := map[string]interface{}{
- ConfigFieldIssuer: "http://test-mock:9999",
- ConfigFieldClientID: "test-client-id",
- }
- mockProvider1, err := createMockOIDCProvider("test-mock", mockProviderConfig)
- require.NoError(t, err)
- mockProvider2, err := createMockOIDCProvider("test-mock", mockProviderConfig)
- require.NoError(t, err)
- mockProvider3, err := createMockOIDCProvider("test-mock", mockProviderConfig)
- require.NoError(t, err)
-
- gateway1.RegisterProvider(mockProvider1)
- gateway2.RegisterProvider(mockProvider2)
- gateway3.RegisterProvider(mockProvider3)
-
- // Step 1: User authenticates and hits Gateway 1 for AssumeRole
- mockToken := createMockJWT(t, "http://test-mock:9999", "production-user")
-
- assumeRequest := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/ProductionS3User",
- WebIdentityToken: mockToken, // JWT token from mock provider
- RoleSessionName: "user-production-session",
- DurationSeconds: int64ToPtr(7200), // 2 hours
- }
-
- stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, assumeRequest)
- require.NoError(t, err, "Gateway 1 should handle AssumeRole")
-
- sessionToken := stsResponse.Credentials.SessionToken
- accessKey := stsResponse.Credentials.AccessKeyId
- secretKey := stsResponse.Credentials.SecretAccessKey
-
- // Step 2: User makes S3 requests that hit different gateways via load balancer
- // Simulate S3 request validation on Gateway 2
- sessionInfo2, err := gateway2.ValidateSessionToken(ctx, sessionToken)
- require.NoError(t, err, "Gateway 2 should validate session from Gateway 1")
- assert.Equal(t, "user-production-session", sessionInfo2.SessionName)
- assert.Equal(t, "arn:aws:iam::role/ProductionS3User", sessionInfo2.RoleArn)
-
- // Simulate S3 request validation on Gateway 3
- sessionInfo3, err := gateway3.ValidateSessionToken(ctx, sessionToken)
- require.NoError(t, err, "Gateway 3 should validate session from Gateway 1")
- assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session")
-
- // Step 3: Verify credentials are consistent
- assert.Equal(t, accessKey, stsResponse.Credentials.AccessKeyId, "Access key should be consistent")
- assert.Equal(t, secretKey, stsResponse.Credentials.SecretAccessKey, "Secret key should be consistent")
-
- // Step 4: Session expiration should be honored across all instances
- assert.True(t, sessionInfo2.ExpiresAt.After(time.Now()), "Session should not be expired")
- assert.True(t, sessionInfo3.ExpiresAt.After(time.Now()), "Session should not be expired")
-
- // Step 5: Token should be identical when parsed
- claims2, err := gateway2.GetTokenGenerator().ValidateSessionToken(sessionToken)
- require.NoError(t, err)
-
- claims3, err := gateway3.GetTokenGenerator().ValidateSessionToken(sessionToken)
- require.NoError(t, err)
-
- assert.Equal(t, claims2.SessionId, claims3.SessionId, "Session IDs should match")
- assert.Equal(t, claims2.ExpiresAt.Unix(), claims3.ExpiresAt.Unix(), "Expiration should match")
- })
-}
-
-// Helper function to convert int64 to pointer
-func int64ToPtr(i int64) *int64 {
- return &i
-}
diff --git a/weed/iam/sts/distributed_sts_test.go b/weed/iam/sts/distributed_sts_test.go
deleted file mode 100644
index 7997e7b8e..000000000
--- a/weed/iam/sts/distributed_sts_test.go
+++ /dev/null
@@ -1,340 +0,0 @@
-package sts
-
-import (
- "context"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// TestDistributedSTSService verifies that multiple STS instances with identical configurations
-// behave consistently across distributed environments
-func TestDistributedSTSService(t *testing.T) {
- ctx := context.Background()
-
- // Common configuration for all instances
- commonConfig := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{12 * time.Hour},
- Issuer: "distributed-sts-test",
- SigningKey: []byte("test-signing-key-32-characters-long"),
-
- Providers: []*ProviderConfig{
- {
- Name: "keycloak-oidc",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "issuer": "http://keycloak:8080/realms/seaweedfs-test",
- "clientId": "seaweedfs-s3",
- "jwksUri": "http://keycloak:8080/realms/seaweedfs-test/protocol/openid-connect/certs",
- },
- },
-
- {
- Name: "disabled-ldap",
- Type: "oidc", // Use OIDC as placeholder since LDAP isn't implemented
- Enabled: false,
- Config: map[string]interface{}{
- "issuer": "ldap://company.com",
- "clientId": "ldap-client",
- },
- },
- },
- }
-
- // Create multiple STS instances simulating distributed deployment
- instance1 := NewSTSService()
- instance2 := NewSTSService()
- instance3 := NewSTSService()
-
- // Initialize all instances with identical configuration
- err := instance1.Initialize(commonConfig)
- require.NoError(t, err, "Instance 1 should initialize successfully")
-
- err = instance2.Initialize(commonConfig)
- require.NoError(t, err, "Instance 2 should initialize successfully")
-
- err = instance3.Initialize(commonConfig)
- require.NoError(t, err, "Instance 3 should initialize successfully")
-
- // Manually register mock providers for testing (not available in production)
- mockProviderConfig := map[string]interface{}{
- "issuer": "http://localhost:9999",
- "clientId": "test-client",
- }
- mockProvider1, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
- require.NoError(t, err)
- mockProvider2, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
- require.NoError(t, err)
- mockProvider3, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
- require.NoError(t, err)
-
- instance1.RegisterProvider(mockProvider1)
- instance2.RegisterProvider(mockProvider2)
- instance3.RegisterProvider(mockProvider3)
-
- // Verify all instances have identical provider configurations
- t.Run("provider_consistency", func(t *testing.T) {
- // All instances should have same number of providers
- assert.Len(t, instance1.providers, 2, "Instance 1 should have 2 enabled providers")
- assert.Len(t, instance2.providers, 2, "Instance 2 should have 2 enabled providers")
- assert.Len(t, instance3.providers, 2, "Instance 3 should have 2 enabled providers")
-
- // All instances should have same provider names
- instance1Names := instance1.getProviderNames()
- instance2Names := instance2.getProviderNames()
- instance3Names := instance3.getProviderNames()
-
- assert.ElementsMatch(t, instance1Names, instance2Names, "Instance 1 and 2 should have same providers")
- assert.ElementsMatch(t, instance2Names, instance3Names, "Instance 2 and 3 should have same providers")
-
- // Verify specific providers exist on all instances
- expectedProviders := []string{"keycloak-oidc", "test-mock-provider"}
- assert.ElementsMatch(t, instance1Names, expectedProviders, "Instance 1 should have expected providers")
- assert.ElementsMatch(t, instance2Names, expectedProviders, "Instance 2 should have expected providers")
- assert.ElementsMatch(t, instance3Names, expectedProviders, "Instance 3 should have expected providers")
-
- // Verify disabled providers are not loaded
- assert.NotContains(t, instance1Names, "disabled-ldap", "Disabled providers should not be loaded")
- assert.NotContains(t, instance2Names, "disabled-ldap", "Disabled providers should not be loaded")
- assert.NotContains(t, instance3Names, "disabled-ldap", "Disabled providers should not be loaded")
- })
-
- // Test token generation consistency across instances
- t.Run("token_generation_consistency", func(t *testing.T) {
- sessionId := "test-session-123"
- expiresAt := time.Now().Add(time.Hour)
-
- // Generate tokens from different instances
- token1, err1 := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
- token2, err2 := instance2.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
- token3, err3 := instance3.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
-
- require.NoError(t, err1, "Instance 1 token generation should succeed")
- require.NoError(t, err2, "Instance 2 token generation should succeed")
- require.NoError(t, err3, "Instance 3 token generation should succeed")
-
- // All tokens should be different (due to timestamp variations)
- // But they should all be valid JWTs with same signing key
- assert.NotEmpty(t, token1)
- assert.NotEmpty(t, token2)
- assert.NotEmpty(t, token3)
- })
-
- // Test token validation consistency - any instance should validate tokens from any other instance
- t.Run("cross_instance_token_validation", func(t *testing.T) {
- sessionId := "cross-validation-session"
- expiresAt := time.Now().Add(time.Hour)
-
- // Generate token on instance 1
- token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
- require.NoError(t, err)
-
- // Validate on all instances
- claims1, err1 := instance1.GetTokenGenerator().ValidateSessionToken(token)
- claims2, err2 := instance2.GetTokenGenerator().ValidateSessionToken(token)
- claims3, err3 := instance3.GetTokenGenerator().ValidateSessionToken(token)
-
- require.NoError(t, err1, "Instance 1 should validate token from instance 1")
- require.NoError(t, err2, "Instance 2 should validate token from instance 1")
- require.NoError(t, err3, "Instance 3 should validate token from instance 1")
-
- // All instances should extract same session ID
- assert.Equal(t, sessionId, claims1.SessionId)
- assert.Equal(t, sessionId, claims2.SessionId)
- assert.Equal(t, sessionId, claims3.SessionId)
-
- assert.Equal(t, claims1.SessionId, claims2.SessionId)
- assert.Equal(t, claims2.SessionId, claims3.SessionId)
- })
-
- // Test provider access consistency
- t.Run("provider_access_consistency", func(t *testing.T) {
- // All instances should be able to access the same providers
- provider1, exists1 := instance1.providers["test-mock-provider"]
- provider2, exists2 := instance2.providers["test-mock-provider"]
- provider3, exists3 := instance3.providers["test-mock-provider"]
-
- assert.True(t, exists1, "Instance 1 should have test-mock-provider")
- assert.True(t, exists2, "Instance 2 should have test-mock-provider")
- assert.True(t, exists3, "Instance 3 should have test-mock-provider")
-
- assert.Equal(t, provider1.Name(), provider2.Name())
- assert.Equal(t, provider2.Name(), provider3.Name())
-
- // Test authentication with the mock provider on all instances
- testToken := "valid_test_token"
-
- identity1, err1 := provider1.Authenticate(ctx, testToken)
- identity2, err2 := provider2.Authenticate(ctx, testToken)
- identity3, err3 := provider3.Authenticate(ctx, testToken)
-
- require.NoError(t, err1, "Instance 1 provider should authenticate successfully")
- require.NoError(t, err2, "Instance 2 provider should authenticate successfully")
- require.NoError(t, err3, "Instance 3 provider should authenticate successfully")
-
- // All instances should return identical identity information
- assert.Equal(t, identity1.UserID, identity2.UserID)
- assert.Equal(t, identity2.UserID, identity3.UserID)
- assert.Equal(t, identity1.Email, identity2.Email)
- assert.Equal(t, identity2.Email, identity3.Email)
- assert.Equal(t, identity1.Provider, identity2.Provider)
- assert.Equal(t, identity2.Provider, identity3.Provider)
- })
-}
-
-// TestSTSConfigurationValidation tests configuration validation for distributed deployments
-func TestSTSConfigurationValidation(t *testing.T) {
- t.Run("consistent_signing_keys_required", func(t *testing.T) {
- // Different signing keys should result in incompatible token validation
- config1 := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{12 * time.Hour},
- Issuer: "test-sts",
- SigningKey: []byte("signing-key-1-32-characters-long"),
- }
-
- config2 := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{12 * time.Hour},
- Issuer: "test-sts",
- SigningKey: []byte("signing-key-2-32-characters-long"), // Different key!
- }
-
- instance1 := NewSTSService()
- instance2 := NewSTSService()
-
- err1 := instance1.Initialize(config1)
- err2 := instance2.Initialize(config2)
-
- require.NoError(t, err1)
- require.NoError(t, err2)
-
- // Generate token on instance 1
- sessionId := "test-session"
- expiresAt := time.Now().Add(time.Hour)
- token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
- require.NoError(t, err)
-
- // Instance 1 should validate its own token
- _, err = instance1.GetTokenGenerator().ValidateSessionToken(token)
- assert.NoError(t, err, "Instance 1 should validate its own token")
-
- // Instance 2 should reject token from instance 1 (different signing key)
- _, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
- assert.Error(t, err, "Instance 2 should reject token with different signing key")
- })
-
- t.Run("consistent_issuer_required", func(t *testing.T) {
- // Different issuers should result in incompatible tokens
- commonSigningKey := []byte("shared-signing-key-32-characters-lo")
-
- config1 := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{12 * time.Hour},
- Issuer: "sts-instance-1",
- SigningKey: commonSigningKey,
- }
-
- config2 := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{12 * time.Hour},
- Issuer: "sts-instance-2", // Different issuer!
- SigningKey: commonSigningKey,
- }
-
- instance1 := NewSTSService()
- instance2 := NewSTSService()
-
- err1 := instance1.Initialize(config1)
- err2 := instance2.Initialize(config2)
-
- require.NoError(t, err1)
- require.NoError(t, err2)
-
- // Generate token on instance 1
- sessionId := "test-session"
- expiresAt := time.Now().Add(time.Hour)
- token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
- require.NoError(t, err)
-
- // Instance 2 should reject token due to issuer mismatch
- // (Even though signing key is the same, issuer validation will fail)
- _, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
- assert.Error(t, err, "Instance 2 should reject token with different issuer")
- })
-}
-
-// TestProviderFactoryDistributed tests the provider factory in distributed scenarios
-func TestProviderFactoryDistributed(t *testing.T) {
- factory := NewProviderFactory()
-
- // Simulate configuration that would be identical across all instances
- configs := []*ProviderConfig{
- {
- Name: "production-keycloak",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "issuer": "https://keycloak.company.com/realms/seaweedfs",
- "clientId": "seaweedfs-prod",
- "clientSecret": "super-secret-key",
- "jwksUri": "https://keycloak.company.com/realms/seaweedfs/protocol/openid-connect/certs",
- "scopes": []string{"openid", "profile", "email", "roles"},
- },
- },
- {
- Name: "backup-oidc",
- Type: "oidc",
- Enabled: false, // Disabled by default
- Config: map[string]interface{}{
- "issuer": "https://backup-oidc.company.com",
- "clientId": "seaweedfs-backup",
- },
- },
- }
-
- // Create providers multiple times (simulating multiple instances)
- providers1, err1 := factory.LoadProvidersFromConfig(configs)
- providers2, err2 := factory.LoadProvidersFromConfig(configs)
- providers3, err3 := factory.LoadProvidersFromConfig(configs)
-
- require.NoError(t, err1, "First load should succeed")
- require.NoError(t, err2, "Second load should succeed")
- require.NoError(t, err3, "Third load should succeed")
-
- // All instances should have same provider counts
- assert.Len(t, providers1, 1, "First instance should have 1 enabled provider")
- assert.Len(t, providers2, 1, "Second instance should have 1 enabled provider")
- assert.Len(t, providers3, 1, "Third instance should have 1 enabled provider")
-
- // All instances should have same provider names
- names1 := make([]string, 0, len(providers1))
- names2 := make([]string, 0, len(providers2))
- names3 := make([]string, 0, len(providers3))
-
- for name := range providers1 {
- names1 = append(names1, name)
- }
- for name := range providers2 {
- names2 = append(names2, name)
- }
- for name := range providers3 {
- names3 = append(names3, name)
- }
-
- assert.ElementsMatch(t, names1, names2, "Instance 1 and 2 should have same provider names")
- assert.ElementsMatch(t, names2, names3, "Instance 2 and 3 should have same provider names")
-
- // Verify specific providers
- expectedProviders := []string{"production-keycloak"}
- assert.ElementsMatch(t, names1, expectedProviders, "Should have expected enabled providers")
-
- // Verify disabled providers are not included
- assert.NotContains(t, names1, "backup-oidc", "Disabled providers should not be loaded")
- assert.NotContains(t, names2, "backup-oidc", "Disabled providers should not be loaded")
- assert.NotContains(t, names3, "backup-oidc", "Disabled providers should not be loaded")
-}
diff --git a/weed/iam/sts/provider_factory.go b/weed/iam/sts/provider_factory.go
index 53635c8f2..eb87d6d7e 100644
--- a/weed/iam/sts/provider_factory.go
+++ b/weed/iam/sts/provider_factory.go
@@ -274,69 +274,3 @@ func (f *ProviderFactory) convertToRoleMapping(value interface{}) (*providers.Ro
return roleMapping, nil
}
-
-// ValidateProviderConfig validates a provider configuration
-func (f *ProviderFactory) ValidateProviderConfig(config *ProviderConfig) error {
- if config == nil {
- return fmt.Errorf("provider config cannot be nil")
- }
-
- if config.Name == "" {
- return fmt.Errorf("provider name cannot be empty")
- }
-
- if config.Type == "" {
- return fmt.Errorf("provider type cannot be empty")
- }
-
- if config.Config == nil {
- return fmt.Errorf("provider config cannot be nil")
- }
-
- // Type-specific validation
- switch config.Type {
- case "oidc":
- return f.validateOIDCConfig(config.Config)
- case "ldap":
- return f.validateLDAPConfig(config.Config)
- case "saml":
- return f.validateSAMLConfig(config.Config)
- default:
- return fmt.Errorf("unsupported provider type: %s", config.Type)
- }
-}
-
-// validateOIDCConfig validates OIDC provider configuration
-func (f *ProviderFactory) validateOIDCConfig(config map[string]interface{}) error {
- if _, ok := config[ConfigFieldIssuer]; !ok {
- return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldIssuer)
- }
-
- if _, ok := config[ConfigFieldClientID]; !ok {
- return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldClientID)
- }
-
- return nil
-}
-
-// validateLDAPConfig validates LDAP provider configuration
-func (f *ProviderFactory) validateLDAPConfig(config map[string]interface{}) error {
- if _, ok := config["server"]; !ok {
- return fmt.Errorf("LDAP provider requires 'server' field")
- }
- if _, ok := config["baseDN"]; !ok {
- return fmt.Errorf("LDAP provider requires 'baseDN' field")
- }
- return nil
-}
-
-// validateSAMLConfig validates SAML provider configuration
-func (f *ProviderFactory) validateSAMLConfig(config map[string]interface{}) error {
- // TODO: Implement when SAML provider is available
- return nil
-}
-
-// GetSupportedProviderTypes returns list of supported provider types
-func (f *ProviderFactory) GetSupportedProviderTypes() []string {
- return []string{ProviderTypeOIDC}
-}
diff --git a/weed/iam/sts/provider_factory_test.go b/weed/iam/sts/provider_factory_test.go
deleted file mode 100644
index 8c36142a7..000000000
--- a/weed/iam/sts/provider_factory_test.go
+++ /dev/null
@@ -1,312 +0,0 @@
-package sts
-
-import (
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestProviderFactory_CreateOIDCProvider(t *testing.T) {
- factory := NewProviderFactory()
-
- config := &ProviderConfig{
- Name: "test-oidc",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "issuer": "https://test-issuer.com",
- "clientId": "test-client",
- "clientSecret": "test-secret",
- "jwksUri": "https://test-issuer.com/.well-known/jwks.json",
- "scopes": []string{"openid", "profile", "email"},
- },
- }
-
- provider, err := factory.CreateProvider(config)
- require.NoError(t, err)
- assert.NotNil(t, provider)
- assert.Equal(t, "test-oidc", provider.Name())
-}
-
-// Note: Mock provider tests removed - mock providers are now test-only
-// and not available through the production ProviderFactory
-
-func TestProviderFactory_DisabledProvider(t *testing.T) {
- factory := NewProviderFactory()
-
- config := &ProviderConfig{
- Name: "disabled-provider",
- Type: "oidc",
- Enabled: false,
- Config: map[string]interface{}{
- "issuer": "https://test-issuer.com",
- "clientId": "test-client",
- },
- }
-
- provider, err := factory.CreateProvider(config)
- require.NoError(t, err)
- assert.Nil(t, provider) // Should return nil for disabled providers
-}
-
-func TestProviderFactory_InvalidProviderType(t *testing.T) {
- factory := NewProviderFactory()
-
- config := &ProviderConfig{
- Name: "invalid-provider",
- Type: "unsupported-type",
- Enabled: true,
- Config: map[string]interface{}{},
- }
-
- provider, err := factory.CreateProvider(config)
- assert.Error(t, err)
- assert.Nil(t, provider)
- assert.Contains(t, err.Error(), "unsupported provider type")
-}
-
-func TestProviderFactory_LoadMultipleProviders(t *testing.T) {
- factory := NewProviderFactory()
-
- configs := []*ProviderConfig{
- {
- Name: "oidc-provider",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "issuer": "https://oidc-issuer.com",
- "clientId": "oidc-client",
- },
- },
-
- {
- Name: "disabled-provider",
- Type: "oidc",
- Enabled: false,
- Config: map[string]interface{}{
- "issuer": "https://disabled-issuer.com",
- "clientId": "disabled-client",
- },
- },
- }
-
- providers, err := factory.LoadProvidersFromConfig(configs)
- require.NoError(t, err)
- assert.Len(t, providers, 1) // Only enabled providers should be loaded
-
- assert.Contains(t, providers, "oidc-provider")
- assert.NotContains(t, providers, "disabled-provider")
-}
-
-func TestProviderFactory_ValidateOIDCConfig(t *testing.T) {
- factory := NewProviderFactory()
-
- t.Run("valid config", func(t *testing.T) {
- config := &ProviderConfig{
- Name: "valid-oidc",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "issuer": "https://valid-issuer.com",
- "clientId": "valid-client",
- },
- }
-
- err := factory.ValidateProviderConfig(config)
- assert.NoError(t, err)
- })
-
- t.Run("missing issuer", func(t *testing.T) {
- config := &ProviderConfig{
- Name: "invalid-oidc",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "clientId": "valid-client",
- },
- }
-
- err := factory.ValidateProviderConfig(config)
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "issuer")
- })
-
- t.Run("missing clientId", func(t *testing.T) {
- config := &ProviderConfig{
- Name: "invalid-oidc",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "issuer": "https://valid-issuer.com",
- },
- }
-
- err := factory.ValidateProviderConfig(config)
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "clientId")
- })
-}
-
-func TestProviderFactory_ConvertToStringSlice(t *testing.T) {
- factory := NewProviderFactory()
-
- t.Run("string slice", func(t *testing.T) {
- input := []string{"a", "b", "c"}
- result, err := factory.convertToStringSlice(input)
- require.NoError(t, err)
- assert.Equal(t, []string{"a", "b", "c"}, result)
- })
-
- t.Run("interface slice", func(t *testing.T) {
- input := []interface{}{"a", "b", "c"}
- result, err := factory.convertToStringSlice(input)
- require.NoError(t, err)
- assert.Equal(t, []string{"a", "b", "c"}, result)
- })
-
- t.Run("invalid type", func(t *testing.T) {
- input := "not-a-slice"
- result, err := factory.convertToStringSlice(input)
- assert.Error(t, err)
- assert.Nil(t, result)
- })
-}
-
-func TestProviderFactory_ConfigConversionErrors(t *testing.T) {
- factory := NewProviderFactory()
-
- t.Run("invalid scopes type", func(t *testing.T) {
- config := &ProviderConfig{
- Name: "invalid-scopes",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "issuer": "https://test-issuer.com",
- "clientId": "test-client",
- "scopes": "invalid-not-array", // Should be array
- },
- }
-
- provider, err := factory.CreateProvider(config)
- assert.Error(t, err)
- assert.Nil(t, provider)
- assert.Contains(t, err.Error(), "failed to convert scopes")
- })
-
- t.Run("invalid claimsMapping type", func(t *testing.T) {
- config := &ProviderConfig{
- Name: "invalid-claims",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "issuer": "https://test-issuer.com",
- "clientId": "test-client",
- "claimsMapping": "invalid-not-map", // Should be map
- },
- }
-
- provider, err := factory.CreateProvider(config)
- assert.Error(t, err)
- assert.Nil(t, provider)
- assert.Contains(t, err.Error(), "failed to convert claimsMapping")
- })
-
- t.Run("invalid roleMapping type", func(t *testing.T) {
- config := &ProviderConfig{
- Name: "invalid-roles",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "issuer": "https://test-issuer.com",
- "clientId": "test-client",
- "roleMapping": "invalid-not-map", // Should be map
- },
- }
-
- provider, err := factory.CreateProvider(config)
- assert.Error(t, err)
- assert.Nil(t, provider)
- assert.Contains(t, err.Error(), "failed to convert roleMapping")
- })
-}
-
-func TestProviderFactory_ConvertToStringMap(t *testing.T) {
- factory := NewProviderFactory()
-
- t.Run("string map", func(t *testing.T) {
- input := map[string]string{"key1": "value1", "key2": "value2"}
- result, err := factory.convertToStringMap(input)
- require.NoError(t, err)
- assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
- })
-
- t.Run("interface map", func(t *testing.T) {
- input := map[string]interface{}{"key1": "value1", "key2": "value2"}
- result, err := factory.convertToStringMap(input)
- require.NoError(t, err)
- assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
- })
-
- t.Run("invalid type", func(t *testing.T) {
- input := "not-a-map"
- result, err := factory.convertToStringMap(input)
- assert.Error(t, err)
- assert.Nil(t, result)
- })
-}
-
-func TestProviderFactory_GetSupportedProviderTypes(t *testing.T) {
- factory := NewProviderFactory()
-
- supportedTypes := factory.GetSupportedProviderTypes()
- assert.Contains(t, supportedTypes, "oidc")
- assert.Len(t, supportedTypes, 1) // Currently only OIDC is supported in production
-}
-
-func TestSTSService_LoadProvidersFromConfig(t *testing.T) {
- stsConfig := &STSConfig{
- TokenDuration: FlexibleDuration{3600 * time.Second},
- MaxSessionLength: FlexibleDuration{43200 * time.Second},
- Issuer: "test-issuer",
- SigningKey: []byte("test-signing-key-32-characters-long"),
- Providers: []*ProviderConfig{
- {
- Name: "test-provider",
- Type: "oidc",
- Enabled: true,
- Config: map[string]interface{}{
- "issuer": "https://test-issuer.com",
- "clientId": "test-client",
- },
- },
- },
- }
-
- stsService := NewSTSService()
- err := stsService.Initialize(stsConfig)
- require.NoError(t, err)
-
- // Check that provider was loaded
- assert.Len(t, stsService.providers, 1)
- assert.Contains(t, stsService.providers, "test-provider")
- assert.Equal(t, "test-provider", stsService.providers["test-provider"].Name())
-}
-
-func TestSTSService_NoProvidersConfig(t *testing.T) {
- stsConfig := &STSConfig{
- TokenDuration: FlexibleDuration{3600 * time.Second},
- MaxSessionLength: FlexibleDuration{43200 * time.Second},
- Issuer: "test-issuer",
- SigningKey: []byte("test-signing-key-32-characters-long"),
- // No providers configured
- }
-
- stsService := NewSTSService()
- err := stsService.Initialize(stsConfig)
- require.NoError(t, err)
-
- // Should initialize successfully with no providers
- assert.Len(t, stsService.providers, 0)
-}
diff --git a/weed/iam/sts/security_test.go b/weed/iam/sts/security_test.go
deleted file mode 100644
index 2d230d796..000000000
--- a/weed/iam/sts/security_test.go
+++ /dev/null
@@ -1,193 +0,0 @@
-package sts
-
-import (
- "context"
- "fmt"
- "strings"
- "testing"
- "time"
-
- "github.com/golang-jwt/jwt/v5"
- "github.com/seaweedfs/seaweedfs/weed/iam/providers"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// TestSecurityIssuerToProviderMapping tests the security fix that ensures JWT tokens
-// with specific issuer claims can only be validated by the provider registered for that issuer
-func TestSecurityIssuerToProviderMapping(t *testing.T) {
- ctx := context.Background()
-
- // Create STS service with two mock providers
- service := NewSTSService()
- config := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{time.Hour * 12},
- Issuer: "test-sts",
- SigningKey: []byte("test-signing-key-32-characters-long"),
- }
-
- err := service.Initialize(config)
- require.NoError(t, err)
-
- // Set up mock trust policy validator
- mockValidator := &MockTrustPolicyValidator{}
- service.SetTrustPolicyValidator(mockValidator)
-
- // Create two mock providers with different issuers
- providerA := &MockIdentityProviderWithIssuer{
- name: "provider-a",
- issuer: "https://provider-a.com",
- validTokens: map[string]bool{
- "token-for-provider-a": true,
- },
- }
-
- providerB := &MockIdentityProviderWithIssuer{
- name: "provider-b",
- issuer: "https://provider-b.com",
- validTokens: map[string]bool{
- "token-for-provider-b": true,
- },
- }
-
- // Register both providers
- err = service.RegisterProvider(providerA)
- require.NoError(t, err)
- err = service.RegisterProvider(providerB)
- require.NoError(t, err)
-
- // Create JWT tokens with specific issuer claims
- tokenForProviderA := createTestJWT(t, "https://provider-a.com", "user-a")
- tokenForProviderB := createTestJWT(t, "https://provider-b.com", "user-b")
-
- t.Run("jwt_token_with_issuer_a_only_validated_by_provider_a", func(t *testing.T) {
- // This should succeed - token has issuer A and provider A is registered
- identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderA)
- assert.NoError(t, err)
- assert.NotNil(t, identity)
- assert.Equal(t, "provider-a", provider.Name())
- })
-
- t.Run("jwt_token_with_issuer_b_only_validated_by_provider_b", func(t *testing.T) {
- // This should succeed - token has issuer B and provider B is registered
- identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderB)
- assert.NoError(t, err)
- assert.NotNil(t, identity)
- assert.Equal(t, "provider-b", provider.Name())
- })
-
- t.Run("jwt_token_with_unregistered_issuer_fails", func(t *testing.T) {
- // Create token with unregistered issuer
- tokenWithUnknownIssuer := createTestJWT(t, "https://unknown-issuer.com", "user-x")
-
- // This should fail - no provider registered for this issuer
- identity, provider, err := service.validateWebIdentityToken(ctx, tokenWithUnknownIssuer)
- assert.Error(t, err)
- assert.Nil(t, identity)
- assert.Nil(t, provider)
- assert.Contains(t, err.Error(), "no identity provider registered for issuer: https://unknown-issuer.com")
- })
-
- t.Run("non_jwt_tokens_are_rejected", func(t *testing.T) {
- // Non-JWT tokens should be rejected - no fallback mechanism exists for security
- identity, provider, err := service.validateWebIdentityToken(ctx, "token-for-provider-a")
- assert.Error(t, err)
- assert.Nil(t, identity)
- assert.Nil(t, provider)
- assert.Contains(t, err.Error(), "web identity token must be a valid JWT token")
- })
-}
-
-// createTestJWT creates a test JWT token with the specified issuer and subject
-func createTestJWT(t *testing.T, issuer, subject string) string {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "iss": issuer,
- "sub": subject,
- "aud": "test-client",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- })
-
- tokenString, err := token.SignedString([]byte("test-signing-key"))
- require.NoError(t, err)
- return tokenString
-}
-
-// MockIdentityProviderWithIssuer is a mock provider that supports issuer mapping
-type MockIdentityProviderWithIssuer struct {
- name string
- issuer string
- validTokens map[string]bool
-}
-
-func (m *MockIdentityProviderWithIssuer) Name() string {
- return m.name
-}
-
-func (m *MockIdentityProviderWithIssuer) GetIssuer() string {
- return m.issuer
-}
-
-func (m *MockIdentityProviderWithIssuer) Initialize(config interface{}) error {
- return nil
-}
-
-func (m *MockIdentityProviderWithIssuer) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
- // For JWT tokens, parse and validate the token format
- if len(token) > 50 && strings.Contains(token, ".") {
- // This looks like a JWT - parse it to get the subject
- parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
- if err != nil {
- return nil, fmt.Errorf("invalid JWT token")
- }
-
- claims, ok := parsedToken.Claims.(jwt.MapClaims)
- if !ok {
- return nil, fmt.Errorf("invalid claims")
- }
-
- issuer, _ := claims["iss"].(string)
- subject, _ := claims["sub"].(string)
-
- // Verify the issuer matches what we expect
- if issuer != m.issuer {
- return nil, fmt.Errorf("token issuer %s does not match provider issuer %s", issuer, m.issuer)
- }
-
- return &providers.ExternalIdentity{
- UserID: subject,
- Email: subject + "@" + m.name + ".com",
- Provider: m.name,
- }, nil
- }
-
- // For non-JWT tokens, check our simple token list
- if m.validTokens[token] {
- return &providers.ExternalIdentity{
- UserID: "test-user",
- Email: "test@" + m.name + ".com",
- Provider: m.name,
- }, nil
- }
-
- return nil, fmt.Errorf("invalid token")
-}
-
-func (m *MockIdentityProviderWithIssuer) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
- return &providers.ExternalIdentity{
- UserID: userID,
- Email: userID + "@" + m.name + ".com",
- Provider: m.name,
- }, nil
-}
-
-func (m *MockIdentityProviderWithIssuer) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
- if m.validTokens[token] {
- return &providers.TokenClaims{
- Subject: "test-user",
- Issuer: m.issuer,
- }, nil
- }
- return nil, fmt.Errorf("invalid token")
-}
diff --git a/weed/iam/sts/session_policy_test.go b/weed/iam/sts/session_policy_test.go
deleted file mode 100644
index 992fde929..000000000
--- a/weed/iam/sts/session_policy_test.go
+++ /dev/null
@@ -1,168 +0,0 @@
-package sts
-
-import (
- "context"
- "testing"
- "time"
-
- "github.com/golang-jwt/jwt/v5"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// createSessionPolicyTestJWT creates a test JWT token for session policy tests
-func createSessionPolicyTestJWT(t *testing.T, issuer, subject string) string {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "iss": issuer,
- "sub": subject,
- "aud": "test-client",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- })
-
- tokenString, err := token.SignedString([]byte("test-signing-key"))
- require.NoError(t, err)
- return tokenString
-}
-
-// TestAssumeRoleWithWebIdentity_SessionPolicy verifies inline session policies are preserved in tokens.
-func TestAssumeRoleWithWebIdentity_SessionPolicy(t *testing.T) {
- service := setupTestSTSService(t)
- ctx := context.Background()
-
- sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::example-bucket/*"}]}`
- testToken := createSessionPolicyTestJWT(t, "test-issuer", "test-user")
-
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/TestRole",
- WebIdentityToken: testToken,
- RoleSessionName: "test-session",
- Policy: &sessionPolicy,
- }
-
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- require.NoError(t, err)
- require.NotNil(t, response)
-
- sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
- require.NoError(t, err)
-
- normalized, err := NormalizeSessionPolicy(sessionPolicy)
- require.NoError(t, err)
- assert.Equal(t, normalized, sessionInfo.SessionPolicy)
-
- t.Run("should_succeed_without_session_policy", func(t *testing.T) {
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/TestRole",
- WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
- RoleSessionName: "test-session",
- }
-
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- require.NoError(t, err)
- require.NotNil(t, response)
-
- sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
- require.NoError(t, err)
- assert.Empty(t, sessionInfo.SessionPolicy)
- })
-}
-
-// Test edge case scenarios for the Policy field handling
-func TestAssumeRoleWithWebIdentity_SessionPolicy_EdgeCases(t *testing.T) {
- service := setupTestSTSService(t)
- ctx := context.Background()
-
- t.Run("malformed_json_policy_rejected", func(t *testing.T) {
- malformedPolicy := `{"Version": "2012-10-17", "Statement": [` // Incomplete JSON
-
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/TestRole",
- WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
- RoleSessionName: "test-session",
- Policy: &malformedPolicy,
- }
-
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- assert.Error(t, err)
- assert.Nil(t, response)
- assert.Contains(t, err.Error(), "invalid session policy JSON")
- })
-
- t.Run("invalid_policy_document_rejected", func(t *testing.T) {
- invalidPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow"}]}`
-
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/TestRole",
- WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
- RoleSessionName: "test-session",
- Policy: &invalidPolicy,
- }
-
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- assert.Error(t, err)
- assert.Nil(t, response)
- assert.Contains(t, err.Error(), "invalid session policy document")
- })
-
- t.Run("whitespace_policy_ignored", func(t *testing.T) {
- whitespacePolicy := " \t\n "
-
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/TestRole",
- WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
- RoleSessionName: "test-session",
- Policy: &whitespacePolicy,
- }
-
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- require.NoError(t, err)
- require.NotNil(t, response)
-
- sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
- require.NoError(t, err)
- assert.Empty(t, sessionInfo.SessionPolicy)
- })
-}
-
-// TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation verifies that the struct field exists and is optional.
-func TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation(t *testing.T) {
- request := &AssumeRoleWithWebIdentityRequest{}
-
- assert.IsType(t, (*string)(nil), request.Policy,
- "Policy field should be *string type for optional JSON policy")
- assert.Nil(t, request.Policy,
- "Policy field should default to nil (no session policy)")
-
- policyValue := `{"Version": "2012-10-17"}`
- request.Policy = &policyValue
- assert.NotNil(t, request.Policy, "Should be able to assign policy value")
- assert.Equal(t, policyValue, *request.Policy, "Policy value should be preserved")
-}
-
-// TestAssumeRoleWithCredentials_SessionPolicy verifies session policy support for credentials-based flow.
-func TestAssumeRoleWithCredentials_SessionPolicy(t *testing.T) {
- service := setupTestSTSService(t)
- ctx := context.Background()
-
- sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"filer:CreateEntry","Resource":"arn:aws:filer::path/user-docs/*"}]}`
- request := &AssumeRoleWithCredentialsRequest{
- RoleArn: "arn:aws:iam::role/TestRole",
- Username: "testuser",
- Password: "testpass",
- RoleSessionName: "test-session",
- ProviderName: "test-ldap",
- Policy: &sessionPolicy,
- }
-
- response, err := service.AssumeRoleWithCredentials(ctx, request)
- require.NoError(t, err)
- require.NotNil(t, response)
-
- sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
- require.NoError(t, err)
-
- normalized, err := NormalizeSessionPolicy(sessionPolicy)
- require.NoError(t, err)
- assert.Equal(t, normalized, sessionInfo.SessionPolicy)
-}
diff --git a/weed/iam/sts/sts_service.go b/weed/iam/sts/sts_service.go
index d02c82ae1..0d0481795 100644
--- a/weed/iam/sts/sts_service.go
+++ b/weed/iam/sts/sts_service.go
@@ -879,21 +879,6 @@ func (s *STSService) calculateSessionDuration(durationSeconds *int64, tokenExpir
return duration
}
-// extractSessionIdFromToken extracts session ID from JWT session token
-func (s *STSService) extractSessionIdFromToken(sessionToken string) string {
- // Validate JWT and extract session claims
- claims, err := s.tokenGenerator.ValidateJWTWithClaims(sessionToken)
- if err != nil {
- // For test compatibility, also handle direct session IDs
- if len(sessionToken) == 32 { // Typical session ID length
- return sessionToken
- }
- return ""
- }
-
- return claims.SessionId
-}
-
// validateAssumeRoleWithCredentialsRequest validates the credentials request parameters
func (s *STSService) validateAssumeRoleWithCredentialsRequest(request *AssumeRoleWithCredentialsRequest) error {
if request.RoleArn == "" {
diff --git a/weed/iam/sts/sts_service_test.go b/weed/iam/sts/sts_service_test.go
deleted file mode 100644
index e16b3209a..000000000
--- a/weed/iam/sts/sts_service_test.go
+++ /dev/null
@@ -1,778 +0,0 @@
-package sts
-
-import (
- "context"
- "fmt"
- "strings"
- "testing"
- "time"
-
- "github.com/golang-jwt/jwt/v5"
- "github.com/seaweedfs/seaweedfs/weed/iam/providers"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// createSTSTestJWT creates a test JWT token for STS service tests
-func createSTSTestJWT(t *testing.T, issuer, subject string) string {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "iss": issuer,
- "sub": subject,
- "aud": "test-client",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- })
-
- tokenString, err := token.SignedString([]byte("test-signing-key"))
- require.NoError(t, err)
- return tokenString
-}
-
-// TestSTSServiceInitialization tests STS service initialization
-func TestSTSServiceInitialization(t *testing.T) {
- tests := []struct {
- name string
- config *STSConfig
- wantErr bool
- }{
- {
- name: "valid config",
- config: &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{time.Hour * 12},
- Issuer: "seaweedfs-sts",
- SigningKey: []byte("test-signing-key"),
- },
- wantErr: false,
- },
- {
- name: "missing signing key",
- config: &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- Issuer: "seaweedfs-sts",
- },
- wantErr: true,
- },
- {
- name: "invalid token duration",
- config: &STSConfig{
- TokenDuration: FlexibleDuration{-time.Hour},
- Issuer: "seaweedfs-sts",
- SigningKey: []byte("test-key"),
- },
- wantErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- service := NewSTSService()
-
- err := service.Initialize(tt.config)
-
- if tt.wantErr {
- assert.Error(t, err)
- } else {
- assert.NoError(t, err)
- assert.True(t, service.IsInitialized())
-
- // Verify defaults if applicable
- if tt.config.Issuer == "" {
- assert.Equal(t, DefaultIssuer, service.Config.Issuer)
- }
- if tt.config.TokenDuration.Duration == 0 {
- assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, service.Config.TokenDuration.Duration)
- }
- }
- })
- }
-}
-
-func TestSTSServiceDefaults(t *testing.T) {
- service := NewSTSService()
- config := &STSConfig{
- SigningKey: []byte("test-signing-key"),
- // Missing duration and issuer
- }
-
- err := service.Initialize(config)
- assert.NoError(t, err)
-
- assert.Equal(t, DefaultIssuer, config.Issuer)
- assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, config.TokenDuration.Duration)
- assert.Equal(t, time.Duration(DefaultMaxSessionLength)*time.Second, config.MaxSessionLength.Duration)
-}
-
-// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
-func TestAssumeRoleWithWebIdentity(t *testing.T) {
- service := setupTestSTSService(t)
-
- tests := []struct {
- name string
- roleArn string
- webIdentityToken string
- sessionName string
- durationSeconds *int64
- wantErr bool
- expectedSubject string
- }{
- {
- name: "successful role assumption",
- roleArn: "arn:aws:iam::role/TestRole",
- webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user-id"),
- sessionName: "test-session",
- durationSeconds: nil, // Use default
- wantErr: false,
- expectedSubject: "test-user-id",
- },
- {
- name: "invalid web identity token",
- roleArn: "arn:aws:iam::role/TestRole",
- webIdentityToken: "invalid-token",
- sessionName: "test-session",
- wantErr: true,
- },
- {
- name: "non-existent role",
- roleArn: "arn:aws:iam::role/NonExistentRole",
- webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
- sessionName: "test-session",
- wantErr: true,
- },
- {
- name: "custom session duration",
- roleArn: "arn:aws:iam::role/TestRole",
- webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
- sessionName: "test-session",
- durationSeconds: int64Ptr(7200), // 2 hours
- wantErr: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
-
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: tt.roleArn,
- WebIdentityToken: tt.webIdentityToken,
- RoleSessionName: tt.sessionName,
- DurationSeconds: tt.durationSeconds,
- }
-
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
-
- if tt.wantErr {
- assert.Error(t, err)
- assert.Nil(t, response)
- } else {
- assert.NoError(t, err)
- assert.NotNil(t, response)
- assert.NotNil(t, response.Credentials)
- assert.NotNil(t, response.AssumedRoleUser)
-
- // Verify credentials
- creds := response.Credentials
- assert.NotEmpty(t, creds.AccessKeyId)
- assert.NotEmpty(t, creds.SecretAccessKey)
- assert.NotEmpty(t, creds.SessionToken)
- assert.True(t, creds.Expiration.After(time.Now()))
-
- // Verify assumed role user
- user := response.AssumedRoleUser
- assert.Equal(t, tt.roleArn, user.AssumedRoleId)
- assert.Contains(t, user.Arn, tt.sessionName)
-
- if tt.expectedSubject != "" {
- assert.Equal(t, tt.expectedSubject, user.Subject)
- }
- }
- })
- }
-}
-
-// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
-func TestAssumeRoleWithLDAP(t *testing.T) {
- service := setupTestSTSService(t)
-
- tests := []struct {
- name string
- roleArn string
- username string
- password string
- sessionName string
- wantErr bool
- }{
- {
- name: "successful LDAP role assumption",
- roleArn: "arn:aws:iam::role/LDAPRole",
- username: "testuser",
- password: "testpass",
- sessionName: "ldap-session",
- wantErr: false,
- },
- {
- name: "invalid LDAP credentials",
- roleArn: "arn:aws:iam::role/LDAPRole",
- username: "testuser",
- password: "wrongpass",
- sessionName: "ldap-session",
- wantErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
-
- request := &AssumeRoleWithCredentialsRequest{
- RoleArn: tt.roleArn,
- Username: tt.username,
- Password: tt.password,
- RoleSessionName: tt.sessionName,
- ProviderName: "test-ldap",
- }
-
- response, err := service.AssumeRoleWithCredentials(ctx, request)
-
- if tt.wantErr {
- assert.Error(t, err)
- assert.Nil(t, response)
- } else {
- assert.NoError(t, err)
- assert.NotNil(t, response)
- assert.NotNil(t, response.Credentials)
- }
- })
- }
-}
-
-// TestSessionTokenValidation tests session token validation
-func TestSessionTokenValidation(t *testing.T) {
- service := setupTestSTSService(t)
- ctx := context.Background()
-
- // First, create a session
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/TestRole",
- WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
- RoleSessionName: "test-session",
- }
-
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- require.NoError(t, err)
- require.NotNil(t, response)
-
- sessionToken := response.Credentials.SessionToken
-
- tests := []struct {
- name string
- token string
- wantErr bool
- }{
- {
- name: "valid session token",
- token: sessionToken,
- wantErr: false,
- },
- {
- name: "invalid session token",
- token: "invalid-session-token",
- wantErr: true,
- },
- {
- name: "empty session token",
- token: "",
- wantErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- session, err := service.ValidateSessionToken(ctx, tt.token)
-
- if tt.wantErr {
- assert.Error(t, err)
- assert.Nil(t, session)
- } else {
- assert.NoError(t, err)
- assert.NotNil(t, session)
- assert.Equal(t, "test-session", session.SessionName)
- assert.Equal(t, "arn:aws:iam::role/TestRole", session.RoleArn)
- }
- })
- }
-}
-
-// TestSessionTokenPersistence tests that JWT tokens remain valid throughout their lifetime
-// Note: In the stateless JWT design, tokens cannot be revoked and remain valid until expiration
-func TestSessionTokenPersistence(t *testing.T) {
- service := setupTestSTSService(t)
- ctx := context.Background()
-
- // Create a session first
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/TestRole",
- WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
- RoleSessionName: "test-session",
- }
-
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- require.NoError(t, err)
-
- sessionToken := response.Credentials.SessionToken
-
- // Verify token is valid initially
- session, err := service.ValidateSessionToken(ctx, sessionToken)
- assert.NoError(t, err)
- assert.NotNil(t, session)
- assert.Equal(t, "test-session", session.SessionName)
-
- // In a stateless JWT system, tokens remain valid throughout their lifetime
- // Multiple validations should all succeed as long as the token hasn't expired
- session2, err := service.ValidateSessionToken(ctx, sessionToken)
- assert.NoError(t, err, "Token should remain valid in stateless system")
- assert.NotNil(t, session2, "Session should be returned from JWT token")
- assert.Equal(t, session.SessionId, session2.SessionId, "Session ID should be consistent")
-}
-
-// Helper functions
-
-func setupTestSTSService(t *testing.T) *STSService {
- service := NewSTSService()
-
- config := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{time.Hour * 12},
- Issuer: "test-sts",
- SigningKey: []byte("test-signing-key-32-characters-long"),
- }
-
- err := service.Initialize(config)
- require.NoError(t, err)
-
- // Set up mock trust policy validator (required for STS testing)
- mockValidator := &MockTrustPolicyValidator{}
- service.SetTrustPolicyValidator(mockValidator)
-
- // Register test providers
- mockOIDCProvider := &MockIdentityProvider{
- name: "test-oidc",
- validTokens: map[string]*providers.TokenClaims{
- createSTSTestJWT(t, "test-issuer", "test-user"): {
- Subject: "test-user-id",
- Issuer: "test-issuer",
- Claims: map[string]interface{}{
- "email": "test@example.com",
- "name": "Test User",
- },
- },
- },
- }
-
- mockLDAPProvider := &MockIdentityProvider{
- name: "test-ldap",
- validCredentials: map[string]string{
- "testuser": "testpass",
- },
- }
-
- service.RegisterProvider(mockOIDCProvider)
- service.RegisterProvider(mockLDAPProvider)
-
- return service
-}
-
-func int64Ptr(v int64) *int64 {
- return &v
-}
-
-// Mock identity provider for testing
-type MockIdentityProvider struct {
- name string
- validTokens map[string]*providers.TokenClaims
- validCredentials map[string]string
-}
-
-func (m *MockIdentityProvider) Name() string {
- return m.name
-}
-
-func (m *MockIdentityProvider) GetIssuer() string {
- return "test-issuer" // This matches the issuer in the token claims
-}
-
-func (m *MockIdentityProvider) Initialize(config interface{}) error {
- return nil
-}
-
-func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
- // First try to parse as JWT token
- if len(token) > 20 && strings.Count(token, ".") >= 2 {
- parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
- if err == nil {
- if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
- issuer, _ := claims["iss"].(string)
- subject, _ := claims["sub"].(string)
-
- // Verify the issuer matches what we expect
- if issuer == "test-issuer" && subject != "" {
- return &providers.ExternalIdentity{
- UserID: subject,
- Email: subject + "@test-domain.com",
- DisplayName: "Test User " + subject,
- Provider: m.name,
- }, nil
- }
- }
- }
- }
-
- // Handle legacy OIDC tokens (for backwards compatibility)
- if claims, exists := m.validTokens[token]; exists {
- email, _ := claims.GetClaimString("email")
- name, _ := claims.GetClaimString("name")
-
- return &providers.ExternalIdentity{
- UserID: claims.Subject,
- Email: email,
- DisplayName: name,
- Provider: m.name,
- }, nil
- }
-
- // Handle LDAP credentials (username:password format)
- if m.validCredentials != nil {
- parts := strings.Split(token, ":")
- if len(parts) == 2 {
- username, password := parts[0], parts[1]
- if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password {
- return &providers.ExternalIdentity{
- UserID: username,
- Email: username + "@" + m.name + ".com",
- DisplayName: "Test User " + username,
- Provider: m.name,
- }, nil
- }
- }
- }
-
- return nil, fmt.Errorf("unknown test token: %s", token)
-}
-
-func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
- return &providers.ExternalIdentity{
- UserID: userID,
- Email: userID + "@" + m.name + ".com",
- Provider: m.name,
- }, nil
-}
-
-func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
- if claims, exists := m.validTokens[token]; exists {
- return claims, nil
- }
- return nil, fmt.Errorf("invalid token")
-}
-
-// TestSessionDurationCappedByTokenExpiration tests that session duration is capped by the source token's exp claim
-func TestSessionDurationCappedByTokenExpiration(t *testing.T) {
- service := NewSTSService()
-
- config := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour}, // Default: 1 hour
- MaxSessionLength: FlexibleDuration{time.Hour * 12},
- Issuer: "test-sts",
- SigningKey: []byte("test-signing-key-32-characters-long"),
- }
-
- err := service.Initialize(config)
- require.NoError(t, err)
-
- tests := []struct {
- name string
- durationSeconds *int64
- tokenExpiration *time.Time
- expectedMaxSeconds int64
- description string
- }{
- {
- name: "no token expiration - use default duration",
- durationSeconds: nil,
- tokenExpiration: nil,
- expectedMaxSeconds: 3600, // 1 hour default
- description: "When no token expiration is set, use the configured default duration",
- },
- {
- name: "token expires before default duration",
- durationSeconds: nil,
- tokenExpiration: timePtr(time.Now().Add(30 * time.Minute)),
- expectedMaxSeconds: 30 * 60, // 30 minutes
- description: "When token expires in 30 min, session should be capped at 30 min",
- },
- {
- name: "token expires after default duration - use default",
- durationSeconds: nil,
- tokenExpiration: timePtr(time.Now().Add(2 * time.Hour)),
- expectedMaxSeconds: 3600, // 1 hour default, since it's less than 2 hour token expiry
- description: "When token expires after default duration, use the default duration",
- },
- {
- name: "requested duration shorter than token expiry",
- durationSeconds: int64Ptr(1800), // 30 min requested
- tokenExpiration: timePtr(time.Now().Add(time.Hour)),
- expectedMaxSeconds: 1800, // 30 minutes as requested
- description: "When requested duration is shorter than token expiry, use requested duration",
- },
- {
- name: "requested duration longer than token expiry - cap at token expiry",
- durationSeconds: int64Ptr(3600), // 1 hour requested
- tokenExpiration: timePtr(time.Now().Add(15 * time.Minute)),
- expectedMaxSeconds: 15 * 60, // Capped at 15 minutes
- description: "When requested duration exceeds token expiry, cap at token expiry",
- },
- {
- name: "GitLab CI short-lived token scenario",
- durationSeconds: nil,
- tokenExpiration: timePtr(time.Now().Add(5 * time.Minute)),
- expectedMaxSeconds: 5 * 60, // 5 minutes
- description: "GitLab CI job with 5 minute timeout should result in 5 minute session",
- },
- {
- name: "already expired token - defense in depth",
- durationSeconds: nil,
- tokenExpiration: timePtr(time.Now().Add(-5 * time.Minute)), // Expired 5 minutes ago
- expectedMaxSeconds: 60, // 1 minute minimum
- description: "Already expired token should result in minimal 1 minute session",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- duration := service.calculateSessionDuration(tt.durationSeconds, tt.tokenExpiration)
-
- // Allow 5 second tolerance for time calculations
- maxExpected := time.Duration(tt.expectedMaxSeconds+5) * time.Second
- minExpected := time.Duration(tt.expectedMaxSeconds-5) * time.Second
-
- assert.GreaterOrEqual(t, duration, minExpected,
- "%s: duration %v should be >= %v", tt.description, duration, minExpected)
- assert.LessOrEqual(t, duration, maxExpected,
- "%s: duration %v should be <= %v", tt.description, duration, maxExpected)
- })
- }
-}
-
-// TestAssumeRoleWithWebIdentityRespectsTokenExpiration tests end-to-end that session duration is capped
-func TestAssumeRoleWithWebIdentityRespectsTokenExpiration(t *testing.T) {
- service := NewSTSService()
-
- config := &STSConfig{
- TokenDuration: FlexibleDuration{time.Hour},
- MaxSessionLength: FlexibleDuration{time.Hour * 12},
- Issuer: "test-sts",
- SigningKey: []byte("test-signing-key-32-characters-long"),
- }
-
- err := service.Initialize(config)
- require.NoError(t, err)
-
- // Set up mock trust policy validator
- mockValidator := &MockTrustPolicyValidator{}
- service.SetTrustPolicyValidator(mockValidator)
-
- // Create a mock provider that returns tokens with short expiration
- shortLivedTokenExpiration := time.Now().Add(10 * time.Minute)
- mockProvider := &MockIdentityProviderWithExpiration{
- name: "short-lived-issuer",
- tokenExpiration: &shortLivedTokenExpiration,
- }
- service.RegisterProvider(mockProvider)
-
- ctx := context.Background()
-
- // Create a JWT token with short expiration
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "iss": "short-lived-issuer",
- "sub": "test-user",
- "aud": "test-client",
- "exp": shortLivedTokenExpiration.Unix(),
- "iat": time.Now().Unix(),
- })
- tokenString, err := token.SignedString([]byte("test-signing-key"))
- require.NoError(t, err)
-
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/TestRole",
- WebIdentityToken: tokenString,
- RoleSessionName: "test-session",
- }
-
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- require.NoError(t, err)
- require.NotNil(t, response)
-
- // Verify the session expires at or before the token expiration
- // Allow 5 second tolerance
- assert.True(t, response.Credentials.Expiration.Before(shortLivedTokenExpiration.Add(5*time.Second)),
- "Session expiration (%v) should not exceed token expiration (%v)",
- response.Credentials.Expiration, shortLivedTokenExpiration)
-}
-
-// MockIdentityProviderWithExpiration is a mock provider that returns tokens with configurable expiration
-type MockIdentityProviderWithExpiration struct {
- name string
- tokenExpiration *time.Time
-}
-
-func (m *MockIdentityProviderWithExpiration) Name() string {
- return m.name
-}
-
-func (m *MockIdentityProviderWithExpiration) GetIssuer() string {
- return m.name
-}
-
-func (m *MockIdentityProviderWithExpiration) Initialize(config interface{}) error {
- return nil
-}
-
-func (m *MockIdentityProviderWithExpiration) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
- // Parse the token to get subject
- parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
- if err != nil {
- return nil, fmt.Errorf("failed to parse token: %w", err)
- }
-
- claims, ok := parsedToken.Claims.(jwt.MapClaims)
- if !ok {
- return nil, fmt.Errorf("invalid claims")
- }
-
- subject, _ := claims["sub"].(string)
-
- identity := &providers.ExternalIdentity{
- UserID: subject,
- Email: subject + "@example.com",
- DisplayName: "Test User",
- Provider: m.name,
- TokenExpiration: m.tokenExpiration,
- }
-
- return identity, nil
-}
-
-func (m *MockIdentityProviderWithExpiration) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
- return &providers.ExternalIdentity{
- UserID: userID,
- Provider: m.name,
- }, nil
-}
-
-func (m *MockIdentityProviderWithExpiration) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
- claims := &providers.TokenClaims{
- Subject: "test-user",
- Issuer: m.name,
- }
- if m.tokenExpiration != nil {
- claims.ExpiresAt = *m.tokenExpiration
- }
- return claims, nil
-}
-
-func timePtr(t time.Time) *time.Time {
- return &t
-}
-
-// TestAssumeRoleWithWebIdentity_PreservesAttributes tests that attributes from the identity provider
-// are correctly propagated to the session token's request context
-func TestAssumeRoleWithWebIdentity_PreservesAttributes(t *testing.T) {
- service := setupTestSTSService(t)
-
- // Create a mock provider that returns a user with attributes
- mockProvider := &MockIdentityProviderWithAttributes{
- name: "attr-provider",
- attributes: map[string]string{
- "preferred_username": "my-user",
- "department": "engineering",
- "project": "seaweedfs",
- },
- }
- service.RegisterProvider(mockProvider)
-
- // Create a valid JWT token for the provider
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "iss": "attr-provider",
- "sub": "test-user-id",
- "aud": "test-client",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- })
- tokenString, err := token.SignedString([]byte("test-signing-key"))
- require.NoError(t, err)
-
- ctx := context.Background()
- request := &AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/TestRole",
- WebIdentityToken: tokenString,
- RoleSessionName: "test-session",
- }
-
- response, err := service.AssumeRoleWithWebIdentity(ctx, request)
- require.NoError(t, err)
- require.NotNil(t, response)
-
- // Validate the session token to check claims
- sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
- require.NoError(t, err)
-
- // Check that attributes are present in RequestContext
- require.NotNil(t, sessionInfo.RequestContext, "RequestContext should not be nil")
- assert.Equal(t, "my-user", sessionInfo.RequestContext["preferred_username"])
- assert.Equal(t, "engineering", sessionInfo.RequestContext["department"])
- assert.Equal(t, "seaweedfs", sessionInfo.RequestContext["project"])
-
- // Check standard claims are also present
- assert.Equal(t, "test-user-id", sessionInfo.RequestContext["sub"])
- assert.Equal(t, "test@example.com", sessionInfo.RequestContext["email"])
- assert.Equal(t, "Test User", sessionInfo.RequestContext["name"])
-}
-
-// MockIdentityProviderWithAttributes is a mock provider that returns configured attributes
-type MockIdentityProviderWithAttributes struct {
- name string
- attributes map[string]string
-}
-
-func (m *MockIdentityProviderWithAttributes) Name() string {
- return m.name
-}
-
-func (m *MockIdentityProviderWithAttributes) GetIssuer() string {
- return m.name
-}
-
-func (m *MockIdentityProviderWithAttributes) Initialize(config interface{}) error {
- return nil
-}
-
-func (m *MockIdentityProviderWithAttributes) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
- return &providers.ExternalIdentity{
- UserID: "test-user-id",
- Email: "test@example.com",
- DisplayName: "Test User",
- Provider: m.name,
- Attributes: m.attributes,
- }, nil
-}
-
-func (m *MockIdentityProviderWithAttributes) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
- return nil, nil
-}
-
-func (m *MockIdentityProviderWithAttributes) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
- return &providers.TokenClaims{
- Subject: "test-user-id",
- Issuer: m.name,
- }, nil
-}
diff --git a/weed/iam/sts/test_utils.go b/weed/iam/sts/test_utils.go
index 61ef72570..61de76bbd 100644
--- a/weed/iam/sts/test_utils.go
+++ b/weed/iam/sts/test_utils.go
@@ -1,53 +1,4 @@
package sts
-import (
- "context"
- "fmt"
- "strings"
-
- "github.com/seaweedfs/seaweedfs/weed/iam/providers"
-)
-
// MockTrustPolicyValidator is a simple mock for testing STS functionality
type MockTrustPolicyValidator struct{}
-
-// ValidateTrustPolicyForWebIdentity allows valid JWT test tokens for STS testing
-func (m *MockTrustPolicyValidator) ValidateTrustPolicyForWebIdentity(ctx context.Context, roleArn string, webIdentityToken string, durationSeconds *int64) error {
- // Reject non-existent roles for testing
- if strings.Contains(roleArn, "NonExistentRole") {
- return fmt.Errorf("trust policy validation failed: role does not exist")
- }
-
- // For STS unit tests, allow JWT tokens that look valid (contain dots for JWT structure)
- // In real implementation, this would validate against actual trust policies
- if len(webIdentityToken) > 20 && strings.Count(webIdentityToken, ".") >= 2 {
- // This appears to be a JWT token - allow it for testing
- return nil
- }
-
- // Legacy support for specific test tokens during migration
- if webIdentityToken == "valid_test_token" || webIdentityToken == "valid-oidc-token" {
- return nil
- }
-
- // Reject invalid tokens
- if webIdentityToken == "invalid_token" || webIdentityToken == "expired_token" || webIdentityToken == "invalid-token" {
- return fmt.Errorf("trust policy denies token")
- }
-
- return nil
-}
-
-// ValidateTrustPolicyForCredentials allows valid test identities for STS testing
-func (m *MockTrustPolicyValidator) ValidateTrustPolicyForCredentials(ctx context.Context, roleArn string, identity *providers.ExternalIdentity) error {
- // Reject non-existent roles for testing
- if strings.Contains(roleArn, "NonExistentRole") {
- return fmt.Errorf("trust policy validation failed: role does not exist")
- }
-
- // For STS unit tests, allow test identities
- if identity != nil && identity.UserID != "" {
- return nil
- }
- return fmt.Errorf("invalid identity for role assumption")
-}
diff --git a/weed/images/preprocess.go b/weed/images/preprocess.go
deleted file mode 100644
index f6f3b554d..000000000
--- a/weed/images/preprocess.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package images
-
-import (
- "bytes"
- "io"
- "path/filepath"
- "strings"
-)
-
-/*
-* Preprocess image files on client side.
-* 1. possibly adjust the orientation
-* 2. resize the image to a width or height limit
-* 3. remove the exif data
-* Call this function on any file uploaded to SeaweedFS
-*
- */
-func MaybePreprocessImage(filename string, data []byte, width, height int) (resized io.ReadSeeker, w int, h int) {
- ext := filepath.Ext(filename)
- ext = strings.ToLower(ext)
- switch ext {
- case ".png", ".gif":
- return Resized(ext, bytes.NewReader(data), width, height, "")
- case ".jpg", ".jpeg":
- data = FixJpgOrientation(data)
- return Resized(ext, bytes.NewReader(data), width, height, "")
- }
- return bytes.NewReader(data), 0, 0
-}
diff --git a/weed/kms/config_loader.go b/weed/kms/config_loader.go
index 3778c0f59..5f31259c6 100644
--- a/weed/kms/config_loader.go
+++ b/weed/kms/config_loader.go
@@ -290,15 +290,6 @@ func (loader *ConfigLoader) ValidateConfiguration() error {
return nil
}
-// LoadKMSFromFilerToml is a convenience function to load KMS configuration from filer.toml
-func LoadKMSFromFilerToml(v ViperConfig) error {
- loader := NewConfigLoader(v)
- if err := loader.LoadConfigurations(); err != nil {
- return err
- }
- return loader.ValidateConfiguration()
-}
-
// LoadKMSFromConfig loads KMS configuration directly from parsed JSON data
func LoadKMSFromConfig(kmsConfig interface{}) error {
kmsMap, ok := kmsConfig.(map[string]interface{})
@@ -415,12 +406,3 @@ func getIntFromConfig(config map[string]interface{}, key string, defaultValue in
}
return defaultValue
}
-
-func getStringFromConfig(config map[string]interface{}, key string, defaultValue string) string {
- if value, exists := config[key]; exists {
- if stringValue, ok := value.(string); ok {
- return stringValue
- }
- }
- return defaultValue
-}
diff --git a/weed/mount/filehandle.go b/weed/mount/filehandle.go
index 98ca6737f..485a00e41 100644
--- a/weed/mount/filehandle.go
+++ b/weed/mount/filehandle.go
@@ -147,13 +147,6 @@ func (fh *FileHandle) ReleaseHandle() {
}
}
-func lessThan(a, b *filer_pb.FileChunk) bool {
- if a.ModifiedTsNs == b.ModifiedTsNs {
- return a.Fid.FileKey < b.Fid.FileKey
- }
- return a.ModifiedTsNs < b.ModifiedTsNs
-}
-
// getCumulativeOffsets returns cached cumulative offsets for chunks, computing them if necessary
func (fh *FileHandle) getCumulativeOffsets(chunks []*filer_pb.FileChunk) []int64 {
fh.chunkCacheLock.RLock()
diff --git a/weed/mount/page_writer/dirty_pages.go b/weed/mount/page_writer/dirty_pages.go
index cec365231..472815dd5 100644
--- a/weed/mount/page_writer/dirty_pages.go
+++ b/weed/mount/page_writer/dirty_pages.go
@@ -21,9 +21,3 @@ func min(x, y int64) int64 {
}
return y
}
-func minInt(x, y int) int {
- if x < y {
- return x
- }
- return y
-}
diff --git a/weed/mount/rdma_client.go b/weed/mount/rdma_client.go
index e9ee802ce..6a77f8f52 100644
--- a/weed/mount/rdma_client.go
+++ b/weed/mount/rdma_client.go
@@ -119,13 +119,6 @@ func (c *RDMAMountClient) lookupVolumeLocationByFileID(ctx context.Context, file
return bestAddress, nil
}
-// lookupVolumeLocation finds the best volume server for a given volume ID (legacy method)
-func (c *RDMAMountClient) lookupVolumeLocation(ctx context.Context, volumeID uint32, needleID uint64, cookie uint32) (string, error) {
- // Create a file ID for lookup (format: volumeId,needleId,cookie)
- fileID := fmt.Sprintf("%d,%x,%d", volumeID, needleID, cookie)
- return c.lookupVolumeLocationByFileID(ctx, fileID)
-}
-
// healthCheck verifies that the RDMA sidecar is available and functioning
func (c *RDMAMountClient) healthCheck() error {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
diff --git a/weed/mq/broker/broker_errors.go b/weed/mq/broker/broker_errors.go
index b3d4cc42c..529a03e44 100644
--- a/weed/mq/broker/broker_errors.go
+++ b/weed/mq/broker/broker_errors.go
@@ -117,11 +117,6 @@ func GetBrokerErrorInfo(code int32) BrokerErrorInfo {
}
}
-// GetKafkaErrorCode returns the corresponding Kafka protocol error code for a broker error
-func GetKafkaErrorCode(brokerErrorCode int32) int16 {
- return GetBrokerErrorInfo(brokerErrorCode).KafkaCode
-}
-
// CreateBrokerError creates a structured broker error with both error code and message
func CreateBrokerError(code int32, message string) (int32, string) {
info := GetBrokerErrorInfo(code)
diff --git a/weed/mq/broker/broker_offset_integration_test.go b/weed/mq/broker/broker_offset_integration_test.go
deleted file mode 100644
index 49df58a64..000000000
--- a/weed/mq/broker/broker_offset_integration_test.go
+++ /dev/null
@@ -1,351 +0,0 @@
-package broker
-
-import (
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/mq/topic"
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
-)
-
-func createTestTopic() topic.Topic {
- return topic.Topic{
- Namespace: "test",
- Name: "offset-test",
- }
-}
-
-func createTestPartition() topic.Partition {
- return topic.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: time.Now().UnixNano(),
- }
-}
-
-func TestBrokerOffsetManager_AssignOffset(t *testing.T) {
- storage := NewInMemoryOffsetStorageForTesting()
- manager := NewBrokerOffsetManagerWithStorage(storage)
- testTopic := createTestTopic()
- testPartition := createTestPartition()
-
- // Test sequential offset assignment
- for i := int64(0); i < 10; i++ {
- assignedOffset, err := manager.AssignOffset(testTopic, testPartition)
- if err != nil {
- t.Fatalf("Failed to assign offset %d: %v", i, err)
- }
-
- if assignedOffset != i {
- t.Errorf("Expected offset %d, got %d", i, assignedOffset)
- }
- }
-}
-
-func TestBrokerOffsetManager_AssignBatchOffsets(t *testing.T) {
- storage := NewInMemoryOffsetStorageForTesting()
- manager := NewBrokerOffsetManagerWithStorage(storage)
- testTopic := createTestTopic()
- testPartition := createTestPartition()
-
- // Assign batch of offsets
- baseOffset, lastOffset, err := manager.AssignBatchOffsets(testTopic, testPartition, 5)
- if err != nil {
- t.Fatalf("Failed to assign batch offsets: %v", err)
- }
-
- if baseOffset != 0 {
- t.Errorf("Expected base offset 0, got %d", baseOffset)
- }
-
- if lastOffset != 4 {
- t.Errorf("Expected last offset 4, got %d", lastOffset)
- }
-
- // Assign another batch
- baseOffset2, lastOffset2, err := manager.AssignBatchOffsets(testTopic, testPartition, 3)
- if err != nil {
- t.Fatalf("Failed to assign second batch offsets: %v", err)
- }
-
- if baseOffset2 != 5 {
- t.Errorf("Expected base offset 5, got %d", baseOffset2)
- }
-
- if lastOffset2 != 7 {
- t.Errorf("Expected last offset 7, got %d", lastOffset2)
- }
-}
-
-func TestBrokerOffsetManager_GetHighWaterMark(t *testing.T) {
- storage := NewInMemoryOffsetStorageForTesting()
- manager := NewBrokerOffsetManagerWithStorage(storage)
- testTopic := createTestTopic()
- testPartition := createTestPartition()
-
- // Initially should be 0
- hwm, err := manager.GetHighWaterMark(testTopic, testPartition)
- if err != nil {
- t.Fatalf("Failed to get initial high water mark: %v", err)
- }
-
- if hwm != 0 {
- t.Errorf("Expected initial high water mark 0, got %d", hwm)
- }
-
- // Assign some offsets
- manager.AssignBatchOffsets(testTopic, testPartition, 10)
-
- // High water mark should be updated
- hwm, err = manager.GetHighWaterMark(testTopic, testPartition)
- if err != nil {
- t.Fatalf("Failed to get high water mark after assignment: %v", err)
- }
-
- if hwm != 10 {
- t.Errorf("Expected high water mark 10, got %d", hwm)
- }
-}
-
-func TestBrokerOffsetManager_CreateSubscription(t *testing.T) {
- storage := NewInMemoryOffsetStorageForTesting()
- manager := NewBrokerOffsetManagerWithStorage(storage)
- testTopic := createTestTopic()
- testPartition := createTestPartition()
-
- // Assign some offsets first
- manager.AssignBatchOffsets(testTopic, testPartition, 5)
-
- // Create subscription
- sub, err := manager.CreateSubscription(
- "test-sub",
- testTopic,
- testPartition,
- schema_pb.OffsetType_RESET_TO_EARLIEST,
- 0,
- )
-
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- if sub.ID != "test-sub" {
- t.Errorf("Expected subscription ID 'test-sub', got %s", sub.ID)
- }
-
- if sub.StartOffset != 0 {
- t.Errorf("Expected start offset 0, got %d", sub.StartOffset)
- }
-}
-
-func TestBrokerOffsetManager_GetPartitionOffsetInfo(t *testing.T) {
- storage := NewInMemoryOffsetStorageForTesting()
- manager := NewBrokerOffsetManagerWithStorage(storage)
- testTopic := createTestTopic()
- testPartition := createTestPartition()
-
- // Test empty partition
- info, err := manager.GetPartitionOffsetInfo(testTopic, testPartition)
- if err != nil {
- t.Fatalf("Failed to get partition offset info: %v", err)
- }
-
- if info.EarliestOffset != 0 {
- t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
- }
-
- if info.LatestOffset != -1 {
- t.Errorf("Expected latest offset -1 for empty partition, got %d", info.LatestOffset)
- }
-
- // Assign offsets and test again
- manager.AssignBatchOffsets(testTopic, testPartition, 5)
-
- info, err = manager.GetPartitionOffsetInfo(testTopic, testPartition)
- if err != nil {
- t.Fatalf("Failed to get partition offset info after assignment: %v", err)
- }
-
- if info.LatestOffset != 4 {
- t.Errorf("Expected latest offset 4, got %d", info.LatestOffset)
- }
-
- if info.HighWaterMark != 5 {
- t.Errorf("Expected high water mark 5, got %d", info.HighWaterMark)
- }
-}
-
-func TestBrokerOffsetManager_MultiplePartitions(t *testing.T) {
- storage := NewInMemoryOffsetStorageForTesting()
- manager := NewBrokerOffsetManagerWithStorage(storage)
- testTopic := createTestTopic()
-
- // Create different partitions
- partition1 := topic.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: time.Now().UnixNano(),
- }
-
- partition2 := topic.Partition{
- RingSize: 1024,
- RangeStart: 32,
- RangeStop: 63,
- UnixTimeNs: time.Now().UnixNano(),
- }
-
- // Assign offsets to different partitions
- assignedOffset1, err := manager.AssignOffset(testTopic, partition1)
- if err != nil {
- t.Fatalf("Failed to assign offset to partition1: %v", err)
- }
-
- assignedOffset2, err := manager.AssignOffset(testTopic, partition2)
- if err != nil {
- t.Fatalf("Failed to assign offset to partition2: %v", err)
- }
-
- // Both should start at 0
- if assignedOffset1 != 0 {
- t.Errorf("Expected offset 0 for partition1, got %d", assignedOffset1)
- }
-
- if assignedOffset2 != 0 {
- t.Errorf("Expected offset 0 for partition2, got %d", assignedOffset2)
- }
-
- // Assign more offsets to partition1
- assignedOffset1_2, err := manager.AssignOffset(testTopic, partition1)
- if err != nil {
- t.Fatalf("Failed to assign second offset to partition1: %v", err)
- }
-
- if assignedOffset1_2 != 1 {
- t.Errorf("Expected offset 1 for partition1, got %d", assignedOffset1_2)
- }
-
- // Partition2 should still be at 0 for next assignment
- assignedOffset2_2, err := manager.AssignOffset(testTopic, partition2)
- if err != nil {
- t.Fatalf("Failed to assign second offset to partition2: %v", err)
- }
-
- if assignedOffset2_2 != 1 {
- t.Errorf("Expected offset 1 for partition2, got %d", assignedOffset2_2)
- }
-}
-
-func TestOffsetAwarePublisher(t *testing.T) {
- storage := NewInMemoryOffsetStorageForTesting()
- manager := NewBrokerOffsetManagerWithStorage(storage)
- testTopic := createTestTopic()
- testPartition := createTestPartition()
-
- // Create a mock local partition (simplified for testing)
- localPartition := &topic.LocalPartition{}
-
- // Create offset assignment function
- assignOffsetFn := func() (int64, error) {
- return manager.AssignOffset(testTopic, testPartition)
- }
-
- // Create offset-aware publisher
- publisher := topic.NewOffsetAwarePublisher(localPartition, assignOffsetFn)
-
- if publisher.GetPartition() != localPartition {
- t.Error("Publisher should return the correct partition")
- }
-
- // Test would require more setup to actually publish messages
- // This tests the basic structure
-}
-
-func TestBrokerOffsetManager_GetOffsetMetrics(t *testing.T) {
- storage := NewInMemoryOffsetStorageForTesting()
- manager := NewBrokerOffsetManagerWithStorage(storage)
- testTopic := createTestTopic()
- testPartition := createTestPartition()
-
- // Initial metrics
- metrics := manager.GetOffsetMetrics()
- if metrics.TotalOffsets != 0 {
- t.Errorf("Expected 0 total offsets initially, got %d", metrics.TotalOffsets)
- }
-
- // Assign some offsets
- manager.AssignBatchOffsets(testTopic, testPartition, 5)
-
- // Create subscription
- manager.CreateSubscription("test-sub", testTopic, testPartition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
-
- // Check updated metrics
- metrics = manager.GetOffsetMetrics()
- if metrics.PartitionCount != 1 {
- t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount)
- }
-}
-
-func TestBrokerOffsetManager_AssignOffsetsWithResult(t *testing.T) {
- storage := NewInMemoryOffsetStorageForTesting()
- manager := NewBrokerOffsetManagerWithStorage(storage)
- testTopic := createTestTopic()
- testPartition := createTestPartition()
-
- // Assign offsets with result
- result := manager.AssignOffsetsWithResult(testTopic, testPartition, 3)
-
- if result.Error != nil {
- t.Fatalf("Expected no error, got: %v", result.Error)
- }
-
- if result.BaseOffset != 0 {
- t.Errorf("Expected base offset 0, got %d", result.BaseOffset)
- }
-
- if result.LastOffset != 2 {
- t.Errorf("Expected last offset 2, got %d", result.LastOffset)
- }
-
- if result.Count != 3 {
- t.Errorf("Expected count 3, got %d", result.Count)
- }
-
- if result.Topic != testTopic {
- t.Error("Topic mismatch in result")
- }
-
- if result.Partition != testPartition {
- t.Error("Partition mismatch in result")
- }
-
- if result.Timestamp <= 0 {
- t.Error("Timestamp should be set")
- }
-}
-
-func TestBrokerOffsetManager_Shutdown(t *testing.T) {
- storage := NewInMemoryOffsetStorageForTesting()
- manager := NewBrokerOffsetManagerWithStorage(storage)
- testTopic := createTestTopic()
- testPartition := createTestPartition()
-
- // Assign some offsets and create subscriptions
- manager.AssignBatchOffsets(testTopic, testPartition, 5)
- manager.CreateSubscription("test-sub", testTopic, testPartition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
-
- // Shutdown should not panic
- manager.Shutdown()
-
- // After shutdown, operations should still work (using new managers)
- offset, err := manager.AssignOffset(testTopic, testPartition)
- if err != nil {
- t.Fatalf("Operations should still work after shutdown: %v", err)
- }
-
- // Should start from 0 again (new manager)
- if offset != 0 {
- t.Errorf("Expected offset 0 after shutdown, got %d", offset)
- }
-}
diff --git a/weed/mq/broker/broker_server.go b/weed/mq/broker/broker_server.go
index 5116ff5a5..a3e658c35 100644
--- a/weed/mq/broker/broker_server.go
+++ b/weed/mq/broker/broker_server.go
@@ -203,14 +203,6 @@ func (b *MessageQueueBroker) GetDataCenter() string {
}
-func (b *MessageQueueBroker) withMasterClient(streamingMode bool, master pb.ServerAddress, fn func(client master_pb.SeaweedClient) error) error {
-
- return pb.WithMasterClient(streamingMode, master, b.grpcDialOption, false, func(client master_pb.SeaweedClient) error {
- return fn(client)
- })
-
-}
-
func (b *MessageQueueBroker) withBrokerClient(streamingMode bool, server pb.ServerAddress, fn func(client mq_pb.SeaweedMessagingClient) error) error {
return pb.WithBrokerGrpcClient(streamingMode, server.String(), b.grpcDialOption, func(client mq_pb.SeaweedMessagingClient) error {
diff --git a/weed/mq/kafka/consumer_offset/filer_storage.go b/weed/mq/kafka/consumer_offset/filer_storage.go
index 9d92ad730..967982683 100644
--- a/weed/mq/kafka/consumer_offset/filer_storage.go
+++ b/weed/mq/kafka/consumer_offset/filer_storage.go
@@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"io"
- "strings"
"time"
"github.com/seaweedfs/seaweedfs/weed/filer_client"
@@ -192,10 +191,6 @@ func (f *FilerStorage) getOffsetPath(group, topic string, partition int32) strin
return fmt.Sprintf("%s/offset", f.getPartitionPath(group, topic, partition))
}
-func (f *FilerStorage) getMetadataPath(group, topic string, partition int32) string {
- return fmt.Sprintf("%s/metadata", f.getPartitionPath(group, topic, partition))
-}
-
func (f *FilerStorage) writeFile(path string, data []byte) error {
fullPath := util.FullPath(path)
dir, name := fullPath.DirAndName()
@@ -311,16 +306,3 @@ func (f *FilerStorage) deleteDirectory(path string) error {
return err
})
}
-
-// normalizePath removes leading/trailing slashes and collapses multiple slashes
-func normalizePath(path string) string {
- path = strings.Trim(path, "/")
- parts := strings.Split(path, "/")
- normalized := []string{}
- for _, part := range parts {
- if part != "" {
- normalized = append(normalized, part)
- }
- }
- return "/" + strings.Join(normalized, "/")
-}
diff --git a/weed/mq/kafka/consumer_offset/filer_storage_test.go b/weed/mq/kafka/consumer_offset/filer_storage_test.go
deleted file mode 100644
index 67a0e7e09..000000000
--- a/weed/mq/kafka/consumer_offset/filer_storage_test.go
+++ /dev/null
@@ -1,65 +0,0 @@
-package consumer_offset
-
-import (
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-// Note: These tests require a running filer instance
-// They are marked as integration tests and should be run with:
-// go test -tags=integration
-
-func TestFilerStorageCommitAndFetch(t *testing.T) {
- t.Skip("Requires running filer - integration test")
-
- // This will be implemented once we have test infrastructure
- // Test will:
- // 1. Create filer storage
- // 2. Commit offset
- // 3. Fetch offset
- // 4. Verify values match
-}
-
-func TestFilerStoragePersistence(t *testing.T) {
- t.Skip("Requires running filer - integration test")
-
- // Test will:
- // 1. Commit offset with first storage instance
- // 2. Close first instance
- // 3. Create new storage instance
- // 4. Fetch offset and verify it persisted
-}
-
-func TestFilerStorageMultipleGroups(t *testing.T) {
- t.Skip("Requires running filer - integration test")
-
- // Test will:
- // 1. Commit offsets for multiple groups
- // 2. Fetch all offsets per group
- // 3. Verify isolation between groups
-}
-
-func TestFilerStoragePath(t *testing.T) {
- // Test path generation (doesn't require filer)
- storage := &FilerStorage{}
-
- group := "test-group"
- topic := "test-topic"
- partition := int32(5)
-
- groupPath := storage.getGroupPath(group)
- assert.Equal(t, ConsumerOffsetsBasePath+"/test-group", groupPath)
-
- topicPath := storage.getTopicPath(group, topic)
- assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic", topicPath)
-
- partitionPath := storage.getPartitionPath(group, topic, partition)
- assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5", partitionPath)
-
- offsetPath := storage.getOffsetPath(group, topic, partition)
- assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5/offset", offsetPath)
-
- metadataPath := storage.getMetadataPath(group, topic, partition)
- assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5/metadata", metadataPath)
-}
diff --git a/weed/mq/kafka/integration/seaweedmq_handler_topics.go b/weed/mq/kafka/integration/seaweedmq_handler_topics.go
index b635b40af..b2071fd00 100644
--- a/weed/mq/kafka/integration/seaweedmq_handler_topics.go
+++ b/weed/mq/kafka/integration/seaweedmq_handler_topics.go
@@ -278,38 +278,3 @@ func (h *SeaweedMQHandler) checkTopicInFiler(topicName string) bool {
return exists
}
-
-// listTopicsFromFiler lists all topics from the filer
-func (h *SeaweedMQHandler) listTopicsFromFiler() []string {
- if h.filerClientAccessor == nil {
- return []string{}
- }
-
- var topics []string
-
- h.filerClientAccessor.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
- request := &filer_pb.ListEntriesRequest{
- Directory: "/topics/kafka",
- }
-
- stream, err := client.ListEntries(context.Background(), request)
- if err != nil {
- return nil // Don't propagate error, just return empty list
- }
-
- for {
- resp, err := stream.Recv()
- if err != nil {
- break // End of stream or error
- }
-
- if resp.Entry != nil && resp.Entry.IsDirectory {
- topics = append(topics, resp.Entry.Name)
- } else if resp.Entry != nil {
- }
- }
- return nil
- })
-
- return topics
-}
diff --git a/weed/mq/kafka/partition_mapping.go b/weed/mq/kafka/partition_mapping.go
deleted file mode 100644
index a956c3cde..000000000
--- a/weed/mq/kafka/partition_mapping.go
+++ /dev/null
@@ -1,53 +0,0 @@
-package kafka
-
-import (
- "github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
-)
-
-// Convenience functions for partition mapping used by production code
-// The full PartitionMapper implementation is in partition_mapping_test.go for testing
-
-// MapKafkaPartitionToSMQRange maps a Kafka partition to SeaweedMQ ring range
-func MapKafkaPartitionToSMQRange(kafkaPartition int32) (rangeStart, rangeStop int32) {
- // Use a range size that divides evenly into MaxPartitionCount (2520)
- // Range size 35 gives us exactly 72 Kafka partitions: 2520 / 35 = 72
- rangeSize := int32(35)
- rangeStart = kafkaPartition * rangeSize
- rangeStop = rangeStart + rangeSize - 1
- return rangeStart, rangeStop
-}
-
-// CreateSMQPartition creates a SeaweedMQ partition from a Kafka partition
-func CreateSMQPartition(kafkaPartition int32, unixTimeNs int64) *schema_pb.Partition {
- rangeStart, rangeStop := MapKafkaPartitionToSMQRange(kafkaPartition)
-
- return &schema_pb.Partition{
- RingSize: pub_balancer.MaxPartitionCount,
- RangeStart: rangeStart,
- RangeStop: rangeStop,
- UnixTimeNs: unixTimeNs,
- }
-}
-
-// ExtractKafkaPartitionFromSMQRange extracts the Kafka partition from SeaweedMQ range
-func ExtractKafkaPartitionFromSMQRange(rangeStart int32) int32 {
- rangeSize := int32(35)
- return rangeStart / rangeSize
-}
-
-// ValidateKafkaPartition validates that a Kafka partition is within supported range
-func ValidateKafkaPartition(kafkaPartition int32) bool {
- maxPartitions := int32(pub_balancer.MaxPartitionCount) / 35 // 72 partitions
- return kafkaPartition >= 0 && kafkaPartition < maxPartitions
-}
-
-// GetRangeSize returns the range size used for partition mapping
-func GetRangeSize() int32 {
- return 35
-}
-
-// GetMaxKafkaPartitions returns the maximum number of Kafka partitions supported
-func GetMaxKafkaPartitions() int32 {
- return int32(pub_balancer.MaxPartitionCount) / 35 // 72 partitions
-}
diff --git a/weed/mq/kafka/partition_mapping_test.go b/weed/mq/kafka/partition_mapping_test.go
deleted file mode 100644
index 6f41a68d4..000000000
--- a/weed/mq/kafka/partition_mapping_test.go
+++ /dev/null
@@ -1,294 +0,0 @@
-package kafka
-
-import (
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
-)
-
-// PartitionMapper provides consistent Kafka partition to SeaweedMQ ring mapping
-// NOTE: This is test-only code and not used in the actual Kafka Gateway implementation
-type PartitionMapper struct{}
-
-// NewPartitionMapper creates a new partition mapper
-func NewPartitionMapper() *PartitionMapper {
- return &PartitionMapper{}
-}
-
-// GetRangeSize returns the consistent range size for Kafka partition mapping
-// This ensures all components use the same calculation
-func (pm *PartitionMapper) GetRangeSize() int32 {
- // Use a range size that divides evenly into MaxPartitionCount (2520)
- // Range size 35 gives us exactly 72 Kafka partitions: 2520 / 35 = 72
- // This provides a good balance between partition granularity and ring utilization
- return 35
-}
-
-// GetMaxKafkaPartitions returns the maximum number of Kafka partitions supported
-func (pm *PartitionMapper) GetMaxKafkaPartitions() int32 {
- // With range size 35, we can support: 2520 / 35 = 72 Kafka partitions
- return int32(pub_balancer.MaxPartitionCount) / pm.GetRangeSize()
-}
-
-// MapKafkaPartitionToSMQRange maps a Kafka partition to SeaweedMQ ring range
-func (pm *PartitionMapper) MapKafkaPartitionToSMQRange(kafkaPartition int32) (rangeStart, rangeStop int32) {
- rangeSize := pm.GetRangeSize()
- rangeStart = kafkaPartition * rangeSize
- rangeStop = rangeStart + rangeSize - 1
- return rangeStart, rangeStop
-}
-
-// CreateSMQPartition creates a SeaweedMQ partition from a Kafka partition
-func (pm *PartitionMapper) CreateSMQPartition(kafkaPartition int32, unixTimeNs int64) *schema_pb.Partition {
- rangeStart, rangeStop := pm.MapKafkaPartitionToSMQRange(kafkaPartition)
-
- return &schema_pb.Partition{
- RingSize: pub_balancer.MaxPartitionCount,
- RangeStart: rangeStart,
- RangeStop: rangeStop,
- UnixTimeNs: unixTimeNs,
- }
-}
-
-// ExtractKafkaPartitionFromSMQRange extracts the Kafka partition from SeaweedMQ range
-func (pm *PartitionMapper) ExtractKafkaPartitionFromSMQRange(rangeStart int32) int32 {
- rangeSize := pm.GetRangeSize()
- return rangeStart / rangeSize
-}
-
-// ValidateKafkaPartition validates that a Kafka partition is within supported range
-func (pm *PartitionMapper) ValidateKafkaPartition(kafkaPartition int32) bool {
- return kafkaPartition >= 0 && kafkaPartition < pm.GetMaxKafkaPartitions()
-}
-
-// GetPartitionMappingInfo returns debug information about the partition mapping
-func (pm *PartitionMapper) GetPartitionMappingInfo() map[string]interface{} {
- return map[string]interface{}{
- "ring_size": pub_balancer.MaxPartitionCount,
- "range_size": pm.GetRangeSize(),
- "max_kafka_partitions": pm.GetMaxKafkaPartitions(),
- "ring_utilization": float64(pm.GetMaxKafkaPartitions()*pm.GetRangeSize()) / float64(pub_balancer.MaxPartitionCount),
- }
-}
-
-// Global instance for consistent usage across the test codebase
-var DefaultPartitionMapper = NewPartitionMapper()
-
-func TestPartitionMapper_GetRangeSize(t *testing.T) {
- mapper := NewPartitionMapper()
- rangeSize := mapper.GetRangeSize()
-
- if rangeSize != 35 {
- t.Errorf("Expected range size 35, got %d", rangeSize)
- }
-
- // Verify that the range size divides evenly into available partitions
- maxPartitions := mapper.GetMaxKafkaPartitions()
- totalUsed := maxPartitions * rangeSize
-
- if totalUsed > int32(pub_balancer.MaxPartitionCount) {
- t.Errorf("Total used slots (%d) exceeds MaxPartitionCount (%d)", totalUsed, pub_balancer.MaxPartitionCount)
- }
-
- t.Logf("Range size: %d, Max Kafka partitions: %d, Ring utilization: %.2f%%",
- rangeSize, maxPartitions, float64(totalUsed)/float64(pub_balancer.MaxPartitionCount)*100)
-}
-
-func TestPartitionMapper_MapKafkaPartitionToSMQRange(t *testing.T) {
- mapper := NewPartitionMapper()
-
- tests := []struct {
- kafkaPartition int32
- expectedStart int32
- expectedStop int32
- }{
- {0, 0, 34},
- {1, 35, 69},
- {2, 70, 104},
- {10, 350, 384},
- }
-
- for _, tt := range tests {
- t.Run("", func(t *testing.T) {
- start, stop := mapper.MapKafkaPartitionToSMQRange(tt.kafkaPartition)
-
- if start != tt.expectedStart {
- t.Errorf("Kafka partition %d: expected start %d, got %d", tt.kafkaPartition, tt.expectedStart, start)
- }
-
- if stop != tt.expectedStop {
- t.Errorf("Kafka partition %d: expected stop %d, got %d", tt.kafkaPartition, tt.expectedStop, stop)
- }
-
- // Verify range size is consistent
- rangeSize := stop - start + 1
- if rangeSize != mapper.GetRangeSize() {
- t.Errorf("Inconsistent range size: expected %d, got %d", mapper.GetRangeSize(), rangeSize)
- }
- })
- }
-}
-
-func TestPartitionMapper_ExtractKafkaPartitionFromSMQRange(t *testing.T) {
- mapper := NewPartitionMapper()
-
- tests := []struct {
- rangeStart int32
- expectedKafka int32
- }{
- {0, 0},
- {35, 1},
- {70, 2},
- {350, 10},
- }
-
- for _, tt := range tests {
- t.Run("", func(t *testing.T) {
- kafkaPartition := mapper.ExtractKafkaPartitionFromSMQRange(tt.rangeStart)
-
- if kafkaPartition != tt.expectedKafka {
- t.Errorf("Range start %d: expected Kafka partition %d, got %d",
- tt.rangeStart, tt.expectedKafka, kafkaPartition)
- }
- })
- }
-}
-
-func TestPartitionMapper_RoundTrip(t *testing.T) {
- mapper := NewPartitionMapper()
-
- // Test round-trip conversion for all valid Kafka partitions
- maxPartitions := mapper.GetMaxKafkaPartitions()
-
- for kafkaPartition := int32(0); kafkaPartition < maxPartitions; kafkaPartition++ {
- // Kafka -> SMQ -> Kafka
- rangeStart, rangeStop := mapper.MapKafkaPartitionToSMQRange(kafkaPartition)
- extractedKafka := mapper.ExtractKafkaPartitionFromSMQRange(rangeStart)
-
- if extractedKafka != kafkaPartition {
- t.Errorf("Round-trip failed for partition %d: got %d", kafkaPartition, extractedKafka)
- }
-
- // Verify no overlap with next partition
- if kafkaPartition < maxPartitions-1 {
- nextStart, _ := mapper.MapKafkaPartitionToSMQRange(kafkaPartition + 1)
- if rangeStop >= nextStart {
- t.Errorf("Partition %d range [%d,%d] overlaps with partition %d start %d",
- kafkaPartition, rangeStart, rangeStop, kafkaPartition+1, nextStart)
- }
- }
- }
-}
-
-func TestPartitionMapper_CreateSMQPartition(t *testing.T) {
- mapper := NewPartitionMapper()
-
- kafkaPartition := int32(5)
- unixTimeNs := time.Now().UnixNano()
-
- partition := mapper.CreateSMQPartition(kafkaPartition, unixTimeNs)
-
- if partition.RingSize != pub_balancer.MaxPartitionCount {
- t.Errorf("Expected ring size %d, got %d", pub_balancer.MaxPartitionCount, partition.RingSize)
- }
-
- expectedStart, expectedStop := mapper.MapKafkaPartitionToSMQRange(kafkaPartition)
- if partition.RangeStart != expectedStart {
- t.Errorf("Expected range start %d, got %d", expectedStart, partition.RangeStart)
- }
-
- if partition.RangeStop != expectedStop {
- t.Errorf("Expected range stop %d, got %d", expectedStop, partition.RangeStop)
- }
-
- if partition.UnixTimeNs != unixTimeNs {
- t.Errorf("Expected timestamp %d, got %d", unixTimeNs, partition.UnixTimeNs)
- }
-}
-
-func TestPartitionMapper_ValidateKafkaPartition(t *testing.T) {
- mapper := NewPartitionMapper()
-
- tests := []struct {
- partition int32
- valid bool
- }{
- {-1, false},
- {0, true},
- {1, true},
- {mapper.GetMaxKafkaPartitions() - 1, true},
- {mapper.GetMaxKafkaPartitions(), false},
- {1000, false},
- }
-
- for _, tt := range tests {
- t.Run("", func(t *testing.T) {
- valid := mapper.ValidateKafkaPartition(tt.partition)
- if valid != tt.valid {
- t.Errorf("Partition %d: expected valid=%v, got %v", tt.partition, tt.valid, valid)
- }
- })
- }
-}
-
-func TestPartitionMapper_ConsistencyWithGlobalFunctions(t *testing.T) {
- mapper := NewPartitionMapper()
-
- kafkaPartition := int32(7)
- unixTimeNs := time.Now().UnixNano()
-
- // Test that global functions produce same results as mapper methods
- start1, stop1 := mapper.MapKafkaPartitionToSMQRange(kafkaPartition)
- start2, stop2 := MapKafkaPartitionToSMQRange(kafkaPartition)
-
- if start1 != start2 || stop1 != stop2 {
- t.Errorf("Global function inconsistent: mapper=(%d,%d), global=(%d,%d)",
- start1, stop1, start2, stop2)
- }
-
- partition1 := mapper.CreateSMQPartition(kafkaPartition, unixTimeNs)
- partition2 := CreateSMQPartition(kafkaPartition, unixTimeNs)
-
- if partition1.RangeStart != partition2.RangeStart || partition1.RangeStop != partition2.RangeStop {
- t.Errorf("Global CreateSMQPartition inconsistent")
- }
-
- extracted1 := mapper.ExtractKafkaPartitionFromSMQRange(start1)
- extracted2 := ExtractKafkaPartitionFromSMQRange(start1)
-
- if extracted1 != extracted2 {
- t.Errorf("Global ExtractKafkaPartitionFromSMQRange inconsistent: %d vs %d", extracted1, extracted2)
- }
-}
-
-func TestPartitionMapper_GetPartitionMappingInfo(t *testing.T) {
- mapper := NewPartitionMapper()
-
- info := mapper.GetPartitionMappingInfo()
-
- // Verify all expected keys are present
- expectedKeys := []string{"ring_size", "range_size", "max_kafka_partitions", "ring_utilization"}
- for _, key := range expectedKeys {
- if _, exists := info[key]; !exists {
- t.Errorf("Missing key in mapping info: %s", key)
- }
- }
-
- // Verify values are reasonable
- if info["ring_size"].(int) != pub_balancer.MaxPartitionCount {
- t.Errorf("Incorrect ring_size in info")
- }
-
- if info["range_size"].(int32) != mapper.GetRangeSize() {
- t.Errorf("Incorrect range_size in info")
- }
-
- utilization := info["ring_utilization"].(float64)
- if utilization <= 0 || utilization > 1 {
- t.Errorf("Invalid ring utilization: %f", utilization)
- }
-
- t.Logf("Partition mapping info: %+v", info)
-}
diff --git a/weed/mq/offset/benchmark_test.go b/weed/mq/offset/benchmark_test.go
index 0fdacf127..8c78cc0f8 100644
--- a/weed/mq/offset/benchmark_test.go
+++ b/weed/mq/offset/benchmark_test.go
@@ -2,11 +2,9 @@ package offset
import (
"fmt"
- "os"
"testing"
"time"
- _ "github.com/mattn/go-sqlite3"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
)
@@ -62,151 +60,6 @@ func BenchmarkBatchOffsetAssignment(b *testing.B) {
}
}
-// BenchmarkSQLOffsetStorage benchmarks SQL storage operations
-func BenchmarkSQLOffsetStorage(b *testing.B) {
- // Create temporary database
- tmpFile, err := os.CreateTemp("", "benchmark_*.db")
- if err != nil {
- b.Fatalf("Failed to create temp database: %v", err)
- }
- tmpFile.Close()
- defer os.Remove(tmpFile.Name())
-
- db, err := CreateDatabase(tmpFile.Name())
- if err != nil {
- b.Fatalf("Failed to create database: %v", err)
- }
- defer db.Close()
-
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- b.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- partition := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: time.Now().UnixNano(),
- }
-
- partitionKey := partitionKey(partition)
-
- b.Run("SaveCheckpoint", func(b *testing.B) {
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- storage.SaveCheckpoint("test-namespace", "test-topic", partition, int64(i))
- }
- })
-
- b.Run("LoadCheckpoint", func(b *testing.B) {
- storage.SaveCheckpoint("test-namespace", "test-topic", partition, 1000)
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- storage.LoadCheckpoint("test-namespace", "test-topic", partition)
- }
- })
-
- b.Run("SaveOffsetMapping", func(b *testing.B) {
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), 100)
- }
- })
-
- // Pre-populate for read benchmarks
- for i := 0; i < 1000; i++ {
- storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), 100)
- }
-
- b.Run("GetHighestOffset", func(b *testing.B) {
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- storage.GetHighestOffset("test-namespace", "test-topic", partition)
- }
- })
-
- b.Run("LoadOffsetMappings", func(b *testing.B) {
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- storage.LoadOffsetMappings(partitionKey)
- }
- })
-
- b.Run("GetOffsetMappingsByRange", func(b *testing.B) {
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- start := int64(i % 900)
- end := start + 100
- storage.GetOffsetMappingsByRange(partitionKey, start, end)
- }
- })
-
- b.Run("GetPartitionStats", func(b *testing.B) {
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- storage.GetPartitionStats(partitionKey)
- }
- })
-}
-
-// BenchmarkInMemoryVsSQL compares in-memory and SQL storage performance
-func BenchmarkInMemoryVsSQL(b *testing.B) {
- partition := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: time.Now().UnixNano(),
- }
-
- // In-memory storage benchmark
- b.Run("InMemory", func(b *testing.B) {
- storage := NewInMemoryOffsetStorage()
- manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
- if err != nil {
- b.Fatalf("Failed to create partition manager: %v", err)
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- manager.AssignOffset()
- }
- })
-
- // SQL storage benchmark
- b.Run("SQL", func(b *testing.B) {
- tmpFile, err := os.CreateTemp("", "benchmark_sql_*.db")
- if err != nil {
- b.Fatalf("Failed to create temp database: %v", err)
- }
- tmpFile.Close()
- defer os.Remove(tmpFile.Name())
-
- db, err := CreateDatabase(tmpFile.Name())
- if err != nil {
- b.Fatalf("Failed to create database: %v", err)
- }
- defer db.Close()
-
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- b.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
- if err != nil {
- b.Fatalf("Failed to create partition manager: %v", err)
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- manager.AssignOffset()
- }
- })
-}
-
// BenchmarkOffsetSubscription benchmarks subscription operations
func BenchmarkOffsetSubscription(b *testing.B) {
storage := NewInMemoryOffsetStorage()
diff --git a/weed/mq/offset/end_to_end_test.go b/weed/mq/offset/end_to_end_test.go
deleted file mode 100644
index f2b57b843..000000000
--- a/weed/mq/offset/end_to_end_test.go
+++ /dev/null
@@ -1,473 +0,0 @@
-package offset
-
-import (
- "fmt"
- "os"
- "testing"
- "time"
-
- _ "github.com/mattn/go-sqlite3"
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
-)
-
-// TestEndToEndOffsetFlow tests the complete offset management flow
-func TestEndToEndOffsetFlow(t *testing.T) {
- // Create temporary database
- tmpFile, err := os.CreateTemp("", "e2e_offset_test_*.db")
- if err != nil {
- t.Fatalf("Failed to create temp database: %v", err)
- }
- tmpFile.Close()
- defer os.Remove(tmpFile.Name())
-
- // Create database with migrations
- db, err := CreateDatabase(tmpFile.Name())
- if err != nil {
- t.Fatalf("Failed to create database: %v", err)
- }
- defer db.Close()
-
- // Create SQL storage
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- // Create SMQ offset integration
- integration := NewSMQOffsetIntegration(storage)
-
- // Test partition
- partition := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: time.Now().UnixNano(),
- }
-
- t.Run("PublishAndAssignOffsets", func(t *testing.T) {
- // Simulate publishing messages with offset assignment
- records := []PublishRecordRequest{
- {Key: []byte("user1"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("user2"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("user3"), Value: &schema_pb.RecordValue{}},
- }
-
- response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
- if err != nil {
- t.Fatalf("Failed to publish record batch: %v", err)
- }
-
- if response.BaseOffset != 0 {
- t.Errorf("Expected base offset 0, got %d", response.BaseOffset)
- }
-
- if response.LastOffset != 2 {
- t.Errorf("Expected last offset 2, got %d", response.LastOffset)
- }
-
- // Verify high water mark
- hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get high water mark: %v", err)
- }
-
- if hwm != 3 {
- t.Errorf("Expected high water mark 3, got %d", hwm)
- }
- })
-
- t.Run("CreateAndUseSubscription", func(t *testing.T) {
- // Create subscription from earliest
- sub, err := integration.CreateSubscription(
- "e2e-test-sub",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_RESET_TO_EARLIEST,
- 0,
- )
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Subscribe to records
- responses, err := integration.SubscribeRecords(sub, 2)
- if err != nil {
- t.Fatalf("Failed to subscribe to records: %v", err)
- }
-
- if len(responses) != 2 {
- t.Errorf("Expected 2 responses, got %d", len(responses))
- }
-
- // Check subscription advancement
- if sub.CurrentOffset != 2 {
- t.Errorf("Expected current offset 2, got %d", sub.CurrentOffset)
- }
-
- // Get subscription lag
- lag, err := sub.GetLag()
- if err != nil {
- t.Fatalf("Failed to get lag: %v", err)
- }
-
- if lag != 1 { // 3 (hwm) - 2 (current) = 1
- t.Errorf("Expected lag 1, got %d", lag)
- }
- })
-
- t.Run("OffsetSeekingAndRanges", func(t *testing.T) {
- // Create subscription at specific offset
- sub, err := integration.CreateSubscription(
- "seek-test-sub",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_EXACT_OFFSET,
- 1,
- )
- if err != nil {
- t.Fatalf("Failed to create subscription at offset 1: %v", err)
- }
-
- // Verify starting position
- if sub.CurrentOffset != 1 {
- t.Errorf("Expected current offset 1, got %d", sub.CurrentOffset)
- }
-
- // Get offset range
- offsetRange, err := sub.GetOffsetRange(2)
- if err != nil {
- t.Fatalf("Failed to get offset range: %v", err)
- }
-
- if offsetRange.StartOffset != 1 {
- t.Errorf("Expected start offset 1, got %d", offsetRange.StartOffset)
- }
-
- if offsetRange.Count != 2 {
- t.Errorf("Expected count 2, got %d", offsetRange.Count)
- }
-
- // Seek to different offset
- err = sub.SeekToOffset(0)
- if err != nil {
- t.Fatalf("Failed to seek to offset 0: %v", err)
- }
-
- if sub.CurrentOffset != 0 {
- t.Errorf("Expected current offset 0 after seek, got %d", sub.CurrentOffset)
- }
- })
-
- t.Run("PartitionInformationAndMetrics", func(t *testing.T) {
- // Get partition offset info
- info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get partition offset info: %v", err)
- }
-
- if info.EarliestOffset != 0 {
- t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
- }
-
- if info.LatestOffset != 2 {
- t.Errorf("Expected latest offset 2, got %d", info.LatestOffset)
- }
-
- if info.HighWaterMark != 3 {
- t.Errorf("Expected high water mark 3, got %d", info.HighWaterMark)
- }
-
- if info.ActiveSubscriptions != 2 { // Two subscriptions created above
- t.Errorf("Expected 2 active subscriptions, got %d", info.ActiveSubscriptions)
- }
-
- // Get offset metrics
- metrics := integration.GetOffsetMetrics()
- if metrics.PartitionCount != 1 {
- t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount)
- }
-
- if metrics.ActiveSubscriptions != 2 {
- t.Errorf("Expected 2 active subscriptions in metrics, got %d", metrics.ActiveSubscriptions)
- }
- })
-}
-
-// TestOffsetPersistenceAcrossRestarts tests that offsets persist across system restarts
-func TestOffsetPersistenceAcrossRestarts(t *testing.T) {
- // Create temporary database
- tmpFile, err := os.CreateTemp("", "persistence_test_*.db")
- if err != nil {
- t.Fatalf("Failed to create temp database: %v", err)
- }
- tmpFile.Close()
- defer os.Remove(tmpFile.Name())
-
- partition := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: time.Now().UnixNano(),
- }
-
- var lastOffset int64
-
- // First session: Create database and assign offsets
- {
- db, err := CreateDatabase(tmpFile.Name())
- if err != nil {
- t.Fatalf("Failed to create database: %v", err)
- }
-
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
-
- integration := NewSMQOffsetIntegration(storage)
-
- // Publish some records
- records := []PublishRecordRequest{
- {Key: []byte("msg1"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("msg2"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("msg3"), Value: &schema_pb.RecordValue{}},
- }
-
- response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
- if err != nil {
- t.Fatalf("Failed to publish records: %v", err)
- }
-
- lastOffset = response.LastOffset
-
- // Close connections - Close integration first to trigger final checkpoint
- integration.Close()
- storage.Close()
- db.Close()
- }
-
- // Second session: Reopen database and verify persistence
- {
- db, err := CreateDatabase(tmpFile.Name())
- if err != nil {
- t.Fatalf("Failed to reopen database: %v", err)
- }
- defer db.Close()
-
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- integration := NewSMQOffsetIntegration(storage)
-
- // Verify high water mark persisted
- hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get high water mark after restart: %v", err)
- }
-
- if hwm != lastOffset+1 {
- t.Errorf("Expected high water mark %d after restart, got %d", lastOffset+1, hwm)
- }
-
- // Assign new offsets and verify continuity
- newResponse, err := integration.PublishRecord("test-namespace", "test-topic", partition, []byte("msg4"), &schema_pb.RecordValue{})
- if err != nil {
- t.Fatalf("Failed to publish new record after restart: %v", err)
- }
-
- expectedNextOffset := lastOffset + 1
- if newResponse.BaseOffset != expectedNextOffset {
- t.Errorf("Expected next offset %d after restart, got %d", expectedNextOffset, newResponse.BaseOffset)
- }
- }
-}
-
-// TestConcurrentOffsetOperations tests concurrent offset operations
-func TestConcurrentOffsetOperations(t *testing.T) {
- // Create temporary database
- tmpFile, err := os.CreateTemp("", "concurrent_test_*.db")
- if err != nil {
- t.Fatalf("Failed to create temp database: %v", err)
- }
- tmpFile.Close()
- defer os.Remove(tmpFile.Name())
-
- db, err := CreateDatabase(tmpFile.Name())
- if err != nil {
- t.Fatalf("Failed to create database: %v", err)
- }
- defer db.Close()
-
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- integration := NewSMQOffsetIntegration(storage)
-
- partition := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: time.Now().UnixNano(),
- }
-
- // Concurrent publishers
- const numPublishers = 5
- const recordsPerPublisher = 10
-
- done := make(chan bool, numPublishers)
-
- for i := 0; i < numPublishers; i++ {
- go func(publisherID int) {
- defer func() { done <- true }()
-
- for j := 0; j < recordsPerPublisher; j++ {
- key := fmt.Sprintf("publisher-%d-msg-%d", publisherID, j)
- _, err := integration.PublishRecord("test-namespace", "test-topic", partition, []byte(key), &schema_pb.RecordValue{})
- if err != nil {
- t.Errorf("Publisher %d failed to publish message %d: %v", publisherID, j, err)
- return
- }
- }
- }(i)
- }
-
- // Wait for all publishers to complete
- for i := 0; i < numPublishers; i++ {
- <-done
- }
-
- // Verify total records
- hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get high water mark: %v", err)
- }
-
- expectedTotal := int64(numPublishers * recordsPerPublisher)
- if hwm != expectedTotal {
- t.Errorf("Expected high water mark %d, got %d", expectedTotal, hwm)
- }
-
- // Verify no duplicate offsets
- info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get partition info: %v", err)
- }
-
- if info.RecordCount != expectedTotal {
- t.Errorf("Expected record count %d, got %d", expectedTotal, info.RecordCount)
- }
-}
-
-// TestOffsetValidationAndErrorHandling tests error conditions and validation
-func TestOffsetValidationAndErrorHandling(t *testing.T) {
- // Create temporary database
- tmpFile, err := os.CreateTemp("", "validation_test_*.db")
- if err != nil {
- t.Fatalf("Failed to create temp database: %v", err)
- }
- tmpFile.Close()
- defer os.Remove(tmpFile.Name())
-
- db, err := CreateDatabase(tmpFile.Name())
- if err != nil {
- t.Fatalf("Failed to create database: %v", err)
- }
- defer db.Close()
-
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- integration := NewSMQOffsetIntegration(storage)
-
- partition := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: time.Now().UnixNano(),
- }
-
- t.Run("InvalidOffsetSubscription", func(t *testing.T) {
- // Try to create subscription with invalid offset
- _, err := integration.CreateSubscription(
- "invalid-sub",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_EXACT_OFFSET,
- 100, // Beyond any existing data
- )
- if err == nil {
- t.Error("Expected error for subscription beyond high water mark")
- }
- })
-
- t.Run("NegativeOffsetValidation", func(t *testing.T) {
- // Try to create subscription with negative offset
- _, err := integration.CreateSubscription(
- "negative-sub",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_EXACT_OFFSET,
- -1,
- )
- if err == nil {
- t.Error("Expected error for negative offset")
- }
- })
-
- t.Run("DuplicateSubscriptionID", func(t *testing.T) {
- // Create first subscription
- _, err := integration.CreateSubscription(
- "duplicate-id",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_RESET_TO_EARLIEST,
- 0,
- )
- if err != nil {
- t.Fatalf("Failed to create first subscription: %v", err)
- }
-
- // Try to create duplicate
- _, err = integration.CreateSubscription(
- "duplicate-id",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_RESET_TO_EARLIEST,
- 0,
- )
- if err == nil {
- t.Error("Expected error for duplicate subscription ID")
- }
- })
-
- t.Run("OffsetRangeValidation", func(t *testing.T) {
- // Add some data first
- integration.PublishRecord("test-namespace", "test-topic", partition, []byte("test"), &schema_pb.RecordValue{})
-
- // Test invalid range validation
- err := integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 5, 10) // Beyond high water mark
- if err == nil {
- t.Error("Expected error for range beyond high water mark")
- }
-
- err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 10, 5) // End before start
- if err == nil {
- t.Error("Expected error for end offset before start offset")
- }
-
- err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, -1, 5) // Negative start
- if err == nil {
- t.Error("Expected error for negative start offset")
- }
- })
-}
diff --git a/weed/mq/offset/filer_storage.go b/weed/mq/offset/filer_storage.go
index 6f1a71e39..10b54bae3 100644
--- a/weed/mq/offset/filer_storage.go
+++ b/weed/mq/offset/filer_storage.go
@@ -93,9 +93,3 @@ func (f *FilerOffsetStorage) getPartitionDir(namespace, topicName string, partit
return fmt.Sprintf("%s/%s/%s/%s/%s", filer.TopicsDir, namespace, topicName, version, partitionRange)
}
-
-// getPartitionKey generates a unique key for a partition
-func (f *FilerOffsetStorage) getPartitionKey(partition *schema_pb.Partition) string {
- return fmt.Sprintf("ring:%d:range:%d-%d:time:%d",
- partition.RingSize, partition.RangeStart, partition.RangeStop, partition.UnixTimeNs)
-}
diff --git a/weed/mq/offset/integration_test.go b/weed/mq/offset/integration_test.go
deleted file mode 100644
index 35299be65..000000000
--- a/weed/mq/offset/integration_test.go
+++ /dev/null
@@ -1,544 +0,0 @@
-package offset
-
-import (
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
-)
-
-func TestSMQOffsetIntegration_PublishRecord(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Publish a single record
- response, err := integration.PublishRecord(
- "test-namespace", "test-topic",
- partition,
- []byte("test-key"),
- &schema_pb.RecordValue{},
- )
-
- if err != nil {
- t.Fatalf("Failed to publish record: %v", err)
- }
-
- if response.Error != "" {
- t.Errorf("Expected no error, got: %s", response.Error)
- }
-
- if response.BaseOffset != 0 {
- t.Errorf("Expected base offset 0, got %d", response.BaseOffset)
- }
-
- if response.LastOffset != 0 {
- t.Errorf("Expected last offset 0, got %d", response.LastOffset)
- }
-}
-
-func TestSMQOffsetIntegration_PublishRecordBatch(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Create batch of records
- records := []PublishRecordRequest{
- {Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
- }
-
- // Publish batch
- response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
- if err != nil {
- t.Fatalf("Failed to publish record batch: %v", err)
- }
-
- if response.Error != "" {
- t.Errorf("Expected no error, got: %s", response.Error)
- }
-
- if response.BaseOffset != 0 {
- t.Errorf("Expected base offset 0, got %d", response.BaseOffset)
- }
-
- if response.LastOffset != 2 {
- t.Errorf("Expected last offset 2, got %d", response.LastOffset)
- }
-
- // Verify high water mark
- hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get high water mark: %v", err)
- }
-
- if hwm != 3 {
- t.Errorf("Expected high water mark 3, got %d", hwm)
- }
-}
-
-func TestSMQOffsetIntegration_EmptyBatch(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Publish empty batch
- response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, []PublishRecordRequest{})
- if err != nil {
- t.Fatalf("Failed to publish empty batch: %v", err)
- }
-
- if response.Error == "" {
- t.Error("Expected error for empty batch")
- }
-}
-
-func TestSMQOffsetIntegration_CreateSubscription(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Publish some records first
- records := []PublishRecordRequest{
- {Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
- }
- integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
-
- // Create subscription
- sub, err := integration.CreateSubscription(
- "test-sub",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_RESET_TO_EARLIEST,
- 0,
- )
-
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- if sub.ID != "test-sub" {
- t.Errorf("Expected subscription ID 'test-sub', got %s", sub.ID)
- }
-
- if sub.StartOffset != 0 {
- t.Errorf("Expected start offset 0, got %d", sub.StartOffset)
- }
-}
-
-func TestSMQOffsetIntegration_SubscribeRecords(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Publish some records
- records := []PublishRecordRequest{
- {Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
- }
- integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
-
- // Create subscription
- sub, err := integration.CreateSubscription(
- "test-sub",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_RESET_TO_EARLIEST,
- 0,
- )
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Subscribe to records
- responses, err := integration.SubscribeRecords(sub, 2)
- if err != nil {
- t.Fatalf("Failed to subscribe to records: %v", err)
- }
-
- if len(responses) != 2 {
- t.Errorf("Expected 2 responses, got %d", len(responses))
- }
-
- // Check offset progression
- if responses[0].Offset != 0 {
- t.Errorf("Expected first record offset 0, got %d", responses[0].Offset)
- }
-
- if responses[1].Offset != 1 {
- t.Errorf("Expected second record offset 1, got %d", responses[1].Offset)
- }
-
- // Check subscription advancement
- if sub.CurrentOffset != 2 {
- t.Errorf("Expected subscription current offset 2, got %d", sub.CurrentOffset)
- }
-}
-
-func TestSMQOffsetIntegration_SubscribeEmptyPartition(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Create subscription on empty partition
- sub, err := integration.CreateSubscription(
- "empty-sub",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_RESET_TO_EARLIEST,
- 0,
- )
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Subscribe to records (should return empty)
- responses, err := integration.SubscribeRecords(sub, 10)
- if err != nil {
- t.Fatalf("Failed to subscribe to empty partition: %v", err)
- }
-
- if len(responses) != 0 {
- t.Errorf("Expected 0 responses from empty partition, got %d", len(responses))
- }
-}
-
-func TestSMQOffsetIntegration_SeekSubscription(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Publish records
- records := []PublishRecordRequest{
- {Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key4"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key5"), Value: &schema_pb.RecordValue{}},
- }
- integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
-
- // Create subscription
- sub, err := integration.CreateSubscription(
- "seek-sub",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_RESET_TO_EARLIEST,
- 0,
- )
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Seek to offset 3
- err = integration.SeekSubscription("seek-sub", 3)
- if err != nil {
- t.Fatalf("Failed to seek subscription: %v", err)
- }
-
- if sub.CurrentOffset != 3 {
- t.Errorf("Expected current offset 3 after seek, got %d", sub.CurrentOffset)
- }
-
- // Subscribe from new position
- responses, err := integration.SubscribeRecords(sub, 2)
- if err != nil {
- t.Fatalf("Failed to subscribe after seek: %v", err)
- }
-
- if len(responses) != 2 {
- t.Errorf("Expected 2 responses after seek, got %d", len(responses))
- }
-
- if responses[0].Offset != 3 {
- t.Errorf("Expected first record offset 3 after seek, got %d", responses[0].Offset)
- }
-}
-
-func TestSMQOffsetIntegration_GetSubscriptionLag(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Publish records
- records := []PublishRecordRequest{
- {Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
- }
- integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
-
- // Create subscription at offset 1
- sub, err := integration.CreateSubscription(
- "lag-sub",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_EXACT_OFFSET,
- 1,
- )
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Get lag
- lag, err := integration.GetSubscriptionLag("lag-sub")
- if err != nil {
- t.Fatalf("Failed to get subscription lag: %v", err)
- }
-
- expectedLag := int64(3 - 1) // hwm - current
- if lag != expectedLag {
- t.Errorf("Expected lag %d, got %d", expectedLag, lag)
- }
-
- // Advance subscription and check lag again
- integration.SubscribeRecords(sub, 1)
-
- lag, err = integration.GetSubscriptionLag("lag-sub")
- if err != nil {
- t.Fatalf("Failed to get lag after advance: %v", err)
- }
-
- expectedLag = int64(3 - 2) // hwm - current
- if lag != expectedLag {
- t.Errorf("Expected lag %d after advance, got %d", expectedLag, lag)
- }
-}
-
-func TestSMQOffsetIntegration_CloseSubscription(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Create subscription
- _, err := integration.CreateSubscription(
- "close-sub",
- "test-namespace", "test-topic",
- partition,
- schema_pb.OffsetType_RESET_TO_EARLIEST,
- 0,
- )
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Close subscription
- err = integration.CloseSubscription("close-sub")
- if err != nil {
- t.Fatalf("Failed to close subscription: %v", err)
- }
-
- // Try to get lag (should fail)
- _, err = integration.GetSubscriptionLag("close-sub")
- if err == nil {
- t.Error("Expected error when getting lag for closed subscription")
- }
-}
-
-func TestSMQOffsetIntegration_ValidateOffsetRange(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Publish some records
- records := []PublishRecordRequest{
- {Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
- }
- integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
-
- // Test valid range
- err := integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 0, 2)
- if err != nil {
- t.Errorf("Valid range should not return error: %v", err)
- }
-
- // Test invalid range (beyond hwm)
- err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 0, 5)
- if err == nil {
- t.Error("Expected error for range beyond high water mark")
- }
-}
-
-func TestSMQOffsetIntegration_GetAvailableOffsetRange(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Test empty partition
- offsetRange, err := integration.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get available range for empty partition: %v", err)
- }
-
- if offsetRange.Count != 0 {
- t.Errorf("Expected empty range for empty partition, got count %d", offsetRange.Count)
- }
-
- // Publish records
- records := []PublishRecordRequest{
- {Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
- }
- integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
-
- // Test with data
- offsetRange, err = integration.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get available range: %v", err)
- }
-
- if offsetRange.StartOffset != 0 {
- t.Errorf("Expected start offset 0, got %d", offsetRange.StartOffset)
- }
-
- if offsetRange.EndOffset != 1 {
- t.Errorf("Expected end offset 1, got %d", offsetRange.EndOffset)
- }
-
- if offsetRange.Count != 2 {
- t.Errorf("Expected count 2, got %d", offsetRange.Count)
- }
-}
-
-func TestSMQOffsetIntegration_GetOffsetMetrics(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Initial metrics
- metrics := integration.GetOffsetMetrics()
- if metrics.TotalOffsets != 0 {
- t.Errorf("Expected 0 total offsets initially, got %d", metrics.TotalOffsets)
- }
-
- if metrics.ActiveSubscriptions != 0 {
- t.Errorf("Expected 0 active subscriptions initially, got %d", metrics.ActiveSubscriptions)
- }
-
- // Publish records
- records := []PublishRecordRequest{
- {Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
- }
- integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
-
- // Create subscriptions
- integration.CreateSubscription("sub1", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
- integration.CreateSubscription("sub2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
-
- // Check updated metrics
- metrics = integration.GetOffsetMetrics()
- if metrics.TotalOffsets != 2 {
- t.Errorf("Expected 2 total offsets, got %d", metrics.TotalOffsets)
- }
-
- if metrics.ActiveSubscriptions != 2 {
- t.Errorf("Expected 2 active subscriptions, got %d", metrics.ActiveSubscriptions)
- }
-
- if metrics.PartitionCount != 1 {
- t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount)
- }
-}
-
-func TestSMQOffsetIntegration_GetOffsetInfo(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Test non-existent offset
- info, err := integration.GetOffsetInfo("test-namespace", "test-topic", partition, 0)
- if err != nil {
- t.Fatalf("Failed to get offset info: %v", err)
- }
-
- if info.Exists {
- t.Error("Offset should not exist in empty partition")
- }
-
- // Publish record
- integration.PublishRecord("test-namespace", "test-topic", partition, []byte("key1"), &schema_pb.RecordValue{})
-
- // Test existing offset
- info, err = integration.GetOffsetInfo("test-namespace", "test-topic", partition, 0)
- if err != nil {
- t.Fatalf("Failed to get offset info for existing offset: %v", err)
- }
-
- if !info.Exists {
- t.Error("Offset should exist after publishing")
- }
-
- if info.Offset != 0 {
- t.Errorf("Expected offset 0, got %d", info.Offset)
- }
-}
-
-func TestSMQOffsetIntegration_GetPartitionOffsetInfo(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- integration := NewSMQOffsetIntegration(storage)
- partition := createTestPartition()
-
- // Test empty partition
- info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get partition offset info: %v", err)
- }
-
- if info.EarliestOffset != 0 {
- t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
- }
-
- if info.LatestOffset != -1 {
- t.Errorf("Expected latest offset -1 for empty partition, got %d", info.LatestOffset)
- }
-
- if info.HighWaterMark != 0 {
- t.Errorf("Expected high water mark 0, got %d", info.HighWaterMark)
- }
-
- if info.RecordCount != 0 {
- t.Errorf("Expected record count 0, got %d", info.RecordCount)
- }
-
- // Publish records
- records := []PublishRecordRequest{
- {Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
- {Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
- }
- integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
-
- // Create subscription
- integration.CreateSubscription("test-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
-
- // Test with data
- info, err = integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get partition offset info with data: %v", err)
- }
-
- if info.EarliestOffset != 0 {
- t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
- }
-
- if info.LatestOffset != 2 {
- t.Errorf("Expected latest offset 2, got %d", info.LatestOffset)
- }
-
- if info.HighWaterMark != 3 {
- t.Errorf("Expected high water mark 3, got %d", info.HighWaterMark)
- }
-
- if info.RecordCount != 3 {
- t.Errorf("Expected record count 3, got %d", info.RecordCount)
- }
-
- if info.ActiveSubscriptions != 1 {
- t.Errorf("Expected 1 active subscription, got %d", info.ActiveSubscriptions)
- }
-}
diff --git a/weed/mq/offset/manager.go b/weed/mq/offset/manager.go
index 53388d82f..b78307f3a 100644
--- a/weed/mq/offset/manager.go
+++ b/weed/mq/offset/manager.go
@@ -338,13 +338,6 @@ type OffsetAssigner struct {
registry *PartitionOffsetRegistry
}
-// NewOffsetAssigner creates a new offset assigner
-func NewOffsetAssigner(storage OffsetStorage) *OffsetAssigner {
- return &OffsetAssigner{
- registry: NewPartitionOffsetRegistry(storage),
- }
-}
-
// AssignSingleOffset assigns a single offset with timestamp
func (a *OffsetAssigner) AssignSingleOffset(namespace, topicName string, partition *schema_pb.Partition) *AssignmentResult {
offset, err := a.registry.AssignOffset(namespace, topicName, partition)
diff --git a/weed/mq/offset/manager_test.go b/weed/mq/offset/manager_test.go
deleted file mode 100644
index 0db301e84..000000000
--- a/weed/mq/offset/manager_test.go
+++ /dev/null
@@ -1,388 +0,0 @@
-package offset
-
-import (
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
-)
-
-func createTestPartition() *schema_pb.Partition {
- return &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: time.Now().UnixNano(),
- }
-}
-
-func TestPartitionOffsetManager_BasicAssignment(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- partition := createTestPartition()
-
- manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
- if err != nil {
- t.Fatalf("Failed to create offset manager: %v", err)
- }
-
- // Test sequential offset assignment
- for i := int64(0); i < 10; i++ {
- offset := manager.AssignOffset()
- if offset != i {
- t.Errorf("Expected offset %d, got %d", i, offset)
- }
- }
-
- // Test high water mark
- hwm := manager.GetHighWaterMark()
- if hwm != 10 {
- t.Errorf("Expected high water mark 10, got %d", hwm)
- }
-}
-
-func TestPartitionOffsetManager_BatchAssignment(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- partition := createTestPartition()
-
- manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
- if err != nil {
- t.Fatalf("Failed to create offset manager: %v", err)
- }
-
- // Assign batch of 5 offsets
- baseOffset, lastOffset := manager.AssignOffsets(5)
- if baseOffset != 0 {
- t.Errorf("Expected base offset 0, got %d", baseOffset)
- }
- if lastOffset != 4 {
- t.Errorf("Expected last offset 4, got %d", lastOffset)
- }
-
- // Assign another batch
- baseOffset, lastOffset = manager.AssignOffsets(3)
- if baseOffset != 5 {
- t.Errorf("Expected base offset 5, got %d", baseOffset)
- }
- if lastOffset != 7 {
- t.Errorf("Expected last offset 7, got %d", lastOffset)
- }
-
- // Check high water mark
- hwm := manager.GetHighWaterMark()
- if hwm != 8 {
- t.Errorf("Expected high water mark 8, got %d", hwm)
- }
-}
-
-func TestPartitionOffsetManager_Recovery(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- partition := createTestPartition()
-
- // Create manager and assign some offsets
- manager1, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
- if err != nil {
- t.Fatalf("Failed to create offset manager: %v", err)
- }
-
- // Assign offsets and simulate records
- for i := 0; i < 150; i++ { // More than checkpoint interval
- offset := manager1.AssignOffset()
- storage.AddRecord("test-namespace", "test-topic", partition, offset)
- }
-
- // Wait for checkpoint to complete
- time.Sleep(100 * time.Millisecond)
-
- // Create new manager (simulates restart)
- manager2, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
- if err != nil {
- t.Fatalf("Failed to create offset manager after recovery: %v", err)
- }
-
- // Next offset should continue from checkpoint + 1
- // With checkpoint interval 100, checkpoint happens at offset 100
- // So recovery should start from 101, but we assigned 150 offsets (0-149)
- // The checkpoint should be at 100, so next offset should be 101
- // But since we have records up to 149, it should recover from storage scan
- nextOffset := manager2.AssignOffset()
- if nextOffset != 150 {
- t.Errorf("Expected next offset 150 after recovery, got %d", nextOffset)
- }
-}
-
-func TestPartitionOffsetManager_RecoveryFromStorage(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- partition := createTestPartition()
-
- // Simulate existing records in storage without checkpoint
- for i := int64(0); i < 50; i++ {
- storage.AddRecord("test-namespace", "test-topic", partition, i)
- }
-
- // Create manager - should recover from storage scan
- manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
- if err != nil {
- t.Fatalf("Failed to create offset manager: %v", err)
- }
-
- // Next offset should be 50
- nextOffset := manager.AssignOffset()
- if nextOffset != 50 {
- t.Errorf("Expected next offset 50 after storage recovery, got %d", nextOffset)
- }
-}
-
-func TestPartitionOffsetRegistry_MultiplePartitions(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
-
- // Create different partitions
- partition1 := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: time.Now().UnixNano(),
- }
-
- partition2 := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 32,
- RangeStop: 63,
- UnixTimeNs: time.Now().UnixNano(),
- }
-
- // Assign offsets to different partitions
- offset1, err := registry.AssignOffset("test-namespace", "test-topic", partition1)
- if err != nil {
- t.Fatalf("Failed to assign offset to partition1: %v", err)
- }
- if offset1 != 0 {
- t.Errorf("Expected offset 0 for partition1, got %d", offset1)
- }
-
- offset2, err := registry.AssignOffset("test-namespace", "test-topic", partition2)
- if err != nil {
- t.Fatalf("Failed to assign offset to partition2: %v", err)
- }
- if offset2 != 0 {
- t.Errorf("Expected offset 0 for partition2, got %d", offset2)
- }
-
- // Assign more offsets to partition1
- offset1_2, err := registry.AssignOffset("test-namespace", "test-topic", partition1)
- if err != nil {
- t.Fatalf("Failed to assign second offset to partition1: %v", err)
- }
- if offset1_2 != 1 {
- t.Errorf("Expected offset 1 for partition1, got %d", offset1_2)
- }
-
- // Partition2 should still be at 0 for next assignment
- offset2_2, err := registry.AssignOffset("test-namespace", "test-topic", partition2)
- if err != nil {
- t.Fatalf("Failed to assign second offset to partition2: %v", err)
- }
- if offset2_2 != 1 {
- t.Errorf("Expected offset 1 for partition2, got %d", offset2_2)
- }
-}
-
-func TestPartitionOffsetRegistry_BatchAssignment(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- partition := createTestPartition()
-
- // Assign batch of offsets
- baseOffset, lastOffset, err := registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
- if err != nil {
- t.Fatalf("Failed to assign batch offsets: %v", err)
- }
-
- if baseOffset != 0 {
- t.Errorf("Expected base offset 0, got %d", baseOffset)
- }
- if lastOffset != 9 {
- t.Errorf("Expected last offset 9, got %d", lastOffset)
- }
-
- // Get high water mark
- hwm, err := registry.GetHighWaterMark("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get high water mark: %v", err)
- }
- if hwm != 10 {
- t.Errorf("Expected high water mark 10, got %d", hwm)
- }
-}
-
-func TestOffsetAssigner_SingleAssignment(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- assigner := NewOffsetAssigner(storage)
- partition := createTestPartition()
-
- // Assign single offset
- result := assigner.AssignSingleOffset("test-namespace", "test-topic", partition)
- if result.Error != nil {
- t.Fatalf("Failed to assign single offset: %v", result.Error)
- }
-
- if result.Assignment == nil {
- t.Fatal("Assignment result is nil")
- }
-
- if result.Assignment.Offset != 0 {
- t.Errorf("Expected offset 0, got %d", result.Assignment.Offset)
- }
-
- if result.Assignment.Partition != partition {
- t.Error("Partition mismatch in assignment")
- }
-
- if result.Assignment.Timestamp <= 0 {
- t.Error("Timestamp should be set")
- }
-}
-
-func TestOffsetAssigner_BatchAssignment(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- assigner := NewOffsetAssigner(storage)
- partition := createTestPartition()
-
- // Assign batch of offsets
- result := assigner.AssignBatchOffsets("test-namespace", "test-topic", partition, 5)
- if result.Error != nil {
- t.Fatalf("Failed to assign batch offsets: %v", result.Error)
- }
-
- if result.Batch == nil {
- t.Fatal("Batch result is nil")
- }
-
- if result.Batch.BaseOffset != 0 {
- t.Errorf("Expected base offset 0, got %d", result.Batch.BaseOffset)
- }
-
- if result.Batch.LastOffset != 4 {
- t.Errorf("Expected last offset 4, got %d", result.Batch.LastOffset)
- }
-
- if result.Batch.Count != 5 {
- t.Errorf("Expected count 5, got %d", result.Batch.Count)
- }
-
- if result.Batch.Timestamp <= 0 {
- t.Error("Timestamp should be set")
- }
-}
-
-func TestOffsetAssigner_HighWaterMark(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- assigner := NewOffsetAssigner(storage)
- partition := createTestPartition()
-
- // Initially should be 0
- hwm, err := assigner.GetHighWaterMark("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get initial high water mark: %v", err)
- }
- if hwm != 0 {
- t.Errorf("Expected initial high water mark 0, got %d", hwm)
- }
-
- // Assign some offsets
- assigner.AssignBatchOffsets("test-namespace", "test-topic", partition, 10)
-
- // High water mark should be updated
- hwm, err = assigner.GetHighWaterMark("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get high water mark after assignment: %v", err)
- }
- if hwm != 10 {
- t.Errorf("Expected high water mark 10, got %d", hwm)
- }
-}
-
-func TestPartitionKey(t *testing.T) {
- partition1 := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: 1234567890,
- }
-
- partition2 := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: 1234567890,
- }
-
- partition3 := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 32,
- RangeStop: 63,
- UnixTimeNs: 1234567890,
- }
-
- key1 := partitionKey(partition1)
- key2 := partitionKey(partition2)
- key3 := partitionKey(partition3)
-
- // Same partitions should have same key
- if key1 != key2 {
- t.Errorf("Same partitions should have same key: %s vs %s", key1, key2)
- }
-
- // Different partitions should have different keys
- if key1 == key3 {
- t.Errorf("Different partitions should have different keys: %s vs %s", key1, key3)
- }
-}
-
-func TestConcurrentOffsetAssignment(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- partition := createTestPartition()
-
- const numGoroutines = 10
- const offsetsPerGoroutine = 100
-
- results := make(chan int64, numGoroutines*offsetsPerGoroutine)
-
- // Start concurrent offset assignments
- for i := 0; i < numGoroutines; i++ {
- go func() {
- for j := 0; j < offsetsPerGoroutine; j++ {
- offset, err := registry.AssignOffset("test-namespace", "test-topic", partition)
- if err != nil {
- t.Errorf("Failed to assign offset: %v", err)
- return
- }
- results <- offset
- }
- }()
- }
-
- // Collect all results
- offsets := make(map[int64]bool)
- for i := 0; i < numGoroutines*offsetsPerGoroutine; i++ {
- offset := <-results
- if offsets[offset] {
- t.Errorf("Duplicate offset assigned: %d", offset)
- }
- offsets[offset] = true
- }
-
- // Verify we got all expected offsets
- expectedCount := numGoroutines * offsetsPerGoroutine
- if len(offsets) != expectedCount {
- t.Errorf("Expected %d unique offsets, got %d", expectedCount, len(offsets))
- }
-
- // Verify offsets are in expected range
- for offset := range offsets {
- if offset < 0 || offset >= int64(expectedCount) {
- t.Errorf("Offset %d is out of expected range [0, %d)", offset, expectedCount)
- }
- }
-}
diff --git a/weed/mq/offset/migration.go b/weed/mq/offset/migration.go
deleted file mode 100644
index 4e0a6ab12..000000000
--- a/weed/mq/offset/migration.go
+++ /dev/null
@@ -1,302 +0,0 @@
-package offset
-
-import (
- "database/sql"
- "fmt"
- "time"
-)
-
-// MigrationVersion represents a database migration version
-type MigrationVersion struct {
- Version int
- Description string
- SQL string
-}
-
-// GetMigrations returns all available migrations for offset storage
-func GetMigrations() []MigrationVersion {
- return []MigrationVersion{
- {
- Version: 1,
- Description: "Create initial offset storage tables",
- SQL: `
- -- Partition offset checkpoints table
- -- TODO: Add _index as computed column when supported by database
- CREATE TABLE IF NOT EXISTS partition_offset_checkpoints (
- partition_key TEXT PRIMARY KEY,
- ring_size INTEGER NOT NULL,
- range_start INTEGER NOT NULL,
- range_stop INTEGER NOT NULL,
- unix_time_ns INTEGER NOT NULL,
- checkpoint_offset INTEGER NOT NULL,
- updated_at INTEGER NOT NULL
- );
-
- -- Offset mappings table for detailed tracking
- -- TODO: Add _index as computed column when supported by database
- CREATE TABLE IF NOT EXISTS offset_mappings (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- partition_key TEXT NOT NULL,
- kafka_offset INTEGER NOT NULL,
- smq_timestamp INTEGER NOT NULL,
- message_size INTEGER NOT NULL,
- created_at INTEGER NOT NULL,
- UNIQUE(partition_key, kafka_offset)
- );
-
- -- Schema migrations tracking table
- CREATE TABLE IF NOT EXISTS schema_migrations (
- version INTEGER PRIMARY KEY,
- description TEXT NOT NULL,
- applied_at INTEGER NOT NULL
- );
- `,
- },
- {
- Version: 2,
- Description: "Add indexes for performance optimization",
- SQL: `
- -- Indexes for performance
- CREATE INDEX IF NOT EXISTS idx_partition_offset_checkpoints_partition
- ON partition_offset_checkpoints(partition_key);
-
- CREATE INDEX IF NOT EXISTS idx_offset_mappings_partition_offset
- ON offset_mappings(partition_key, kafka_offset);
-
- CREATE INDEX IF NOT EXISTS idx_offset_mappings_timestamp
- ON offset_mappings(partition_key, smq_timestamp);
-
- CREATE INDEX IF NOT EXISTS idx_offset_mappings_created_at
- ON offset_mappings(created_at);
- `,
- },
- {
- Version: 3,
- Description: "Add partition metadata table for enhanced tracking",
- SQL: `
- -- Partition metadata table
- CREATE TABLE IF NOT EXISTS partition_metadata (
- partition_key TEXT PRIMARY KEY,
- ring_size INTEGER NOT NULL,
- range_start INTEGER NOT NULL,
- range_stop INTEGER NOT NULL,
- unix_time_ns INTEGER NOT NULL,
- created_at INTEGER NOT NULL,
- last_activity_at INTEGER NOT NULL,
- record_count INTEGER DEFAULT 0,
- total_size INTEGER DEFAULT 0
- );
-
- -- Index for partition metadata
- CREATE INDEX IF NOT EXISTS idx_partition_metadata_activity
- ON partition_metadata(last_activity_at);
- `,
- },
- }
-}
-
-// MigrationManager handles database schema migrations
-type MigrationManager struct {
- db *sql.DB
-}
-
-// NewMigrationManager creates a new migration manager
-func NewMigrationManager(db *sql.DB) *MigrationManager {
- return &MigrationManager{db: db}
-}
-
-// GetCurrentVersion returns the current schema version
-func (m *MigrationManager) GetCurrentVersion() (int, error) {
- // First, ensure the migrations table exists
- _, err := m.db.Exec(`
- CREATE TABLE IF NOT EXISTS schema_migrations (
- version INTEGER PRIMARY KEY,
- description TEXT NOT NULL,
- applied_at INTEGER NOT NULL
- )
- `)
- if err != nil {
- return 0, fmt.Errorf("failed to create migrations table: %w", err)
- }
-
- var version sql.NullInt64
- err = m.db.QueryRow("SELECT MAX(version) FROM schema_migrations").Scan(&version)
- if err != nil {
- return 0, fmt.Errorf("failed to get current version: %w", err)
- }
-
- if !version.Valid {
- return 0, nil // No migrations applied yet
- }
-
- return int(version.Int64), nil
-}
-
-// ApplyMigrations applies all pending migrations
-func (m *MigrationManager) ApplyMigrations() error {
- currentVersion, err := m.GetCurrentVersion()
- if err != nil {
- return fmt.Errorf("failed to get current version: %w", err)
- }
-
- migrations := GetMigrations()
-
- for _, migration := range migrations {
- if migration.Version <= currentVersion {
- continue // Already applied
- }
-
- fmt.Printf("Applying migration %d: %s\n", migration.Version, migration.Description)
-
- // Begin transaction
- tx, err := m.db.Begin()
- if err != nil {
- return fmt.Errorf("failed to begin transaction for migration %d: %w", migration.Version, err)
- }
-
- // Execute migration SQL
- _, err = tx.Exec(migration.SQL)
- if err != nil {
- tx.Rollback()
- return fmt.Errorf("failed to execute migration %d: %w", migration.Version, err)
- }
-
- // Record migration as applied
- _, err = tx.Exec(
- "INSERT INTO schema_migrations (version, description, applied_at) VALUES (?, ?, ?)",
- migration.Version,
- migration.Description,
- getCurrentTimestamp(),
- )
- if err != nil {
- tx.Rollback()
- return fmt.Errorf("failed to record migration %d: %w", migration.Version, err)
- }
-
- // Commit transaction
- err = tx.Commit()
- if err != nil {
- return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
- }
-
- fmt.Printf("Successfully applied migration %d\n", migration.Version)
- }
-
- return nil
-}
-
-// RollbackMigration rolls back a specific migration (if supported)
-func (m *MigrationManager) RollbackMigration(version int) error {
- // TODO: Implement rollback functionality
- // ASSUMPTION: For now, rollbacks are not supported as they require careful planning
- return fmt.Errorf("migration rollbacks not implemented - manual intervention required")
-}
-
-// GetAppliedMigrations returns a list of all applied migrations
-func (m *MigrationManager) GetAppliedMigrations() ([]AppliedMigration, error) {
- rows, err := m.db.Query(`
- SELECT version, description, applied_at
- FROM schema_migrations
- ORDER BY version
- `)
- if err != nil {
- return nil, fmt.Errorf("failed to query applied migrations: %w", err)
- }
- defer rows.Close()
-
- var migrations []AppliedMigration
- for rows.Next() {
- var migration AppliedMigration
- err := rows.Scan(&migration.Version, &migration.Description, &migration.AppliedAt)
- if err != nil {
- return nil, fmt.Errorf("failed to scan migration: %w", err)
- }
- migrations = append(migrations, migration)
- }
-
- return migrations, nil
-}
-
-// ValidateSchema validates that the database schema is up to date
-func (m *MigrationManager) ValidateSchema() error {
- currentVersion, err := m.GetCurrentVersion()
- if err != nil {
- return fmt.Errorf("failed to get current version: %w", err)
- }
-
- migrations := GetMigrations()
- if len(migrations) == 0 {
- return nil
- }
-
- latestVersion := migrations[len(migrations)-1].Version
- if currentVersion < latestVersion {
- return fmt.Errorf("schema is outdated: current version %d, latest version %d", currentVersion, latestVersion)
- }
-
- return nil
-}
-
-// AppliedMigration represents a migration that has been applied
-type AppliedMigration struct {
- Version int
- Description string
- AppliedAt int64
-}
-
-// getCurrentTimestamp returns the current timestamp in nanoseconds
-func getCurrentTimestamp() int64 {
- return time.Now().UnixNano()
-}
-
-// CreateDatabase creates and initializes a new offset storage database
-func CreateDatabase(dbPath string) (*sql.DB, error) {
- // TODO: Support different database types (PostgreSQL, MySQL, etc.)
- // ASSUMPTION: Using SQLite for now, can be extended for other databases
-
- db, err := sql.Open("sqlite3", dbPath)
- if err != nil {
- return nil, fmt.Errorf("failed to open database: %w", err)
- }
-
- // Configure SQLite for better performance
- pragmas := []string{
- "PRAGMA journal_mode=WAL", // Write-Ahead Logging for better concurrency
- "PRAGMA synchronous=NORMAL", // Balance between safety and performance
- "PRAGMA cache_size=10000", // Increase cache size
- "PRAGMA foreign_keys=ON", // Enable foreign key constraints
- "PRAGMA temp_store=MEMORY", // Store temporary tables in memory
- }
-
- for _, pragma := range pragmas {
- _, err := db.Exec(pragma)
- if err != nil {
- db.Close()
- return nil, fmt.Errorf("failed to set pragma %s: %w", pragma, err)
- }
- }
-
- // Apply migrations
- migrationManager := NewMigrationManager(db)
- err = migrationManager.ApplyMigrations()
- if err != nil {
- db.Close()
- return nil, fmt.Errorf("failed to apply migrations: %w", err)
- }
-
- return db, nil
-}
-
-// BackupDatabase creates a backup of the offset storage database
-func BackupDatabase(sourceDB *sql.DB, backupPath string) error {
- // TODO: Implement database backup functionality
- // ASSUMPTION: This would use database-specific backup mechanisms
- return fmt.Errorf("database backup not implemented yet")
-}
-
-// RestoreDatabase restores a database from a backup
-func RestoreDatabase(backupPath, targetPath string) error {
- // TODO: Implement database restore functionality
- // ASSUMPTION: This would use database-specific restore mechanisms
- return fmt.Errorf("database restore not implemented yet")
-}
diff --git a/weed/mq/offset/sql_storage.go b/weed/mq/offset/sql_storage.go
deleted file mode 100644
index c3107e5a4..000000000
--- a/weed/mq/offset/sql_storage.go
+++ /dev/null
@@ -1,394 +0,0 @@
-package offset
-
-import (
- "database/sql"
- "fmt"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
-)
-
-// OffsetEntry represents a mapping between Kafka offset and SMQ timestamp
-type OffsetEntry struct {
- KafkaOffset int64
- SMQTimestamp int64
- MessageSize int32
-}
-
-// SQLOffsetStorage implements OffsetStorage using SQL database with _index column
-type SQLOffsetStorage struct {
- db *sql.DB
-}
-
-// NewSQLOffsetStorage creates a new SQL-based offset storage
-func NewSQLOffsetStorage(db *sql.DB) (*SQLOffsetStorage, error) {
- storage := &SQLOffsetStorage{db: db}
-
- // Initialize database schema
- if err := storage.initializeSchema(); err != nil {
- return nil, fmt.Errorf("failed to initialize schema: %w", err)
- }
-
- return storage, nil
-}
-
-// initializeSchema creates the necessary tables for offset storage
-func (s *SQLOffsetStorage) initializeSchema() error {
- // TODO: Create offset storage tables with _index as hidden column
- // ASSUMPTION: Using SQLite-compatible syntax, may need adaptation for other databases
-
- queries := []string{
- // Partition offset checkpoints table
- // TODO: Add _index as computed column when supported by database
- // ASSUMPTION: Using regular columns for now, _index concept preserved for future enhancement
- `CREATE TABLE IF NOT EXISTS partition_offset_checkpoints (
- partition_key TEXT PRIMARY KEY,
- ring_size INTEGER NOT NULL,
- range_start INTEGER NOT NULL,
- range_stop INTEGER NOT NULL,
- unix_time_ns INTEGER NOT NULL,
- checkpoint_offset INTEGER NOT NULL,
- updated_at INTEGER NOT NULL
- )`,
-
- // Offset mappings table for detailed tracking
- // TODO: Add _index as computed column when supported by database
- `CREATE TABLE IF NOT EXISTS offset_mappings (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- partition_key TEXT NOT NULL,
- kafka_offset INTEGER NOT NULL,
- smq_timestamp INTEGER NOT NULL,
- message_size INTEGER NOT NULL,
- created_at INTEGER NOT NULL,
- UNIQUE(partition_key, kafka_offset)
- )`,
-
- // Indexes for performance
- `CREATE INDEX IF NOT EXISTS idx_partition_offset_checkpoints_partition
- ON partition_offset_checkpoints(partition_key)`,
-
- `CREATE INDEX IF NOT EXISTS idx_offset_mappings_partition_offset
- ON offset_mappings(partition_key, kafka_offset)`,
-
- `CREATE INDEX IF NOT EXISTS idx_offset_mappings_timestamp
- ON offset_mappings(partition_key, smq_timestamp)`,
- }
-
- for _, query := range queries {
- if _, err := s.db.Exec(query); err != nil {
- return fmt.Errorf("failed to execute schema query: %w", err)
- }
- }
-
- return nil
-}
-
-// SaveCheckpoint saves the checkpoint for a partition
-func (s *SQLOffsetStorage) SaveCheckpoint(namespace, topicName string, partition *schema_pb.Partition, offset int64) error {
- // Use TopicPartitionKey to ensure each topic has isolated checkpoint storage
- partitionKey := TopicPartitionKey(namespace, topicName, partition)
- now := time.Now().UnixNano()
-
- // TODO: Use UPSERT for better performance
- // ASSUMPTION: SQLite REPLACE syntax, may need adaptation for other databases
- query := `
- REPLACE INTO partition_offset_checkpoints
- (partition_key, ring_size, range_start, range_stop, unix_time_ns, checkpoint_offset, updated_at)
- VALUES (?, ?, ?, ?, ?, ?, ?)
- `
-
- _, err := s.db.Exec(query,
- partitionKey,
- partition.RingSize,
- partition.RangeStart,
- partition.RangeStop,
- partition.UnixTimeNs,
- offset,
- now,
- )
-
- if err != nil {
- return fmt.Errorf("failed to save checkpoint: %w", err)
- }
-
- return nil
-}
-
-// LoadCheckpoint loads the checkpoint for a partition
-func (s *SQLOffsetStorage) LoadCheckpoint(namespace, topicName string, partition *schema_pb.Partition) (int64, error) {
- // Use TopicPartitionKey to match SaveCheckpoint
- partitionKey := TopicPartitionKey(namespace, topicName, partition)
-
- query := `
- SELECT checkpoint_offset
- FROM partition_offset_checkpoints
- WHERE partition_key = ?
- `
-
- var checkpointOffset int64
- err := s.db.QueryRow(query, partitionKey).Scan(&checkpointOffset)
-
- if err == sql.ErrNoRows {
- return -1, fmt.Errorf("no checkpoint found")
- }
-
- if err != nil {
- return -1, fmt.Errorf("failed to load checkpoint: %w", err)
- }
-
- return checkpointOffset, nil
-}
-
-// GetHighestOffset finds the highest offset in storage for a partition
-func (s *SQLOffsetStorage) GetHighestOffset(namespace, topicName string, partition *schema_pb.Partition) (int64, error) {
- // Use TopicPartitionKey to match SaveCheckpoint
- partitionKey := TopicPartitionKey(namespace, topicName, partition)
-
- // TODO: Use _index column for efficient querying
- // ASSUMPTION: kafka_offset represents the sequential offset we're tracking
- query := `
- SELECT MAX(kafka_offset)
- FROM offset_mappings
- WHERE partition_key = ?
- `
-
- var highestOffset sql.NullInt64
- err := s.db.QueryRow(query, partitionKey).Scan(&highestOffset)
-
- if err != nil {
- return -1, fmt.Errorf("failed to get highest offset: %w", err)
- }
-
- if !highestOffset.Valid {
- return -1, fmt.Errorf("no records found")
- }
-
- return highestOffset.Int64, nil
-}
-
-// SaveOffsetMapping stores an offset mapping (extends OffsetStorage interface)
-func (s *SQLOffsetStorage) SaveOffsetMapping(partitionKey string, kafkaOffset, smqTimestamp int64, size int32) error {
- now := time.Now().UnixNano()
-
- // TODO: Handle duplicate key conflicts gracefully
- // ASSUMPTION: Using INSERT OR REPLACE for conflict resolution
- query := `
- INSERT OR REPLACE INTO offset_mappings
- (partition_key, kafka_offset, smq_timestamp, message_size, created_at)
- VALUES (?, ?, ?, ?, ?)
- `
-
- _, err := s.db.Exec(query, partitionKey, kafkaOffset, smqTimestamp, size, now)
- if err != nil {
- return fmt.Errorf("failed to save offset mapping: %w", err)
- }
-
- return nil
-}
-
-// LoadOffsetMappings retrieves all offset mappings for a partition
-func (s *SQLOffsetStorage) LoadOffsetMappings(partitionKey string) ([]OffsetEntry, error) {
- // TODO: Add pagination for large result sets
- // ASSUMPTION: Loading all mappings for now, should be paginated in production
- query := `
- SELECT kafka_offset, smq_timestamp, message_size
- FROM offset_mappings
- WHERE partition_key = ?
- ORDER BY kafka_offset ASC
- `
-
- rows, err := s.db.Query(query, partitionKey)
- if err != nil {
- return nil, fmt.Errorf("failed to query offset mappings: %w", err)
- }
- defer rows.Close()
-
- var entries []OffsetEntry
- for rows.Next() {
- var entry OffsetEntry
- err := rows.Scan(&entry.KafkaOffset, &entry.SMQTimestamp, &entry.MessageSize)
- if err != nil {
- return nil, fmt.Errorf("failed to scan offset entry: %w", err)
- }
- entries = append(entries, entry)
- }
-
- if err := rows.Err(); err != nil {
- return nil, fmt.Errorf("error iterating offset mappings: %w", err)
- }
-
- return entries, nil
-}
-
-// GetOffsetMappingsByRange retrieves offset mappings within a specific range
-func (s *SQLOffsetStorage) GetOffsetMappingsByRange(partitionKey string, startOffset, endOffset int64) ([]OffsetEntry, error) {
- // TODO: Use _index column for efficient range queries
- query := `
- SELECT kafka_offset, smq_timestamp, message_size
- FROM offset_mappings
- WHERE partition_key = ? AND kafka_offset >= ? AND kafka_offset <= ?
- ORDER BY kafka_offset ASC
- `
-
- rows, err := s.db.Query(query, partitionKey, startOffset, endOffset)
- if err != nil {
- return nil, fmt.Errorf("failed to query offset range: %w", err)
- }
- defer rows.Close()
-
- var entries []OffsetEntry
- for rows.Next() {
- var entry OffsetEntry
- err := rows.Scan(&entry.KafkaOffset, &entry.SMQTimestamp, &entry.MessageSize)
- if err != nil {
- return nil, fmt.Errorf("failed to scan offset entry: %w", err)
- }
- entries = append(entries, entry)
- }
-
- return entries, nil
-}
-
-// GetPartitionStats returns statistics about a partition's offset usage
-func (s *SQLOffsetStorage) GetPartitionStats(partitionKey string) (*PartitionStats, error) {
- query := `
- SELECT
- COUNT(*) as record_count,
- MIN(kafka_offset) as earliest_offset,
- MAX(kafka_offset) as latest_offset,
- SUM(message_size) as total_size,
- MIN(created_at) as first_record_time,
- MAX(created_at) as last_record_time
- FROM offset_mappings
- WHERE partition_key = ?
- `
-
- var stats PartitionStats
- var earliestOffset, latestOffset sql.NullInt64
- var totalSize sql.NullInt64
- var firstRecordTime, lastRecordTime sql.NullInt64
-
- err := s.db.QueryRow(query, partitionKey).Scan(
- &stats.RecordCount,
- &earliestOffset,
- &latestOffset,
- &totalSize,
- &firstRecordTime,
- &lastRecordTime,
- )
-
- if err != nil {
- return nil, fmt.Errorf("failed to get partition stats: %w", err)
- }
-
- stats.PartitionKey = partitionKey
-
- if earliestOffset.Valid {
- stats.EarliestOffset = earliestOffset.Int64
- } else {
- stats.EarliestOffset = -1
- }
-
- if latestOffset.Valid {
- stats.LatestOffset = latestOffset.Int64
- stats.HighWaterMark = latestOffset.Int64 + 1
- } else {
- stats.LatestOffset = -1
- stats.HighWaterMark = 0
- }
-
- if firstRecordTime.Valid {
- stats.FirstRecordTime = firstRecordTime.Int64
- }
-
- if lastRecordTime.Valid {
- stats.LastRecordTime = lastRecordTime.Int64
- }
-
- if totalSize.Valid {
- stats.TotalSize = totalSize.Int64
- }
-
- return &stats, nil
-}
-
-// CleanupOldMappings removes offset mappings older than the specified time
-func (s *SQLOffsetStorage) CleanupOldMappings(olderThanNs int64) error {
- // TODO: Add configurable cleanup policies
- // ASSUMPTION: Simple time-based cleanup, could be enhanced with retention policies
- query := `
- DELETE FROM offset_mappings
- WHERE created_at < ?
- `
-
- result, err := s.db.Exec(query, olderThanNs)
- if err != nil {
- return fmt.Errorf("failed to cleanup old mappings: %w", err)
- }
-
- rowsAffected, _ := result.RowsAffected()
- if rowsAffected > 0 {
- // Log cleanup activity
- fmt.Printf("Cleaned up %d old offset mappings\n", rowsAffected)
- }
-
- return nil
-}
-
-// Close closes the database connection
-func (s *SQLOffsetStorage) Close() error {
- if s.db != nil {
- return s.db.Close()
- }
- return nil
-}
-
-// PartitionStats provides statistics about a partition's offset usage
-type PartitionStats struct {
- PartitionKey string
- RecordCount int64
- EarliestOffset int64
- LatestOffset int64
- HighWaterMark int64
- TotalSize int64
- FirstRecordTime int64
- LastRecordTime int64
-}
-
-// GetAllPartitions returns a list of all partitions with offset data
-func (s *SQLOffsetStorage) GetAllPartitions() ([]string, error) {
- query := `
- SELECT DISTINCT partition_key
- FROM offset_mappings
- ORDER BY partition_key
- `
-
- rows, err := s.db.Query(query)
- if err != nil {
- return nil, fmt.Errorf("failed to get all partitions: %w", err)
- }
- defer rows.Close()
-
- var partitions []string
- for rows.Next() {
- var partitionKey string
- if err := rows.Scan(&partitionKey); err != nil {
- return nil, fmt.Errorf("failed to scan partition key: %w", err)
- }
- partitions = append(partitions, partitionKey)
- }
-
- return partitions, nil
-}
-
-// Vacuum performs database maintenance operations
-func (s *SQLOffsetStorage) Vacuum() error {
- // TODO: Add database-specific optimization commands
- // ASSUMPTION: SQLite VACUUM command, may need adaptation for other databases
- _, err := s.db.Exec("VACUUM")
- if err != nil {
- return fmt.Errorf("failed to vacuum database: %w", err)
- }
-
- return nil
-}
diff --git a/weed/mq/offset/sql_storage_test.go b/weed/mq/offset/sql_storage_test.go
deleted file mode 100644
index 661f317de..000000000
--- a/weed/mq/offset/sql_storage_test.go
+++ /dev/null
@@ -1,516 +0,0 @@
-package offset
-
-import (
- "database/sql"
- "os"
- "testing"
- "time"
-
- _ "github.com/mattn/go-sqlite3" // SQLite driver
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
-)
-
-func createTestDB(t *testing.T) *sql.DB {
- // Create temporary database file
- tmpFile, err := os.CreateTemp("", "offset_test_*.db")
- if err != nil {
- t.Fatalf("Failed to create temp database file: %v", err)
- }
- tmpFile.Close()
-
- // Clean up the file when test completes
- t.Cleanup(func() {
- os.Remove(tmpFile.Name())
- })
-
- db, err := sql.Open("sqlite3", tmpFile.Name())
- if err != nil {
- t.Fatalf("Failed to open database: %v", err)
- }
-
- t.Cleanup(func() {
- db.Close()
- })
-
- return db
-}
-
-func createTestPartitionForSQL() *schema_pb.Partition {
- return &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 0,
- RangeStop: 31,
- UnixTimeNs: time.Now().UnixNano(),
- }
-}
-
-func TestSQLOffsetStorage_InitializeSchema(t *testing.T) {
- db := createTestDB(t)
-
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- // Verify tables were created
- tables := []string{
- "partition_offset_checkpoints",
- "offset_mappings",
- }
-
- for _, table := range tables {
- var count int
- err := db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&count)
- if err != nil {
- t.Fatalf("Failed to check table %s: %v", table, err)
- }
-
- if count != 1 {
- t.Errorf("Table %s was not created", table)
- }
- }
-}
-
-func TestSQLOffsetStorage_SaveLoadCheckpoint(t *testing.T) {
- db := createTestDB(t)
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- partition := createTestPartitionForSQL()
-
- // Test saving checkpoint
- err = storage.SaveCheckpoint("test-namespace", "test-topic", partition, 100)
- if err != nil {
- t.Fatalf("Failed to save checkpoint: %v", err)
- }
-
- // Test loading checkpoint
- checkpoint, err := storage.LoadCheckpoint("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to load checkpoint: %v", err)
- }
-
- if checkpoint != 100 {
- t.Errorf("Expected checkpoint 100, got %d", checkpoint)
- }
-
- // Test updating checkpoint
- err = storage.SaveCheckpoint("test-namespace", "test-topic", partition, 200)
- if err != nil {
- t.Fatalf("Failed to update checkpoint: %v", err)
- }
-
- checkpoint, err = storage.LoadCheckpoint("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to load updated checkpoint: %v", err)
- }
-
- if checkpoint != 200 {
- t.Errorf("Expected updated checkpoint 200, got %d", checkpoint)
- }
-}
-
-func TestSQLOffsetStorage_LoadCheckpointNotFound(t *testing.T) {
- db := createTestDB(t)
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- partition := createTestPartitionForSQL()
-
- // Test loading non-existent checkpoint
- _, err = storage.LoadCheckpoint("test-namespace", "test-topic", partition)
- if err == nil {
- t.Error("Expected error for non-existent checkpoint")
- }
-}
-
-func TestSQLOffsetStorage_SaveLoadOffsetMappings(t *testing.T) {
- db := createTestDB(t)
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- partition := createTestPartitionForSQL()
- partitionKey := partitionKey(partition)
-
- // Save multiple offset mappings
- mappings := []struct {
- offset int64
- timestamp int64
- size int32
- }{
- {0, 1000, 100},
- {1, 2000, 150},
- {2, 3000, 200},
- }
-
- for _, mapping := range mappings {
- err := storage.SaveOffsetMapping(partitionKey, mapping.offset, mapping.timestamp, mapping.size)
- if err != nil {
- t.Fatalf("Failed to save offset mapping: %v", err)
- }
- }
-
- // Load offset mappings
- entries, err := storage.LoadOffsetMappings(partitionKey)
- if err != nil {
- t.Fatalf("Failed to load offset mappings: %v", err)
- }
-
- if len(entries) != len(mappings) {
- t.Errorf("Expected %d entries, got %d", len(mappings), len(entries))
- }
-
- // Verify entries are sorted by offset
- for i, entry := range entries {
- expected := mappings[i]
- if entry.KafkaOffset != expected.offset {
- t.Errorf("Entry %d: expected offset %d, got %d", i, expected.offset, entry.KafkaOffset)
- }
- if entry.SMQTimestamp != expected.timestamp {
- t.Errorf("Entry %d: expected timestamp %d, got %d", i, expected.timestamp, entry.SMQTimestamp)
- }
- if entry.MessageSize != expected.size {
- t.Errorf("Entry %d: expected size %d, got %d", i, expected.size, entry.MessageSize)
- }
- }
-}
-
-func TestSQLOffsetStorage_GetHighestOffset(t *testing.T) {
- db := createTestDB(t)
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- partition := createTestPartitionForSQL()
- partitionKey := TopicPartitionKey("test-namespace", "test-topic", partition)
-
- // Test empty partition
- _, err = storage.GetHighestOffset("test-namespace", "test-topic", partition)
- if err == nil {
- t.Error("Expected error for empty partition")
- }
-
- // Add some offset mappings
- offsets := []int64{5, 1, 3, 2, 4}
- for _, offset := range offsets {
- err := storage.SaveOffsetMapping(partitionKey, offset, offset*1000, 100)
- if err != nil {
- t.Fatalf("Failed to save offset mapping: %v", err)
- }
- }
-
- // Get highest offset
- highest, err := storage.GetHighestOffset("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get highest offset: %v", err)
- }
-
- if highest != 5 {
- t.Errorf("Expected highest offset 5, got %d", highest)
- }
-}
-
-func TestSQLOffsetStorage_GetOffsetMappingsByRange(t *testing.T) {
- db := createTestDB(t)
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- partition := createTestPartitionForSQL()
- partitionKey := partitionKey(partition)
-
- // Add offset mappings
- for i := int64(0); i < 10; i++ {
- err := storage.SaveOffsetMapping(partitionKey, i, i*1000, 100)
- if err != nil {
- t.Fatalf("Failed to save offset mapping: %v", err)
- }
- }
-
- // Get range of offsets
- entries, err := storage.GetOffsetMappingsByRange(partitionKey, 3, 7)
- if err != nil {
- t.Fatalf("Failed to get offset range: %v", err)
- }
-
- expectedCount := 5 // offsets 3, 4, 5, 6, 7
- if len(entries) != expectedCount {
- t.Errorf("Expected %d entries, got %d", expectedCount, len(entries))
- }
-
- // Verify range
- for i, entry := range entries {
- expectedOffset := int64(3 + i)
- if entry.KafkaOffset != expectedOffset {
- t.Errorf("Entry %d: expected offset %d, got %d", i, expectedOffset, entry.KafkaOffset)
- }
- }
-}
-
-func TestSQLOffsetStorage_GetPartitionStats(t *testing.T) {
- db := createTestDB(t)
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- partition := createTestPartitionForSQL()
- partitionKey := partitionKey(partition)
-
- // Test empty partition stats
- stats, err := storage.GetPartitionStats(partitionKey)
- if err != nil {
- t.Fatalf("Failed to get empty partition stats: %v", err)
- }
-
- if stats.RecordCount != 0 {
- t.Errorf("Expected record count 0, got %d", stats.RecordCount)
- }
-
- if stats.EarliestOffset != -1 {
- t.Errorf("Expected earliest offset -1, got %d", stats.EarliestOffset)
- }
-
- // Add some data
- sizes := []int32{100, 150, 200}
- for i, size := range sizes {
- err := storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), size)
- if err != nil {
- t.Fatalf("Failed to save offset mapping: %v", err)
- }
- }
-
- // Get stats with data
- stats, err = storage.GetPartitionStats(partitionKey)
- if err != nil {
- t.Fatalf("Failed to get partition stats: %v", err)
- }
-
- if stats.RecordCount != 3 {
- t.Errorf("Expected record count 3, got %d", stats.RecordCount)
- }
-
- if stats.EarliestOffset != 0 {
- t.Errorf("Expected earliest offset 0, got %d", stats.EarliestOffset)
- }
-
- if stats.LatestOffset != 2 {
- t.Errorf("Expected latest offset 2, got %d", stats.LatestOffset)
- }
-
- if stats.HighWaterMark != 3 {
- t.Errorf("Expected high water mark 3, got %d", stats.HighWaterMark)
- }
-
- expectedTotalSize := int64(100 + 150 + 200)
- if stats.TotalSize != expectedTotalSize {
- t.Errorf("Expected total size %d, got %d", expectedTotalSize, stats.TotalSize)
- }
-}
-
-func TestSQLOffsetStorage_GetAllPartitions(t *testing.T) {
- db := createTestDB(t)
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- // Test empty database
- partitions, err := storage.GetAllPartitions()
- if err != nil {
- t.Fatalf("Failed to get all partitions: %v", err)
- }
-
- if len(partitions) != 0 {
- t.Errorf("Expected 0 partitions, got %d", len(partitions))
- }
-
- // Add data for multiple partitions
- partition1 := createTestPartitionForSQL()
- partition2 := &schema_pb.Partition{
- RingSize: 1024,
- RangeStart: 32,
- RangeStop: 63,
- UnixTimeNs: time.Now().UnixNano(),
- }
-
- partitionKey1 := partitionKey(partition1)
- partitionKey2 := partitionKey(partition2)
-
- storage.SaveOffsetMapping(partitionKey1, 0, 1000, 100)
- storage.SaveOffsetMapping(partitionKey2, 0, 2000, 150)
-
- // Get all partitions
- partitions, err = storage.GetAllPartitions()
- if err != nil {
- t.Fatalf("Failed to get all partitions: %v", err)
- }
-
- if len(partitions) != 2 {
- t.Errorf("Expected 2 partitions, got %d", len(partitions))
- }
-
- // Verify partition keys are present
- partitionMap := make(map[string]bool)
- for _, p := range partitions {
- partitionMap[p] = true
- }
-
- if !partitionMap[partitionKey1] {
- t.Errorf("Partition key %s not found", partitionKey1)
- }
-
- if !partitionMap[partitionKey2] {
- t.Errorf("Partition key %s not found", partitionKey2)
- }
-}
-
-func TestSQLOffsetStorage_CleanupOldMappings(t *testing.T) {
- db := createTestDB(t)
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- partition := createTestPartitionForSQL()
- partitionKey := partitionKey(partition)
-
- // Add mappings with different timestamps
- now := time.Now().UnixNano()
-
- // Add old mapping by directly inserting with old timestamp
- oldTime := now - (24 * time.Hour).Nanoseconds() // 24 hours ago
- _, err = db.Exec(`
- INSERT INTO offset_mappings
- (partition_key, kafka_offset, smq_timestamp, message_size, created_at)
- VALUES (?, ?, ?, ?, ?)
- `, partitionKey, 0, oldTime, 100, oldTime)
- if err != nil {
- t.Fatalf("Failed to insert old mapping: %v", err)
- }
-
- // Add recent mapping
- storage.SaveOffsetMapping(partitionKey, 1, now, 150)
-
- // Verify both mappings exist
- entries, err := storage.LoadOffsetMappings(partitionKey)
- if err != nil {
- t.Fatalf("Failed to load mappings: %v", err)
- }
-
- if len(entries) != 2 {
- t.Errorf("Expected 2 mappings before cleanup, got %d", len(entries))
- }
-
- // Cleanup old mappings (older than 12 hours)
- cutoffTime := now - (12 * time.Hour).Nanoseconds()
- err = storage.CleanupOldMappings(cutoffTime)
- if err != nil {
- t.Fatalf("Failed to cleanup old mappings: %v", err)
- }
-
- // Verify only recent mapping remains
- entries, err = storage.LoadOffsetMappings(partitionKey)
- if err != nil {
- t.Fatalf("Failed to load mappings after cleanup: %v", err)
- }
-
- if len(entries) != 1 {
- t.Errorf("Expected 1 mapping after cleanup, got %d", len(entries))
- }
-
- if entries[0].KafkaOffset != 1 {
- t.Errorf("Expected remaining mapping offset 1, got %d", entries[0].KafkaOffset)
- }
-}
-
-func TestSQLOffsetStorage_Vacuum(t *testing.T) {
- db := createTestDB(t)
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- // Vacuum should not fail on empty database
- err = storage.Vacuum()
- if err != nil {
- t.Fatalf("Failed to vacuum database: %v", err)
- }
-
- // Add some data and vacuum again
- partition := createTestPartitionForSQL()
- partitionKey := partitionKey(partition)
- storage.SaveOffsetMapping(partitionKey, 0, 1000, 100)
-
- err = storage.Vacuum()
- if err != nil {
- t.Fatalf("Failed to vacuum database with data: %v", err)
- }
-}
-
-func TestSQLOffsetStorage_ConcurrentAccess(t *testing.T) {
- db := createTestDB(t)
- storage, err := NewSQLOffsetStorage(db)
- if err != nil {
- t.Fatalf("Failed to create SQL storage: %v", err)
- }
- defer storage.Close()
-
- partition := createTestPartitionForSQL()
- partitionKey := partitionKey(partition)
-
- // Test concurrent writes
- const numGoroutines = 10
- const offsetsPerGoroutine = 10
-
- done := make(chan bool, numGoroutines)
-
- for i := 0; i < numGoroutines; i++ {
- go func(goroutineID int) {
- defer func() { done <- true }()
-
- for j := 0; j < offsetsPerGoroutine; j++ {
- offset := int64(goroutineID*offsetsPerGoroutine + j)
- err := storage.SaveOffsetMapping(partitionKey, offset, offset*1000, 100)
- if err != nil {
- t.Errorf("Failed to save offset mapping %d: %v", offset, err)
- return
- }
- }
- }(i)
- }
-
- // Wait for all goroutines to complete
- for i := 0; i < numGoroutines; i++ {
- <-done
- }
-
- // Verify all mappings were saved
- entries, err := storage.LoadOffsetMappings(partitionKey)
- if err != nil {
- t.Fatalf("Failed to load mappings: %v", err)
- }
-
- expectedCount := numGoroutines * offsetsPerGoroutine
- if len(entries) != expectedCount {
- t.Errorf("Expected %d mappings, got %d", expectedCount, len(entries))
- }
-}
diff --git a/weed/mq/offset/subscriber_test.go b/weed/mq/offset/subscriber_test.go
deleted file mode 100644
index 1ab97dadc..000000000
--- a/weed/mq/offset/subscriber_test.go
+++ /dev/null
@@ -1,457 +0,0 @@
-package offset
-
-import (
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
-)
-
-func TestOffsetSubscriber_CreateSubscription(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- subscriber := NewOffsetSubscriber(registry)
- partition := createTestPartition()
-
- // Assign some offsets first
- registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
-
- // Test EXACT_OFFSET subscription
- sub, err := subscriber.CreateSubscription("test-sub-1", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5)
- if err != nil {
- t.Fatalf("Failed to create EXACT_OFFSET subscription: %v", err)
- }
-
- if sub.StartOffset != 5 {
- t.Errorf("Expected start offset 5, got %d", sub.StartOffset)
- }
- if sub.CurrentOffset != 5 {
- t.Errorf("Expected current offset 5, got %d", sub.CurrentOffset)
- }
-
- // Test RESET_TO_LATEST subscription
- sub2, err := subscriber.CreateSubscription("test-sub-2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0)
- if err != nil {
- t.Fatalf("Failed to create RESET_TO_LATEST subscription: %v", err)
- }
-
- if sub2.StartOffset != 10 { // Should be at high water mark
- t.Errorf("Expected start offset 10, got %d", sub2.StartOffset)
- }
-}
-
-func TestOffsetSubscriber_InvalidSubscription(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- subscriber := NewOffsetSubscriber(registry)
- partition := createTestPartition()
-
- // Assign some offsets
- registry.AssignOffsets("test-namespace", "test-topic", partition, 5)
-
- // Test invalid offset (beyond high water mark)
- _, err := subscriber.CreateSubscription("invalid-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 10)
- if err == nil {
- t.Error("Expected error for offset beyond high water mark")
- }
-
- // Test negative offset
- _, err = subscriber.CreateSubscription("invalid-sub-2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, -1)
- if err == nil {
- t.Error("Expected error for negative offset")
- }
-}
-
-func TestOffsetSubscriber_DuplicateSubscription(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- subscriber := NewOffsetSubscriber(registry)
- partition := createTestPartition()
-
- // Create first subscription
- _, err := subscriber.CreateSubscription("duplicate-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
- if err != nil {
- t.Fatalf("Failed to create first subscription: %v", err)
- }
-
- // Try to create duplicate
- _, err = subscriber.CreateSubscription("duplicate-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
- if err == nil {
- t.Error("Expected error for duplicate subscription ID")
- }
-}
-
-func TestOffsetSubscription_SeekToOffset(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- subscriber := NewOffsetSubscriber(registry)
- partition := createTestPartition()
-
- // Assign offsets
- registry.AssignOffsets("test-namespace", "test-topic", partition, 20)
-
- // Create subscription
- sub, err := subscriber.CreateSubscription("seek-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Test valid seek
- err = sub.SeekToOffset(10)
- if err != nil {
- t.Fatalf("Failed to seek to offset 10: %v", err)
- }
-
- if sub.CurrentOffset != 10 {
- t.Errorf("Expected current offset 10, got %d", sub.CurrentOffset)
- }
-
- // Test invalid seek (beyond high water mark)
- err = sub.SeekToOffset(25)
- if err == nil {
- t.Error("Expected error for seek beyond high water mark")
- }
-
- // Test negative seek
- err = sub.SeekToOffset(-1)
- if err == nil {
- t.Error("Expected error for negative seek offset")
- }
-}
-
-func TestOffsetSubscription_AdvanceOffset(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- subscriber := NewOffsetSubscriber(registry)
- partition := createTestPartition()
-
- // Create subscription
- sub, err := subscriber.CreateSubscription("advance-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Test single advance
- initialOffset := sub.GetNextOffset()
- sub.AdvanceOffset()
-
- if sub.GetNextOffset() != initialOffset+1 {
- t.Errorf("Expected offset %d, got %d", initialOffset+1, sub.GetNextOffset())
- }
-
- // Test batch advance
- sub.AdvanceOffsetBy(5)
-
- if sub.GetNextOffset() != initialOffset+6 {
- t.Errorf("Expected offset %d, got %d", initialOffset+6, sub.GetNextOffset())
- }
-}
-
-func TestOffsetSubscription_GetLag(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- subscriber := NewOffsetSubscriber(registry)
- partition := createTestPartition()
-
- // Assign offsets
- registry.AssignOffsets("test-namespace", "test-topic", partition, 15)
-
- // Create subscription at offset 5
- sub, err := subscriber.CreateSubscription("lag-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5)
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Check initial lag
- lag, err := sub.GetLag()
- if err != nil {
- t.Fatalf("Failed to get lag: %v", err)
- }
-
- expectedLag := int64(15 - 5) // hwm - current
- if lag != expectedLag {
- t.Errorf("Expected lag %d, got %d", expectedLag, lag)
- }
-
- // Advance and check lag again
- sub.AdvanceOffsetBy(3)
-
- lag, err = sub.GetLag()
- if err != nil {
- t.Fatalf("Failed to get lag after advance: %v", err)
- }
-
- expectedLag = int64(15 - 8) // hwm - current
- if lag != expectedLag {
- t.Errorf("Expected lag %d after advance, got %d", expectedLag, lag)
- }
-}
-
-func TestOffsetSubscription_IsAtEnd(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- subscriber := NewOffsetSubscriber(registry)
- partition := createTestPartition()
-
- // Assign offsets
- registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
-
- // Create subscription at end
- sub, err := subscriber.CreateSubscription("end-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0)
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Should be at end
- atEnd, err := sub.IsAtEnd()
- if err != nil {
- t.Fatalf("Failed to check if at end: %v", err)
- }
-
- if !atEnd {
- t.Error("Expected subscription to be at end")
- }
-
- // Seek to middle and check again
- sub.SeekToOffset(5)
-
- atEnd, err = sub.IsAtEnd()
- if err != nil {
- t.Fatalf("Failed to check if at end after seek: %v", err)
- }
-
- if atEnd {
- t.Error("Expected subscription not to be at end after seek")
- }
-}
-
-func TestOffsetSubscription_GetOffsetRange(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- subscriber := NewOffsetSubscriber(registry)
- partition := createTestPartition()
-
- // Assign offsets
- registry.AssignOffsets("test-namespace", "test-topic", partition, 20)
-
- // Create subscription
- sub, err := subscriber.CreateSubscription("range-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5)
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Test normal range
- offsetRange, err := sub.GetOffsetRange(10)
- if err != nil {
- t.Fatalf("Failed to get offset range: %v", err)
- }
-
- if offsetRange.StartOffset != 5 {
- t.Errorf("Expected start offset 5, got %d", offsetRange.StartOffset)
- }
- if offsetRange.EndOffset != 14 {
- t.Errorf("Expected end offset 14, got %d", offsetRange.EndOffset)
- }
- if offsetRange.Count != 10 {
- t.Errorf("Expected count 10, got %d", offsetRange.Count)
- }
-
- // Test range that exceeds high water mark
- sub.SeekToOffset(15)
- offsetRange, err = sub.GetOffsetRange(10)
- if err != nil {
- t.Fatalf("Failed to get offset range near end: %v", err)
- }
-
- if offsetRange.StartOffset != 15 {
- t.Errorf("Expected start offset 15, got %d", offsetRange.StartOffset)
- }
- if offsetRange.EndOffset != 19 { // Should be capped at hwm-1
- t.Errorf("Expected end offset 19, got %d", offsetRange.EndOffset)
- }
- if offsetRange.Count != 5 {
- t.Errorf("Expected count 5, got %d", offsetRange.Count)
- }
-}
-
-func TestOffsetSubscription_EmptyRange(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- subscriber := NewOffsetSubscriber(registry)
- partition := createTestPartition()
-
- // Assign offsets
- registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
-
- // Create subscription at end
- sub, err := subscriber.CreateSubscription("empty-range-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0)
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Request range when at end
- offsetRange, err := sub.GetOffsetRange(5)
- if err != nil {
- t.Fatalf("Failed to get offset range at end: %v", err)
- }
-
- if offsetRange.Count != 0 {
- t.Errorf("Expected empty range (count 0), got count %d", offsetRange.Count)
- }
-
- if offsetRange.StartOffset != 10 {
- t.Errorf("Expected start offset 10, got %d", offsetRange.StartOffset)
- }
-
- if offsetRange.EndOffset != 9 { // Empty range: end < start
- t.Errorf("Expected end offset 9 (empty range), got %d", offsetRange.EndOffset)
- }
-}
-
-func TestOffsetSeeker_ValidateOffsetRange(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- seeker := NewOffsetSeeker(registry)
- partition := createTestPartition()
-
- // Assign offsets
- registry.AssignOffsets("test-namespace", "test-topic", partition, 15)
-
- // Test valid range
- err := seeker.ValidateOffsetRange("test-namespace", "test-topic", partition, 5, 10)
- if err != nil {
- t.Errorf("Valid range should not return error: %v", err)
- }
-
- // Test invalid ranges
- testCases := []struct {
- name string
- startOffset int64
- endOffset int64
- expectError bool
- }{
- {"negative start", -1, 5, true},
- {"end before start", 10, 5, true},
- {"start beyond hwm", 20, 25, true},
- {"valid range", 0, 14, false},
- {"single offset", 5, 5, false},
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- err := seeker.ValidateOffsetRange("test-namespace", "test-topic", partition, tc.startOffset, tc.endOffset)
- if tc.expectError && err == nil {
- t.Error("Expected error but got none")
- }
- if !tc.expectError && err != nil {
- t.Errorf("Expected no error but got: %v", err)
- }
- })
- }
-}
-
-func TestOffsetSeeker_GetAvailableOffsetRange(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- seeker := NewOffsetSeeker(registry)
- partition := createTestPartition()
-
- // Test empty partition
- offsetRange, err := seeker.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get available range for empty partition: %v", err)
- }
-
- if offsetRange.Count != 0 {
- t.Errorf("Expected empty range for empty partition, got count %d", offsetRange.Count)
- }
-
- // Assign offsets and test again
- registry.AssignOffsets("test-namespace", "test-topic", partition, 25)
-
- offsetRange, err = seeker.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
- if err != nil {
- t.Fatalf("Failed to get available range: %v", err)
- }
-
- if offsetRange.StartOffset != 0 {
- t.Errorf("Expected start offset 0, got %d", offsetRange.StartOffset)
- }
- if offsetRange.EndOffset != 24 {
- t.Errorf("Expected end offset 24, got %d", offsetRange.EndOffset)
- }
- if offsetRange.Count != 25 {
- t.Errorf("Expected count 25, got %d", offsetRange.Count)
- }
-}
-
-func TestOffsetSubscriber_CloseSubscription(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- subscriber := NewOffsetSubscriber(registry)
- partition := createTestPartition()
-
- // Create subscription
- sub, err := subscriber.CreateSubscription("close-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- // Verify subscription exists
- _, err = subscriber.GetSubscription("close-test")
- if err != nil {
- t.Fatalf("Subscription should exist: %v", err)
- }
-
- // Close subscription
- err = subscriber.CloseSubscription("close-test")
- if err != nil {
- t.Fatalf("Failed to close subscription: %v", err)
- }
-
- // Verify subscription is gone
- _, err = subscriber.GetSubscription("close-test")
- if err == nil {
- t.Error("Subscription should not exist after close")
- }
-
- // Verify subscription is marked inactive
- if sub.IsActive {
- t.Error("Subscription should be marked inactive after close")
- }
-}
-
-func TestOffsetSubscription_InactiveOperations(t *testing.T) {
- storage := NewInMemoryOffsetStorage()
- registry := NewPartitionOffsetRegistry(storage)
- subscriber := NewOffsetSubscriber(registry)
- partition := createTestPartition()
-
- // Create and close subscription
- sub, err := subscriber.CreateSubscription("inactive-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
- if err != nil {
- t.Fatalf("Failed to create subscription: %v", err)
- }
-
- subscriber.CloseSubscription("inactive-test")
-
- // Test operations on inactive subscription
- err = sub.SeekToOffset(5)
- if err == nil {
- t.Error("Expected error for seek on inactive subscription")
- }
-
- _, err = sub.GetLag()
- if err == nil {
- t.Error("Expected error for GetLag on inactive subscription")
- }
-
- _, err = sub.IsAtEnd()
- if err == nil {
- t.Error("Expected error for IsAtEnd on inactive subscription")
- }
-
- _, err = sub.GetOffsetRange(10)
- if err == nil {
- t.Error("Expected error for GetOffsetRange on inactive subscription")
- }
-}
diff --git a/weed/mq/pub_balancer/repair.go b/weed/mq/pub_balancer/repair.go
index 9af81d27f..549843978 100644
--- a/weed/mq/pub_balancer/repair.go
+++ b/weed/mq/pub_balancer/repair.go
@@ -1,13 +1,6 @@
package pub_balancer
-import (
- "math/rand/v2"
- "sort"
-
- cmap "github.com/orcaman/concurrent-map/v2"
- "github.com/seaweedfs/seaweedfs/weed/mq/topic"
- "modernc.org/mathutil"
-)
+import ()
func (balancer *PubBalancer) RepairTopics() []BalanceAction {
action := BalanceTopicPartitionOnBrokers(balancer.Brokers)
@@ -17,107 +10,3 @@ func (balancer *PubBalancer) RepairTopics() []BalanceAction {
type TopicPartitionInfo struct {
Broker string
}
-
-// RepairMissingTopicPartitions check the stats of all brokers,
-// and repair the missing topic partitions on the brokers.
-func RepairMissingTopicPartitions(brokers cmap.ConcurrentMap[string, *BrokerStats]) (actions []BalanceAction) {
-
- // find all topic partitions
- topicToTopicPartitions := make(map[topic.Topic]map[topic.Partition]*TopicPartitionInfo)
- for brokerStatsItem := range brokers.IterBuffered() {
- broker, brokerStats := brokerStatsItem.Key, brokerStatsItem.Val
- for topicPartitionStatsItem := range brokerStats.TopicPartitionStats.IterBuffered() {
- topicPartitionStat := topicPartitionStatsItem.Val
- topicPartitionToInfo, found := topicToTopicPartitions[topicPartitionStat.Topic]
- if !found {
- topicPartitionToInfo = make(map[topic.Partition]*TopicPartitionInfo)
- topicToTopicPartitions[topicPartitionStat.Topic] = topicPartitionToInfo
- }
- tpi, found := topicPartitionToInfo[topicPartitionStat.Partition]
- if !found {
- tpi = &TopicPartitionInfo{}
- topicPartitionToInfo[topicPartitionStat.Partition] = tpi
- }
- tpi.Broker = broker
- }
- }
-
- // collect all brokers as candidates
- candidates := make([]string, 0, brokers.Count())
- for brokerStatsItem := range brokers.IterBuffered() {
- candidates = append(candidates, brokerStatsItem.Key)
- }
-
- // find the missing topic partitions
- for t, topicPartitionToInfo := range topicToTopicPartitions {
- missingPartitions := EachTopicRepairMissingTopicPartitions(t, topicPartitionToInfo)
- for _, partition := range missingPartitions {
- actions = append(actions, BalanceActionCreate{
- TopicPartition: topic.TopicPartition{
- Topic: t,
- Partition: partition,
- },
- TargetBroker: candidates[rand.IntN(len(candidates))],
- })
- }
- }
-
- return actions
-}
-
-func EachTopicRepairMissingTopicPartitions(t topic.Topic, info map[topic.Partition]*TopicPartitionInfo) (missingPartitions []topic.Partition) {
-
- // find the missing topic partitions
- var partitions []topic.Partition
- for partition := range info {
- partitions = append(partitions, partition)
- }
- return findMissingPartitions(partitions, MaxPartitionCount)
-}
-
-// findMissingPartitions find the missing partitions
-func findMissingPartitions(partitions []topic.Partition, ringSize int32) (missingPartitions []topic.Partition) {
- // sort the partitions by range start
- sort.Slice(partitions, func(i, j int) bool {
- return partitions[i].RangeStart < partitions[j].RangeStart
- })
-
- // calculate the average partition size
- var covered int32
- for _, partition := range partitions {
- covered += partition.RangeStop - partition.RangeStart
- }
- averagePartitionSize := covered / int32(len(partitions))
-
- // find the missing partitions
- var coveredWatermark int32
- i := 0
- for i < len(partitions) {
- partition := partitions[i]
- if partition.RangeStart > coveredWatermark {
- upperBound := mathutil.MinInt32(coveredWatermark+averagePartitionSize, partition.RangeStart)
- missingPartitions = append(missingPartitions, topic.Partition{
- RangeStart: coveredWatermark,
- RangeStop: upperBound,
- RingSize: ringSize,
- })
- coveredWatermark = upperBound
- if coveredWatermark == partition.RangeStop {
- i++
- }
- } else {
- coveredWatermark = partition.RangeStop
- i++
- }
- }
- for coveredWatermark < ringSize {
- upperBound := mathutil.MinInt32(coveredWatermark+averagePartitionSize, ringSize)
- missingPartitions = append(missingPartitions, topic.Partition{
- RangeStart: coveredWatermark,
- RangeStop: upperBound,
- RingSize: ringSize,
- })
- coveredWatermark = upperBound
- }
- return missingPartitions
-}
diff --git a/weed/mq/pub_balancer/repair_test.go b/weed/mq/pub_balancer/repair_test.go
deleted file mode 100644
index 4ccf59e13..000000000
--- a/weed/mq/pub_balancer/repair_test.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package pub_balancer
-
-import (
- "reflect"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/mq/topic"
-)
-
-func Test_findMissingPartitions(t *testing.T) {
- type args struct {
- partitions []topic.Partition
- }
- tests := []struct {
- name string
- args args
- wantMissingPartitions []topic.Partition
- }{
- {
- name: "one partition",
- args: args{
- partitions: []topic.Partition{
- {RingSize: 1024, RangeStart: 0, RangeStop: 1024},
- },
- },
- wantMissingPartitions: nil,
- },
- {
- name: "two partitions",
- args: args{
- partitions: []topic.Partition{
- {RingSize: 1024, RangeStart: 0, RangeStop: 512},
- {RingSize: 1024, RangeStart: 512, RangeStop: 1024},
- },
- },
- wantMissingPartitions: nil,
- },
- {
- name: "four partitions, missing last two",
- args: args{
- partitions: []topic.Partition{
- {RingSize: 1024, RangeStart: 0, RangeStop: 256},
- {RingSize: 1024, RangeStart: 256, RangeStop: 512},
- },
- },
- wantMissingPartitions: []topic.Partition{
- {RingSize: 1024, RangeStart: 512, RangeStop: 768},
- {RingSize: 1024, RangeStart: 768, RangeStop: 1024},
- },
- },
- {
- name: "four partitions, missing first two",
- args: args{
- partitions: []topic.Partition{
- {RingSize: 1024, RangeStart: 512, RangeStop: 768},
- {RingSize: 1024, RangeStart: 768, RangeStop: 1024},
- },
- },
- wantMissingPartitions: []topic.Partition{
- {RingSize: 1024, RangeStart: 0, RangeStop: 256},
- {RingSize: 1024, RangeStart: 256, RangeStop: 512},
- },
- },
- {
- name: "four partitions, missing middle two",
- args: args{
- partitions: []topic.Partition{
- {RingSize: 1024, RangeStart: 0, RangeStop: 256},
- {RingSize: 1024, RangeStart: 768, RangeStop: 1024},
- },
- },
- wantMissingPartitions: []topic.Partition{
- {RingSize: 1024, RangeStart: 256, RangeStop: 512},
- {RingSize: 1024, RangeStart: 512, RangeStop: 768},
- },
- },
- {
- name: "four partitions, missing three",
- args: args{
- partitions: []topic.Partition{
- {RingSize: 1024, RangeStart: 512, RangeStop: 768},
- },
- },
- wantMissingPartitions: []topic.Partition{
- {RingSize: 1024, RangeStart: 0, RangeStop: 256},
- {RingSize: 1024, RangeStart: 256, RangeStop: 512},
- {RingSize: 1024, RangeStart: 768, RangeStop: 1024},
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if gotMissingPartitions := findMissingPartitions(tt.args.partitions, 1024); !reflect.DeepEqual(gotMissingPartitions, tt.wantMissingPartitions) {
- t.Errorf("findMissingPartitions() = %v, want %v", gotMissingPartitions, tt.wantMissingPartitions)
- }
- })
- }
-}
diff --git a/weed/mq/segment/message_serde.go b/weed/mq/segment/message_serde.go
deleted file mode 100644
index 66a76c57d..000000000
--- a/weed/mq/segment/message_serde.go
+++ /dev/null
@@ -1,109 +0,0 @@
-package segment
-
-import (
- flatbuffers "github.com/google/flatbuffers/go"
- "github.com/seaweedfs/seaweedfs/weed/pb/message_fbs"
-)
-
-type MessageBatchBuilder struct {
- b *flatbuffers.Builder
- producerId int32
- producerEpoch int32
- segmentId int32
- flags int32
- messageOffsets []flatbuffers.UOffsetT
- segmentSeqBase int64
- segmentSeqLast int64
- tsMsBase int64
- tsMsLast int64
-}
-
-func NewMessageBatchBuilder(b *flatbuffers.Builder,
- producerId int32,
- producerEpoch int32,
- segmentId int32,
- flags int32) *MessageBatchBuilder {
-
- b.Reset()
-
- return &MessageBatchBuilder{
- b: b,
- producerId: producerId,
- producerEpoch: producerEpoch,
- segmentId: segmentId,
- flags: flags,
- }
-}
-
-func (builder *MessageBatchBuilder) AddMessage(segmentSeq int64, tsMs int64, properties map[string][]byte, key []byte, value []byte) {
- if builder.segmentSeqBase == 0 {
- builder.segmentSeqBase = segmentSeq
- }
- builder.segmentSeqLast = segmentSeq
- if builder.tsMsBase == 0 {
- builder.tsMsBase = tsMs
- }
- builder.tsMsLast = tsMs
-
- var names, values, pairs []flatbuffers.UOffsetT
- for k, v := range properties {
- names = append(names, builder.b.CreateString(k))
- values = append(values, builder.b.CreateByteVector(v))
- }
- for i, _ := range names {
- message_fbs.NameValueStart(builder.b)
- message_fbs.NameValueAddName(builder.b, names[i])
- message_fbs.NameValueAddValue(builder.b, values[i])
- pair := message_fbs.NameValueEnd(builder.b)
- pairs = append(pairs, pair)
- }
-
- message_fbs.MessageStartPropertiesVector(builder.b, len(properties))
- for i := len(pairs) - 1; i >= 0; i-- {
- builder.b.PrependUOffsetT(pairs[i])
- }
- propOffset := builder.b.EndVector(len(properties))
-
- keyOffset := builder.b.CreateByteVector(key)
- valueOffset := builder.b.CreateByteVector(value)
-
- message_fbs.MessageStart(builder.b)
- message_fbs.MessageAddSeqDelta(builder.b, int32(segmentSeq-builder.segmentSeqBase))
- message_fbs.MessageAddTsMsDelta(builder.b, int32(tsMs-builder.tsMsBase))
-
- message_fbs.MessageAddProperties(builder.b, propOffset)
- message_fbs.MessageAddKey(builder.b, keyOffset)
- message_fbs.MessageAddData(builder.b, valueOffset)
- messageOffset := message_fbs.MessageEnd(builder.b)
-
- builder.messageOffsets = append(builder.messageOffsets, messageOffset)
-
-}
-
-func (builder *MessageBatchBuilder) BuildMessageBatch() {
- message_fbs.MessageBatchStartMessagesVector(builder.b, len(builder.messageOffsets))
- for i := len(builder.messageOffsets) - 1; i >= 0; i-- {
- builder.b.PrependUOffsetT(builder.messageOffsets[i])
- }
- messagesOffset := builder.b.EndVector(len(builder.messageOffsets))
-
- message_fbs.MessageBatchStart(builder.b)
- message_fbs.MessageBatchAddProducerId(builder.b, builder.producerId)
- message_fbs.MessageBatchAddProducerEpoch(builder.b, builder.producerEpoch)
- message_fbs.MessageBatchAddSegmentId(builder.b, builder.segmentId)
- message_fbs.MessageBatchAddFlags(builder.b, builder.flags)
- message_fbs.MessageBatchAddSegmentSeqBase(builder.b, builder.segmentSeqBase)
- message_fbs.MessageBatchAddSegmentSeqMaxDelta(builder.b, int32(builder.segmentSeqLast-builder.segmentSeqBase))
- message_fbs.MessageBatchAddTsMsBase(builder.b, builder.tsMsBase)
- message_fbs.MessageBatchAddTsMsMaxDelta(builder.b, int32(builder.tsMsLast-builder.tsMsBase))
-
- message_fbs.MessageBatchAddMessages(builder.b, messagesOffset)
-
- messageBatch := message_fbs.MessageBatchEnd(builder.b)
-
- builder.b.Finish(messageBatch)
-}
-
-func (builder *MessageBatchBuilder) GetBytes() []byte {
- return builder.b.FinishedBytes()
-}
diff --git a/weed/mq/segment/message_serde_test.go b/weed/mq/segment/message_serde_test.go
deleted file mode 100644
index 52c9d8e55..000000000
--- a/weed/mq/segment/message_serde_test.go
+++ /dev/null
@@ -1,61 +0,0 @@
-package segment
-
-import (
- "testing"
-
- flatbuffers "github.com/google/flatbuffers/go"
- "github.com/seaweedfs/seaweedfs/weed/pb/message_fbs"
- "github.com/stretchr/testify/assert"
-)
-
-func TestMessageSerde(t *testing.T) {
- b := flatbuffers.NewBuilder(1024)
-
- prop := make(map[string][]byte)
- prop["n1"] = []byte("v1")
- prop["n2"] = []byte("v2")
-
- bb := NewMessageBatchBuilder(b, 1, 2, 3, 4)
-
- bb.AddMessage(5, 6, prop, []byte("the primary key"), []byte("body is here"))
- bb.AddMessage(5, 7, prop, []byte("the primary 2"), []byte("body is 2"))
-
- bb.BuildMessageBatch()
-
- buf := bb.GetBytes()
-
- println("serialized size", len(buf))
-
- mb := message_fbs.GetRootAsMessageBatch(buf, 0)
-
- assert.Equal(t, int32(1), mb.ProducerId())
- assert.Equal(t, int32(2), mb.ProducerEpoch())
- assert.Equal(t, int32(3), mb.SegmentId())
- assert.Equal(t, int32(4), mb.Flags())
- assert.Equal(t, int64(5), mb.SegmentSeqBase())
- assert.Equal(t, int32(0), mb.SegmentSeqMaxDelta())
- assert.Equal(t, int64(6), mb.TsMsBase())
- assert.Equal(t, int32(1), mb.TsMsMaxDelta())
-
- assert.Equal(t, 2, mb.MessagesLength())
-
- m := &message_fbs.Message{}
- mb.Messages(m, 0)
-
- /*
- // the vector seems not consistent
- nv := &message_fbs.NameValue{}
- m.Properties(nv, 0)
- assert.Equal(t, "n1", string(nv.Name()))
- assert.Equal(t, "v1", string(nv.Value()))
- m.Properties(nv, 1)
- assert.Equal(t, "n2", string(nv.Name()))
- assert.Equal(t, "v2", string(nv.Value()))
- */
- assert.Equal(t, []byte("the primary key"), m.Key())
- assert.Equal(t, []byte("body is here"), m.Data())
-
- assert.Equal(t, int32(0), m.SeqDelta())
- assert.Equal(t, int32(0), m.TsMsDelta())
-
-}
diff --git a/weed/mq/sub_coordinator/inflight_message_tracker.go b/weed/mq/sub_coordinator/inflight_message_tracker.go
index 8ecbb2ccd..c78e7883e 100644
--- a/weed/mq/sub_coordinator/inflight_message_tracker.go
+++ b/weed/mq/sub_coordinator/inflight_message_tracker.go
@@ -28,28 +28,6 @@ func (imt *InflightMessageTracker) EnflightMessage(key []byte, tsNs int64) {
imt.timestamps.EnflightTimestamp(tsNs)
}
-// IsMessageAcknowledged returns true if the message has been acknowledged.
-// If the message is older than the oldest inflight messages, returns false.
-// returns false if the message is inflight.
-// Otherwise, returns false if the message is old and can be ignored.
-func (imt *InflightMessageTracker) IsMessageAcknowledged(key []byte, tsNs int64) bool {
- imt.mu.Lock()
- defer imt.mu.Unlock()
-
- if tsNs <= imt.timestamps.OldestAckedTimestamp() {
- return true
- }
- if tsNs > imt.timestamps.Latest() {
- return false
- }
-
- if _, found := imt.messages[string(key)]; found {
- return false
- }
-
- return true
-}
-
// AcknowledgeMessage acknowledges the message with the key and timestamp.
func (imt *InflightMessageTracker) AcknowledgeMessage(key []byte, tsNs int64) bool {
// fmt.Printf("AcknowledgeMessage(%s,%d)\n", string(key), tsNs)
@@ -164,8 +142,3 @@ func (rb *RingBuffer) AckTimestamp(timestamp int64) {
func (rb *RingBuffer) OldestAckedTimestamp() int64 {
return rb.maxAllAckedTs
}
-
-// Latest returns the most recently known timestamp in the ring buffer.
-func (rb *RingBuffer) Latest() int64 {
- return rb.maxTimestamp
-}
diff --git a/weed/mq/sub_coordinator/inflight_message_tracker_test.go b/weed/mq/sub_coordinator/inflight_message_tracker_test.go
deleted file mode 100644
index a5c63d561..000000000
--- a/weed/mq/sub_coordinator/inflight_message_tracker_test.go
+++ /dev/null
@@ -1,134 +0,0 @@
-package sub_coordinator
-
-import (
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestRingBuffer(t *testing.T) {
- // Initialize a RingBuffer with capacity 5
- rb := NewRingBuffer(5)
-
- // Add timestamps to the buffer
- timestamps := []int64{100, 200, 300, 400, 500}
- for _, ts := range timestamps {
- rb.EnflightTimestamp(ts)
- }
-
- // Test Add method and buffer size
- expectedSize := 5
- if rb.size != expectedSize {
- t.Errorf("Expected buffer size %d, got %d", expectedSize, rb.size)
- }
-
- assert.Equal(t, int64(0), rb.OldestAckedTimestamp())
- assert.Equal(t, int64(500), rb.Latest())
-
- rb.AckTimestamp(200)
- assert.Equal(t, int64(0), rb.OldestAckedTimestamp())
- rb.AckTimestamp(100)
- assert.Equal(t, int64(200), rb.OldestAckedTimestamp())
-
- rb.EnflightTimestamp(int64(600))
- rb.EnflightTimestamp(int64(700))
-
- rb.AckTimestamp(500)
- assert.Equal(t, int64(200), rb.OldestAckedTimestamp())
- rb.AckTimestamp(400)
- assert.Equal(t, int64(200), rb.OldestAckedTimestamp())
- rb.AckTimestamp(300)
- assert.Equal(t, int64(500), rb.OldestAckedTimestamp())
-
- assert.Equal(t, int64(700), rb.Latest())
-}
-
-func TestInflightMessageTracker(t *testing.T) {
- // Initialize an InflightMessageTracker with capacity 5
- tracker := NewInflightMessageTracker(5)
-
- // Add inflight messages
- key := []byte("1")
- timestamp := int64(1)
- tracker.EnflightMessage(key, timestamp)
-
- // Test IsMessageAcknowledged method
- isOld := tracker.IsMessageAcknowledged(key, timestamp-10)
- if !isOld {
- t.Error("Expected message to be old")
- }
-
- // Test AcknowledgeMessage method
- acked := tracker.AcknowledgeMessage(key, timestamp)
- if !acked {
- t.Error("Expected message to be acked")
- }
- if _, exists := tracker.messages[string(key)]; exists {
- t.Error("Expected message to be deleted after ack")
- }
- if tracker.timestamps.size != 0 {
- t.Error("Expected buffer size to be 0 after ack")
- }
- assert.Equal(t, timestamp, tracker.GetOldestAckedTimestamp())
-}
-
-func TestInflightMessageTracker2(t *testing.T) {
- // Initialize an InflightMessageTracker with initial capacity 1
- tracker := NewInflightMessageTracker(1)
-
- tracker.EnflightMessage([]byte("1"), int64(1))
- tracker.EnflightMessage([]byte("2"), int64(2))
- tracker.EnflightMessage([]byte("3"), int64(3))
- tracker.EnflightMessage([]byte("4"), int64(4))
- tracker.EnflightMessage([]byte("5"), int64(5))
- assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1)))
- assert.Equal(t, int64(1), tracker.GetOldestAckedTimestamp())
-
- // Test IsMessageAcknowledged method
- isAcked := tracker.IsMessageAcknowledged([]byte("2"), int64(2))
- if isAcked {
- t.Error("Expected message to be not acked")
- }
-
- // Test AcknowledgeMessage method
- assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2)))
- assert.Equal(t, int64(2), tracker.GetOldestAckedTimestamp())
-
-}
-
-func TestInflightMessageTracker3(t *testing.T) {
- // Initialize an InflightMessageTracker with initial capacity 1
- tracker := NewInflightMessageTracker(1)
-
- tracker.EnflightMessage([]byte("1"), int64(1))
- tracker.EnflightMessage([]byte("2"), int64(2))
- tracker.EnflightMessage([]byte("3"), int64(3))
- assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1)))
- tracker.EnflightMessage([]byte("4"), int64(4))
- tracker.EnflightMessage([]byte("5"), int64(5))
- assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2)))
- assert.True(t, tracker.AcknowledgeMessage([]byte("3"), int64(3)))
- tracker.EnflightMessage([]byte("6"), int64(6))
- tracker.EnflightMessage([]byte("7"), int64(7))
- assert.True(t, tracker.AcknowledgeMessage([]byte("4"), int64(4)))
- assert.True(t, tracker.AcknowledgeMessage([]byte("5"), int64(5)))
- assert.True(t, tracker.AcknowledgeMessage([]byte("6"), int64(6)))
- assert.Equal(t, int64(6), tracker.GetOldestAckedTimestamp())
- assert.True(t, tracker.AcknowledgeMessage([]byte("7"), int64(7)))
- assert.Equal(t, int64(7), tracker.GetOldestAckedTimestamp())
-
-}
-
-func TestInflightMessageTracker4(t *testing.T) {
- // Initialize an InflightMessageTracker with initial capacity 1
- tracker := NewInflightMessageTracker(1)
-
- tracker.EnflightMessage([]byte("1"), int64(1))
- tracker.EnflightMessage([]byte("2"), int64(2))
- assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1)))
- assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2)))
- tracker.EnflightMessage([]byte("3"), int64(3))
- assert.True(t, tracker.AcknowledgeMessage([]byte("3"), int64(3)))
- assert.Equal(t, int64(3), tracker.GetOldestAckedTimestamp())
-
-}
diff --git a/weed/mq/sub_coordinator/partition_consumer_mapping.go b/weed/mq/sub_coordinator/partition_consumer_mapping.go
index e4d00a0dd..ec3a6582f 100644
--- a/weed/mq/sub_coordinator/partition_consumer_mapping.go
+++ b/weed/mq/sub_coordinator/partition_consumer_mapping.go
@@ -1,130 +1,6 @@
package sub_coordinator
-import (
- "fmt"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
-)
-
type PartitionConsumerMapping struct {
currentMapping *PartitionSlotToConsumerInstanceList
prevMappings []*PartitionSlotToConsumerInstanceList
}
-
-// Balance goal:
-// 1. max processing power utilization
-// 2. allow one consumer instance to be down unexpectedly
-// without affecting the processing power utilization
-
-func (pcm *PartitionConsumerMapping) BalanceToConsumerInstances(partitionSlotToBrokerList *pub_balancer.PartitionSlotToBrokerList, consumerInstances []*ConsumerGroupInstance) {
- if len(partitionSlotToBrokerList.PartitionSlots) == 0 || len(consumerInstances) == 0 {
- return
- }
- newMapping := NewPartitionSlotToConsumerInstanceList(partitionSlotToBrokerList.RingSize, time.Now())
- var prevMapping *PartitionSlotToConsumerInstanceList
- if len(pcm.prevMappings) > 0 {
- prevMapping = pcm.prevMappings[len(pcm.prevMappings)-1]
- } else {
- prevMapping = nil
- }
- newMapping.PartitionSlots = doBalanceSticky(partitionSlotToBrokerList.PartitionSlots, consumerInstances, prevMapping)
- if pcm.currentMapping != nil {
- pcm.prevMappings = append(pcm.prevMappings, pcm.currentMapping)
- if len(pcm.prevMappings) > 10 {
- pcm.prevMappings = pcm.prevMappings[1:]
- }
- }
- pcm.currentMapping = newMapping
-}
-
-func doBalanceSticky(partitions []*pub_balancer.PartitionSlotToBroker, consumerInstances []*ConsumerGroupInstance, prevMapping *PartitionSlotToConsumerInstanceList) (partitionSlots []*PartitionSlotToConsumerInstance) {
- // collect previous consumer instance ids
- prevConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{})
- if prevMapping != nil {
- for _, prevPartitionSlot := range prevMapping.PartitionSlots {
- if prevPartitionSlot.AssignedInstanceId != "" {
- prevConsumerInstanceIds[prevPartitionSlot.AssignedInstanceId] = struct{}{}
- }
- }
- }
- // collect current consumer instance ids
- currConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{})
- for _, consumerInstance := range consumerInstances {
- currConsumerInstanceIds[consumerInstance.InstanceId] = struct{}{}
- }
-
- // check deleted consumer instances
- deletedConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{})
- for consumerInstanceId := range prevConsumerInstanceIds {
- if _, ok := currConsumerInstanceIds[consumerInstanceId]; !ok {
- deletedConsumerInstanceIds[consumerInstanceId] = struct{}{}
- }
- }
-
- // convert partition slots from list to a map
- prevPartitionSlotMap := make(map[string]*PartitionSlotToConsumerInstance)
- if prevMapping != nil {
- for _, partitionSlot := range prevMapping.PartitionSlots {
- key := fmt.Sprintf("%d-%d", partitionSlot.RangeStart, partitionSlot.RangeStop)
- prevPartitionSlotMap[key] = partitionSlot
- }
- }
-
- // make a copy of old mapping, skipping the deleted consumer instances
- newPartitionSlots := make([]*PartitionSlotToConsumerInstance, 0, len(partitions))
- for _, partition := range partitions {
- newPartitionSlots = append(newPartitionSlots, &PartitionSlotToConsumerInstance{
- RangeStart: partition.RangeStart,
- RangeStop: partition.RangeStop,
- UnixTimeNs: partition.UnixTimeNs,
- Broker: partition.AssignedBroker,
- FollowerBroker: partition.FollowerBroker,
- })
- }
- for _, newPartitionSlot := range newPartitionSlots {
- key := fmt.Sprintf("%d-%d", newPartitionSlot.RangeStart, newPartitionSlot.RangeStop)
- if prevPartitionSlot, ok := prevPartitionSlotMap[key]; ok {
- if _, ok := deletedConsumerInstanceIds[prevPartitionSlot.AssignedInstanceId]; !ok {
- newPartitionSlot.AssignedInstanceId = prevPartitionSlot.AssignedInstanceId
- }
- }
- }
-
- // for all consumer instances, count the average number of partitions
- // that are assigned to them
- consumerInstancePartitionCount := make(map[ConsumerGroupInstanceId]int)
- for _, newPartitionSlot := range newPartitionSlots {
- if newPartitionSlot.AssignedInstanceId != "" {
- consumerInstancePartitionCount[newPartitionSlot.AssignedInstanceId]++
- }
- }
- // average number of partitions that are assigned to each consumer instance
- averageConsumerInstanceLoad := float32(len(partitions)) / float32(len(consumerInstances))
-
- // assign unassigned partition slots to consumer instances that is underloaded
- consumerInstanceIdsIndex := 0
- for _, newPartitionSlot := range newPartitionSlots {
- if newPartitionSlot.AssignedInstanceId == "" {
- for avoidDeadLoop := len(consumerInstances); avoidDeadLoop > 0; avoidDeadLoop-- {
- consumerInstance := consumerInstances[consumerInstanceIdsIndex]
- if float32(consumerInstancePartitionCount[consumerInstance.InstanceId]) < averageConsumerInstanceLoad {
- newPartitionSlot.AssignedInstanceId = consumerInstance.InstanceId
- consumerInstancePartitionCount[consumerInstance.InstanceId]++
- consumerInstanceIdsIndex++
- if consumerInstanceIdsIndex >= len(consumerInstances) {
- consumerInstanceIdsIndex = 0
- }
- break
- } else {
- consumerInstanceIdsIndex++
- if consumerInstanceIdsIndex >= len(consumerInstances) {
- consumerInstanceIdsIndex = 0
- }
- }
- }
- }
- }
-
- return newPartitionSlots
-}
diff --git a/weed/mq/sub_coordinator/partition_consumer_mapping_test.go b/weed/mq/sub_coordinator/partition_consumer_mapping_test.go
deleted file mode 100644
index ccc4e8601..000000000
--- a/weed/mq/sub_coordinator/partition_consumer_mapping_test.go
+++ /dev/null
@@ -1,385 +0,0 @@
-package sub_coordinator
-
-import (
- "reflect"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
-)
-
-func Test_doBalanceSticky(t *testing.T) {
- type args struct {
- partitions []*pub_balancer.PartitionSlotToBroker
- consumerInstanceIds []*ConsumerGroupInstance
- prevMapping *PartitionSlotToConsumerInstanceList
- }
- tests := []struct {
- name string
- args args
- wantPartitionSlots []*PartitionSlotToConsumerInstance
- }{
- {
- name: "1 consumer instance, 1 partition",
- args: args{
- partitions: []*pub_balancer.PartitionSlotToBroker{
- {
- RangeStart: 0,
- RangeStop: 100,
- },
- },
- consumerInstanceIds: []*ConsumerGroupInstance{
- {
- InstanceId: "consumer-instance-1",
- MaxPartitionCount: 1,
- },
- },
- prevMapping: nil,
- },
- wantPartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-1",
- },
- },
- },
- {
- name: "2 consumer instances, 1 partition",
- args: args{
- partitions: []*pub_balancer.PartitionSlotToBroker{
- {
- RangeStart: 0,
- RangeStop: 100,
- },
- },
- consumerInstanceIds: []*ConsumerGroupInstance{
- {
- InstanceId: "consumer-instance-1",
- MaxPartitionCount: 1,
- },
- {
- InstanceId: "consumer-instance-2",
- MaxPartitionCount: 1,
- },
- },
- prevMapping: nil,
- },
- wantPartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-1",
- },
- },
- },
- {
- name: "1 consumer instance, 2 partitions",
- args: args{
- partitions: []*pub_balancer.PartitionSlotToBroker{
- {
- RangeStart: 0,
- RangeStop: 50,
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- },
- },
- consumerInstanceIds: []*ConsumerGroupInstance{
- {
- InstanceId: "consumer-instance-1",
- MaxPartitionCount: 1,
- },
- },
- prevMapping: nil,
- },
- wantPartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 50,
- AssignedInstanceId: "consumer-instance-1",
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-1",
- },
- },
- },
- {
- name: "2 consumer instances, 2 partitions",
- args: args{
- partitions: []*pub_balancer.PartitionSlotToBroker{
- {
- RangeStart: 0,
- RangeStop: 50,
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- },
- },
- consumerInstanceIds: []*ConsumerGroupInstance{
- {
- InstanceId: "consumer-instance-1",
- MaxPartitionCount: 1,
- },
- {
- InstanceId: "consumer-instance-2",
- MaxPartitionCount: 1,
- },
- },
- prevMapping: nil,
- },
- wantPartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 50,
- AssignedInstanceId: "consumer-instance-1",
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-2",
- },
- },
- },
- {
- name: "2 consumer instances, 2 partitions, 1 deleted consumer instance",
- args: args{
- partitions: []*pub_balancer.PartitionSlotToBroker{
- {
- RangeStart: 0,
- RangeStop: 50,
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- },
- },
- consumerInstanceIds: []*ConsumerGroupInstance{
- {
- InstanceId: "consumer-instance-1",
- MaxPartitionCount: 1,
- },
- {
- InstanceId: "consumer-instance-2",
- MaxPartitionCount: 1,
- },
- },
- prevMapping: &PartitionSlotToConsumerInstanceList{
- PartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 50,
- AssignedInstanceId: "consumer-instance-3",
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-2",
- },
- },
- },
- },
- wantPartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 50,
- AssignedInstanceId: "consumer-instance-1",
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-2",
- },
- },
- },
- {
- name: "2 consumer instances, 2 partitions, 1 new consumer instance",
- args: args{
- partitions: []*pub_balancer.PartitionSlotToBroker{
- {
- RangeStart: 0,
- RangeStop: 50,
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- },
- },
- consumerInstanceIds: []*ConsumerGroupInstance{
- {
- InstanceId: "consumer-instance-1",
- MaxPartitionCount: 1,
- },
- {
- InstanceId: "consumer-instance-2",
- MaxPartitionCount: 1,
- },
- {
- InstanceId: "consumer-instance-3",
- MaxPartitionCount: 1,
- },
- },
- prevMapping: &PartitionSlotToConsumerInstanceList{
- PartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 50,
- AssignedInstanceId: "consumer-instance-3",
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-2",
- },
- },
- },
- },
- wantPartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 50,
- AssignedInstanceId: "consumer-instance-3",
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-2",
- },
- },
- },
- {
- name: "2 consumer instances, 2 partitions, 1 new partition",
- args: args{
- partitions: []*pub_balancer.PartitionSlotToBroker{
- {
- RangeStart: 0,
- RangeStop: 50,
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- },
- {
- RangeStart: 100,
- RangeStop: 150,
- },
- },
- consumerInstanceIds: []*ConsumerGroupInstance{
- {
- InstanceId: "consumer-instance-1",
- MaxPartitionCount: 1,
- },
- {
- InstanceId: "consumer-instance-2",
- MaxPartitionCount: 1,
- },
- },
- prevMapping: &PartitionSlotToConsumerInstanceList{
- PartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 50,
- AssignedInstanceId: "consumer-instance-1",
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-2",
- },
- },
- },
- },
- wantPartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 50,
- AssignedInstanceId: "consumer-instance-1",
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-2",
- },
- {
- RangeStart: 100,
- RangeStop: 150,
- AssignedInstanceId: "consumer-instance-1",
- },
- },
- },
- {
- name: "2 consumer instances, 2 partitions, 1 new partition, 1 new consumer instance",
- args: args{
- partitions: []*pub_balancer.PartitionSlotToBroker{
- {
- RangeStart: 0,
- RangeStop: 50,
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- },
- {
- RangeStart: 100,
- RangeStop: 150,
- },
- },
- consumerInstanceIds: []*ConsumerGroupInstance{
- {
- InstanceId: "consumer-instance-1",
- MaxPartitionCount: 1,
- },
- {
- InstanceId: "consumer-instance-2",
- MaxPartitionCount: 1,
- },
- {
- InstanceId: "consumer-instance-3",
- MaxPartitionCount: 1,
- },
- },
- prevMapping: &PartitionSlotToConsumerInstanceList{
- PartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 50,
- AssignedInstanceId: "consumer-instance-1",
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-2",
- },
- },
- },
- },
- wantPartitionSlots: []*PartitionSlotToConsumerInstance{
- {
- RangeStart: 0,
- RangeStop: 50,
- AssignedInstanceId: "consumer-instance-1",
- },
- {
- RangeStart: 50,
- RangeStop: 100,
- AssignedInstanceId: "consumer-instance-2",
- },
- {
- RangeStart: 100,
- RangeStop: 150,
- AssignedInstanceId: "consumer-instance-3",
- },
- },
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if gotPartitionSlots := doBalanceSticky(tt.args.partitions, tt.args.consumerInstanceIds, tt.args.prevMapping); !reflect.DeepEqual(gotPartitionSlots, tt.wantPartitionSlots) {
- t.Errorf("doBalanceSticky() = %v, want %v", gotPartitionSlots, tt.wantPartitionSlots)
- }
- })
- }
-}
diff --git a/weed/mq/sub_coordinator/partition_list.go b/weed/mq/sub_coordinator/partition_list.go
index 16bf1ff0c..38c130598 100644
--- a/weed/mq/sub_coordinator/partition_list.go
+++ b/weed/mq/sub_coordinator/partition_list.go
@@ -1,7 +1,5 @@
package sub_coordinator
-import "time"
-
type PartitionSlotToConsumerInstance struct {
RangeStart int32
RangeStop int32
@@ -16,10 +14,3 @@ type PartitionSlotToConsumerInstanceList struct {
RingSize int32
Version int64
}
-
-func NewPartitionSlotToConsumerInstanceList(ringSize int32, version time.Time) *PartitionSlotToConsumerInstanceList {
- return &PartitionSlotToConsumerInstanceList{
- RingSize: ringSize,
- Version: version.UnixNano(),
- }
-}
diff --git a/weed/mq/topic/local_partition_offset.go b/weed/mq/topic/local_partition_offset.go
index 9c8a2dac4..ef7da3606 100644
--- a/weed/mq/topic/local_partition_offset.go
+++ b/weed/mq/topic/local_partition_offset.go
@@ -90,22 +90,3 @@ type OffsetAwarePublisher struct {
partition *LocalPartition
assignOffsetFn OffsetAssignmentFunc
}
-
-// NewOffsetAwarePublisher creates a new offset-aware publisher
-func NewOffsetAwarePublisher(partition *LocalPartition, assignOffsetFn OffsetAssignmentFunc) *OffsetAwarePublisher {
- return &OffsetAwarePublisher{
- partition: partition,
- assignOffsetFn: assignOffsetFn,
- }
-}
-
-// Publish publishes a message with automatic offset assignment
-func (oap *OffsetAwarePublisher) Publish(message *mq_pb.DataMessage) error {
- _, err := oap.partition.PublishWithOffset(message, oap.assignOffsetFn)
- return err
-}
-
-// GetPartition returns the underlying partition
-func (oap *OffsetAwarePublisher) GetPartition() *LocalPartition {
- return oap.partition
-}
diff --git a/weed/mq/topic/partition.go b/weed/mq/topic/partition.go
index 658ec85c4..fc3b71aac 100644
--- a/weed/mq/topic/partition.go
+++ b/weed/mq/topic/partition.go
@@ -16,15 +16,6 @@ type Partition struct {
UnixTimeNs int64 // in nanoseconds
}
-func NewPartition(rangeStart, rangeStop, ringSize int32, unixTimeNs int64) *Partition {
- return &Partition{
- RangeStart: rangeStart,
- RangeStop: rangeStop,
- RingSize: ringSize,
- UnixTimeNs: unixTimeNs,
- }
-}
-
func (partition Partition) Equals(other Partition) bool {
if partition.RangeStart != other.RangeStart {
return false
@@ -57,24 +48,6 @@ func FromPbPartition(partition *schema_pb.Partition) Partition {
}
}
-func SplitPartitions(targetCount int32, ts int64) []*Partition {
- partitions := make([]*Partition, 0, targetCount)
- partitionSize := PartitionCount / targetCount
- for i := int32(0); i < targetCount; i++ {
- partitionStop := (i + 1) * partitionSize
- if i == targetCount-1 {
- partitionStop = PartitionCount
- }
- partitions = append(partitions, &Partition{
- RangeStart: i * partitionSize,
- RangeStop: partitionStop,
- RingSize: PartitionCount,
- UnixTimeNs: ts,
- })
- }
- return partitions
-}
-
func (partition Partition) ToPbPartition() *schema_pb.Partition {
return &schema_pb.Partition{
RangeStart: partition.RangeStart,
diff --git a/weed/operation/assign_file_id.go b/weed/operation/assign_file_id.go
index 7c2c71074..5609bf8ac 100644
--- a/weed/operation/assign_file_id.go
+++ b/weed/operation/assign_file_id.go
@@ -3,8 +3,6 @@ package operation
import (
"context"
"fmt"
- "strings"
- "sync"
"time"
"github.com/seaweedfs/seaweedfs/weed/pb"
@@ -41,118 +39,6 @@ type AssignResult struct {
Replicas []Location `json:"replicas,omitempty"`
}
-// This is a proxy to the master server, only for assigning volume ids.
-// It runs via grpc to the master server in streaming mode.
-// The connection to the master would only be re-established when the last connection has error.
-type AssignProxy struct {
- grpcConnection *grpc.ClientConn
- pool chan *singleThreadAssignProxy
-}
-
-func NewAssignProxy(masterFn GetMasterFn, grpcDialOption grpc.DialOption, concurrency int) (ap *AssignProxy, err error) {
- ap = &AssignProxy{
- pool: make(chan *singleThreadAssignProxy, concurrency),
- }
- ap.grpcConnection, err = pb.GrpcDial(context.Background(), masterFn(context.Background()).ToGrpcAddress(), true, grpcDialOption)
- if err != nil {
- return nil, fmt.Errorf("fail to dial %s: %v", masterFn(context.Background()).ToGrpcAddress(), err)
- }
- for i := 0; i < concurrency; i++ {
- ap.pool <- &singleThreadAssignProxy{}
- }
- return ap, nil
-}
-
-func (ap *AssignProxy) Assign(primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (ret *AssignResult, err error) {
- p := <-ap.pool
- defer func() {
- ap.pool <- p
- }()
-
- return p.doAssign(ap.grpcConnection, primaryRequest, alternativeRequests...)
-}
-
-type singleThreadAssignProxy struct {
- assignClient master_pb.Seaweed_StreamAssignClient
- sync.Mutex
-}
-
-func (ap *singleThreadAssignProxy) doAssign(grpcConnection *grpc.ClientConn, primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (ret *AssignResult, err error) {
- ap.Lock()
- defer ap.Unlock()
-
- if ap.assignClient == nil {
- client := master_pb.NewSeaweedClient(grpcConnection)
- ap.assignClient, err = client.StreamAssign(context.Background())
- if err != nil {
- ap.assignClient = nil
- return nil, fmt.Errorf("fail to create stream assign client: %w", err)
- }
- }
-
- var requests []*VolumeAssignRequest
- requests = append(requests, primaryRequest)
- requests = append(requests, alternativeRequests...)
- ret = &AssignResult{}
-
- for _, request := range requests {
- if request == nil {
- continue
- }
- req := &master_pb.AssignRequest{
- Count: request.Count,
- Replication: request.Replication,
- Collection: request.Collection,
- Ttl: request.Ttl,
- DiskType: request.DiskType,
- DataCenter: request.DataCenter,
- Rack: request.Rack,
- DataNode: request.DataNode,
- WritableVolumeCount: request.WritableVolumeCount,
- }
- if err = ap.assignClient.Send(req); err != nil {
- ap.assignClient = nil
- return nil, fmt.Errorf("StreamAssignSend: %w", err)
- }
- resp, grpcErr := ap.assignClient.Recv()
- if grpcErr != nil {
- ap.assignClient = nil
- return nil, grpcErr
- }
- if resp.Error != "" {
- // StreamAssign returns transient warmup errors as in-band responses.
- // Wrap them as codes.Unavailable so the caller's retry logic can
- // classify them as retriable.
- if strings.Contains(resp.Error, "warming up") {
- return nil, status.Errorf(codes.Unavailable, "StreamAssignRecv: %s", resp.Error)
- }
- return nil, fmt.Errorf("StreamAssignRecv: %v", resp.Error)
- }
-
- ret.Count = resp.Count
- ret.Fid = resp.Fid
- ret.Url = resp.Location.Url
- ret.PublicUrl = resp.Location.PublicUrl
- ret.GrpcPort = int(resp.Location.GrpcPort)
- ret.Error = resp.Error
- ret.Auth = security.EncodedJwt(resp.Auth)
- for _, r := range resp.Replicas {
- ret.Replicas = append(ret.Replicas, Location{
- Url: r.Url,
- PublicUrl: r.PublicUrl,
- DataCenter: r.DataCenter,
- })
- }
-
- if ret.Count <= 0 {
- continue
- }
- break
- }
-
- return
-}
-
func Assign(ctx context.Context, masterFn GetMasterFn, grpcDialOption grpc.DialOption, primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (*AssignResult, error) {
var requests []*VolumeAssignRequest
diff --git a/weed/operation/assign_file_id_test.go b/weed/operation/assign_file_id_test.go
deleted file mode 100644
index ecfa7d6d0..000000000
--- a/weed/operation/assign_file_id_test.go
+++ /dev/null
@@ -1,70 +0,0 @@
-package operation
-
-import (
- "context"
- "fmt"
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/pb"
- "google.golang.org/grpc"
-)
-
-func BenchmarkWithConcurrency(b *testing.B) {
- concurrencyLevels := []int{1, 10, 100, 1000}
-
- ap, _ := NewAssignProxy(func(_ context.Context) pb.ServerAddress {
- return pb.ServerAddress("localhost:9333")
- }, grpc.WithInsecure(), 16)
-
- for _, concurrency := range concurrencyLevels {
- b.Run(
- fmt.Sprintf("Concurrency-%d", concurrency),
- func(b *testing.B) {
- for i := 0; i < b.N; i++ {
- done := make(chan struct{})
- startTime := time.Now()
-
- for j := 0; j < concurrency; j++ {
- go func() {
-
- ap.Assign(&VolumeAssignRequest{
- Count: 1,
- })
-
- done <- struct{}{}
- }()
- }
-
- for j := 0; j < concurrency; j++ {
- <-done
- }
-
- duration := time.Since(startTime)
- b.Logf("Concurrency: %d, Duration: %v", concurrency, duration)
- }
- },
- )
- }
-}
-
-func BenchmarkStreamAssign(b *testing.B) {
- ap, _ := NewAssignProxy(func(_ context.Context) pb.ServerAddress {
- return pb.ServerAddress("localhost:9333")
- }, grpc.WithInsecure(), 16)
- for i := 0; i < b.N; i++ {
- ap.Assign(&VolumeAssignRequest{
- Count: 1,
- })
- }
-}
-
-func BenchmarkUnaryAssign(b *testing.B) {
- for i := 0; i < b.N; i++ {
- Assign(context.Background(), func(_ context.Context) pb.ServerAddress {
- return pb.ServerAddress("localhost:9333")
- }, grpc.WithInsecure(), &VolumeAssignRequest{
- Count: 1,
- })
- }
-}
diff --git a/weed/pb/filer_pb/filer_client.go b/weed/pb/filer_pb/filer_client.go
index c93417eee..3e7f9859e 100644
--- a/weed/pb/filer_pb/filer_client.go
+++ b/weed/pb/filer_pb/filer_client.go
@@ -93,11 +93,6 @@ func List(ctx context.Context, filerClient FilerClient, parentDirectoryPath, pre
})
}
-func doList(ctx context.Context, filerClient FilerClient, fullDirPath util.FullPath, prefix string, fn EachEntryFunction, startFrom string, inclusive bool, limit uint32) (err error) {
- _, err = doListWithSnapshot(ctx, filerClient, fullDirPath, prefix, fn, startFrom, inclusive, limit, 0)
- return err
-}
-
func doListWithSnapshot(ctx context.Context, filerClient FilerClient, fullDirPath util.FullPath, prefix string, fn EachEntryFunction, startFrom string, inclusive bool, limit uint32, snapshotTsNs int64) (actualSnapshotTsNs int64, err error) {
err = filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error {
actualSnapshotTsNs, err = DoSeaweedListWithSnapshot(ctx, client, fullDirPath, prefix, fn, startFrom, inclusive, limit, snapshotTsNs)
@@ -212,26 +207,6 @@ func Exists(ctx context.Context, filerClient FilerClient, parentDirectoryPath st
return
}
-func Touch(ctx context.Context, filerClient FilerClient, parentDirectoryPath string, entryName string, entry *Entry) (err error) {
-
- return filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error {
-
- request := &UpdateEntryRequest{
- Directory: parentDirectoryPath,
- Entry: entry,
- }
-
- glog.V(4).InfofCtx(ctx, "touch entry %v/%v: %v", parentDirectoryPath, entryName, request)
- if err := UpdateEntry(ctx, client, request); err != nil {
- glog.V(0).InfofCtx(ctx, "touch exists entry %v: %v", request, err)
- return fmt.Errorf("touch exists entry %s/%s: %v", parentDirectoryPath, entryName, err)
- }
-
- return nil
- })
-
-}
-
func Mkdir(ctx context.Context, filerClient FilerClient, parentDirectoryPath string, dirName string, fn func(entry *Entry)) error {
return filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error {
return DoMkdir(ctx, client, parentDirectoryPath, dirName, fn)
@@ -349,59 +324,3 @@ func DoRemoveWithResponse(ctx context.Context, client SeaweedFilerClient, parent
return resp, nil
}
}
-
-// DoDeleteEmptyParentDirectories recursively deletes empty parent directories.
-// It stops at root "/" or at stopAtPath.
-// For safety, dirPath must be under stopAtPath (when stopAtPath is provided).
-// The checked map tracks already-processed directories to avoid redundant work in batch operations.
-func DoDeleteEmptyParentDirectories(ctx context.Context, client SeaweedFilerClient, dirPath util.FullPath, stopAtPath util.FullPath, checked map[string]bool) {
- if dirPath == "/" || dirPath == stopAtPath {
- return
- }
-
- // Skip if already checked (for batch delete optimization)
- dirPathStr := string(dirPath)
- if checked != nil {
- if checked[dirPathStr] {
- return
- }
- checked[dirPathStr] = true
- }
-
- // Safety check: if stopAtPath is provided, dirPath must be under it (root "/" allows everything)
- stopStr := string(stopAtPath)
- if stopAtPath != "" && stopStr != "/" && !strings.HasPrefix(dirPathStr+"/", stopStr+"/") {
- glog.V(1).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: %s is not under %s, skipping", dirPath, stopAtPath)
- return
- }
-
- // Check if directory is empty by listing with limit 1
- isEmpty := true
- err := SeaweedList(ctx, client, dirPathStr, "", func(entry *Entry, isLast bool) error {
- isEmpty = false
- return io.EOF // Use sentinel error to explicitly stop iteration
- }, "", false, 1)
-
- if err != nil && err != io.EOF {
- glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: error checking %s: %v", dirPath, err)
- return
- }
-
- if !isEmpty {
- // Directory is not empty, stop checking upward
- glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: directory %s is not empty, stopping cleanup", dirPath)
- return
- }
-
- // Directory is empty, try to delete it
- glog.V(2).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: deleting empty directory %s", dirPath)
- parentDir, dirName := dirPath.DirAndName()
-
- if err := DoRemove(ctx, client, parentDir, dirName, false, false, false, false, nil); err == nil {
- // Successfully deleted, continue checking upwards
- DoDeleteEmptyParentDirectories(ctx, client, util.FullPath(parentDir), stopAtPath, checked)
- } else {
- // Failed to delete, stop cleanup
- glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: failed to delete %s: %v", dirPath, err)
- }
-}
diff --git a/weed/pb/filer_pb/filer_pb_helper.go b/weed/pb/filer_pb/filer_pb_helper.go
index 05d5f602a..b621e366a 100644
--- a/weed/pb/filer_pb/filer_pb_helper.go
+++ b/weed/pb/filer_pb/filer_pb_helper.go
@@ -111,15 +111,6 @@ func BeforeEntrySerialization(chunks []*FileChunk) {
}
}
-func EnsureFid(chunk *FileChunk) {
- if chunk.Fid != nil {
- return
- }
- if fid, err := ToFileIdObject(chunk.FileId); err == nil {
- chunk.Fid = fid
- }
-}
-
func AfterEntryDeserialization(chunks []*FileChunk) {
for _, chunk := range chunks {
@@ -309,16 +300,6 @@ func MetadataEventTouchesDirectory(event *SubscribeMetadataResponse, dir string)
MetadataEventTargetDirectory(event) == dir
}
-func MetadataEventTouchesDirectoryPrefix(event *SubscribeMetadataResponse, prefix string) bool {
- if strings.HasPrefix(MetadataEventSourceDirectory(event), prefix) {
- return true
- }
- return event != nil &&
- event.EventNotification != nil &&
- event.EventNotification.NewEntry != nil &&
- strings.HasPrefix(MetadataEventTargetDirectory(event), prefix)
-}
-
func MetadataEventMatchesSubscription(event *SubscribeMetadataResponse, pathPrefix string, pathPrefixes []string, directories []string) bool {
if event == nil {
return false
diff --git a/weed/pb/filer_pb/filer_pb_helper_test.go b/weed/pb/filer_pb/filer_pb_helper_test.go
deleted file mode 100644
index b38b094e3..000000000
--- a/weed/pb/filer_pb/filer_pb_helper_test.go
+++ /dev/null
@@ -1,87 +0,0 @@
-package filer_pb
-
-import (
- "testing"
-
- "google.golang.org/protobuf/proto"
-)
-
-func TestFileIdSize(t *testing.T) {
- fileIdStr := "11745,0293434534cbb9892b"
-
- fid, _ := ToFileIdObject(fileIdStr)
- bytes, _ := proto.Marshal(fid)
-
- println(len(fileIdStr))
- println(len(bytes))
-}
-
-func TestMetadataEventMatchesSubscription(t *testing.T) {
- event := &SubscribeMetadataResponse{
- Directory: "/tmp",
- EventNotification: &EventNotification{
- OldEntry: &Entry{Name: "old-name"},
- NewEntry: &Entry{Name: "new-name"},
- NewParentPath: "/watched",
- },
- }
-
- tests := []struct {
- name string
- pathPrefix string
- pathPrefixes []string
- directories []string
- }{
- {
- name: "primary path prefix matches rename target",
- pathPrefix: "/watched/new-name",
- },
- {
- name: "additional path prefix matches rename target",
- pathPrefix: "/data",
- pathPrefixes: []string{"/watched"},
- },
- {
- name: "directory watch matches rename target directory",
- pathPrefix: "/data",
- directories: []string{"/watched"},
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if !MetadataEventMatchesSubscription(event, tt.pathPrefix, tt.pathPrefixes, tt.directories) {
- t.Fatalf("MetadataEventMatchesSubscription returned false")
- }
- })
- }
-}
-
-func TestMetadataEventTouchesDirectoryHelpers(t *testing.T) {
- renameInto := &SubscribeMetadataResponse{
- Directory: "/tmp",
- EventNotification: &EventNotification{
- OldEntry: &Entry{Name: "filer.conf"},
- NewEntry: &Entry{Name: "filer.conf"},
- NewParentPath: "/etc/seaweedfs",
- },
- }
- if got := MetadataEventTargetDirectory(renameInto); got != "/etc/seaweedfs" {
- t.Fatalf("MetadataEventTargetDirectory = %q, want /etc/seaweedfs", got)
- }
- if !MetadataEventTouchesDirectory(renameInto, "/etc/seaweedfs") {
- t.Fatalf("expected rename target to touch /etc/seaweedfs")
- }
-
- renameOut := &SubscribeMetadataResponse{
- Directory: "/etc/remote",
- EventNotification: &EventNotification{
- OldEntry: &Entry{Name: "remote.conf"},
- NewEntry: &Entry{Name: "remote.conf"},
- NewParentPath: "/tmp",
- },
- }
- if !MetadataEventTouchesDirectoryPrefix(renameOut, "/etc/remote") {
- t.Fatalf("expected rename source to touch /etc/remote")
- }
-}
diff --git a/weed/pb/grpc_client_server.go b/weed/pb/grpc_client_server.go
index 4b7c0852d..82b9a23f5 100644
--- a/weed/pb/grpc_client_server.go
+++ b/weed/pb/grpc_client_server.go
@@ -28,7 +28,6 @@ import (
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
"github.com/seaweedfs/seaweedfs/weed/pb/master_pb"
"github.com/seaweedfs/seaweedfs/weed/pb/mq_pb"
- "github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
)
const (
@@ -318,18 +317,6 @@ func WithGrpcClient(streamingMode bool, signature int32, fn func(*grpc.ClientCon
}
-func ParseServerAddress(server string, deltaPort int) (newServerAddress string, err error) {
-
- host, port, parseErr := hostAndPort(server)
- if parseErr != nil {
- return "", fmt.Errorf("server port parse error: %w", parseErr)
- }
-
- newPort := int(port) + deltaPort
-
- return util.JoinHostPort(host, newPort), nil
-}
-
func hostAndPort(address string) (host string, port uint64, err error) {
colonIndex := strings.LastIndex(address, ":")
if colonIndex < 0 {
@@ -457,10 +444,3 @@ func WithOneOfGrpcFilerClients(streamingMode bool, filerAddresses []ServerAddres
return err
}
-
-func WithWorkerClient(streamingMode bool, workerAddress string, grpcDialOption grpc.DialOption, fn func(client worker_pb.WorkerServiceClient) error) error {
- return WithGrpcClient(streamingMode, 0, func(grpcConnection *grpc.ClientConn) error {
- client := worker_pb.NewWorkerServiceClient(grpcConnection)
- return fn(client)
- }, workerAddress, false, grpcDialOption)
-}
diff --git a/weed/pb/server_address.go b/weed/pb/server_address.go
index 151323b03..7cdece5eb 100644
--- a/weed/pb/server_address.go
+++ b/weed/pb/server_address.go
@@ -157,14 +157,6 @@ func (sa ServerAddresses) ToAddressMap() (addresses map[string]ServerAddress) {
return
}
-func (sa ServerAddresses) ToAddressStrings() (addresses []string) {
- parts := strings.Split(string(sa), ",")
- for _, address := range parts {
- addresses = append(addresses, address)
- }
- return
-}
-
func ToAddressStrings(addresses []ServerAddress) []string {
var strings []string
for _, addr := range addresses {
@@ -172,20 +164,6 @@ func ToAddressStrings(addresses []ServerAddress) []string {
}
return strings
}
-func ToAddressStringsFromMap(addresses map[string]ServerAddress) []string {
- var strings []string
- for _, addr := range addresses {
- strings = append(strings, string(addr))
- }
- return strings
-}
-func FromAddressStrings(strings []string) []ServerAddress {
- var addresses []ServerAddress
- for _, addr := range strings {
- addresses = append(addresses, ServerAddress(addr))
- }
- return addresses
-}
func ParseUrl(input string) (address ServerAddress, path string, err error) {
if !strings.HasPrefix(input, "http://") {
diff --git a/weed/plugin/worker/iceberg/detection.go b/weed/plugin/worker/iceberg/detection.go
index 80e7fcd61..8a279287c 100644
--- a/weed/plugin/worker/iceberg/detection.go
+++ b/weed/plugin/worker/iceberg/detection.go
@@ -449,58 +449,6 @@ func hasEligibleCompaction(
return len(bins) > 0, nil
}
-func countDataManifestsForRewrite(
- ctx context.Context,
- filerClient filer_pb.SeaweedFilerClient,
- bucketName, tablePath string,
- manifests []iceberg.ManifestFile,
- meta table.Metadata,
- predicate *partitionPredicate,
-) (int64, error) {
- if predicate == nil {
- return countDataManifests(manifests), nil
- }
-
- specsByID := specByID(meta)
-
- var count int64
- for _, mf := range manifests {
- if mf.ManifestContent() != iceberg.ManifestContentData {
- continue
- }
- manifestData, err := loadFileByIcebergPath(ctx, filerClient, bucketName, tablePath, mf.FilePath())
- if err != nil {
- return 0, fmt.Errorf("read manifest %s: %w", mf.FilePath(), err)
- }
- entries, err := iceberg.ReadManifest(mf, bytes.NewReader(manifestData), true)
- if err != nil {
- return 0, fmt.Errorf("parse manifest %s: %w", mf.FilePath(), err)
- }
- if len(entries) == 0 {
- continue
- }
- spec, ok := specsByID[int(mf.PartitionSpecID())]
- if !ok {
- continue
- }
- allMatch := len(entries) > 0
- for _, entry := range entries {
- match, err := predicate.Matches(spec, entry.DataFile().Partition())
- if err != nil {
- return 0, err
- }
- if !match {
- allMatch = false
- break
- }
- }
- if allMatch {
- count++
- }
- }
- return count, nil
-}
-
func compactionMinInputFiles(minInputFiles int64) (int, error) {
// Ensure the configured value is positive and fits into the platform's int type
if minInputFiles <= 0 {
diff --git a/weed/plugin/worker/iceberg/planning_index.go b/weed/plugin/worker/iceberg/planning_index.go
index 74e019354..a2401836a 100644
--- a/weed/plugin/worker/iceberg/planning_index.go
+++ b/weed/plugin/worker/iceberg/planning_index.go
@@ -137,26 +137,6 @@ func mergePlanningIndexSections(index, existing *planningIndex) *planningIndex {
return index
}
-func buildPlanningIndex(
- ctx context.Context,
- filerClient filer_pb.SeaweedFilerClient,
- bucketName, tablePath string,
- meta table.Metadata,
- config Config,
- ops []string,
-) (*planningIndex, error) {
- currentSnap := meta.CurrentSnapshot()
- if currentSnap == nil || currentSnap.ManifestList == "" {
- return nil, nil
- }
-
- manifests, err := loadCurrentManifests(ctx, filerClient, bucketName, tablePath, meta)
- if err != nil {
- return nil, err
- }
- return buildPlanningIndexFromManifests(ctx, filerClient, bucketName, tablePath, meta, config, ops, manifests)
-}
-
func buildPlanningIndexFromManifests(
ctx context.Context,
filerClient filer_pb.SeaweedFilerClient,
diff --git a/weed/plugin/worker/lifecycle/config.go b/weed/plugin/worker/lifecycle/config.go
deleted file mode 100644
index 62e0b4dbf..000000000
--- a/weed/plugin/worker/lifecycle/config.go
+++ /dev/null
@@ -1,131 +0,0 @@
-package lifecycle
-
-import (
- "strconv"
- "strings"
-
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/pb/plugin_pb"
-)
-
-const (
- jobType = "s3_lifecycle"
-
- defaultBatchSize = 1000
- defaultMaxDeletesPerBucket = 10000
- defaultDryRun = false
- defaultDeleteMarkerCleanup = true
- defaultAbortMPUDaysDefault = 7
-
- MetricObjectsExpired = "objects_expired"
- MetricObjectsScanned = "objects_scanned"
- MetricBucketsScanned = "buckets_scanned"
- MetricBucketsWithRules = "buckets_with_rules"
- MetricDeleteMarkersClean = "delete_markers_cleaned"
- MetricMPUAborted = "mpu_aborted"
- MetricErrors = "errors"
- MetricDurationMs = "duration_ms"
-)
-
-// Config holds parsed worker config values for lifecycle management.
-type Config struct {
- BatchSize int64
- MaxDeletesPerBucket int64
- DryRun bool
- DeleteMarkerCleanup bool
- AbortMPUDays int64
-}
-
-// ParseConfig extracts a lifecycle Config from plugin config values.
-func ParseConfig(values map[string]*plugin_pb.ConfigValue) Config {
- cfg := Config{
- BatchSize: readInt64Config(values, "batch_size", defaultBatchSize),
- MaxDeletesPerBucket: readInt64Config(values, "max_deletes_per_bucket", defaultMaxDeletesPerBucket),
- DryRun: readBoolConfig(values, "dry_run", defaultDryRun),
- DeleteMarkerCleanup: readBoolConfig(values, "delete_marker_cleanup", defaultDeleteMarkerCleanup),
- AbortMPUDays: readInt64Config(values, "abort_mpu_days", defaultAbortMPUDaysDefault),
- }
-
- if cfg.BatchSize <= 0 {
- cfg.BatchSize = defaultBatchSize
- }
- if cfg.MaxDeletesPerBucket <= 0 {
- cfg.MaxDeletesPerBucket = defaultMaxDeletesPerBucket
- }
- if cfg.AbortMPUDays < 0 {
- cfg.AbortMPUDays = defaultAbortMPUDaysDefault
- }
-
- return cfg
-}
-
-func readStringConfig(values map[string]*plugin_pb.ConfigValue, field string, fallback string) string {
- if values == nil {
- return fallback
- }
- value := values[field]
- if value == nil {
- return fallback
- }
- switch kind := value.Kind.(type) {
- case *plugin_pb.ConfigValue_StringValue:
- return kind.StringValue
- case *plugin_pb.ConfigValue_Int64Value:
- return strconv.FormatInt(kind.Int64Value, 10)
- default:
- glog.V(1).Infof("readStringConfig: unexpected type %T for field %q", value.Kind, field)
- }
- return fallback
-}
-
-func readBoolConfig(values map[string]*plugin_pb.ConfigValue, field string, fallback bool) bool {
- if values == nil {
- return fallback
- }
- value := values[field]
- if value == nil {
- return fallback
- }
- switch kind := value.Kind.(type) {
- case *plugin_pb.ConfigValue_BoolValue:
- return kind.BoolValue
- case *plugin_pb.ConfigValue_StringValue:
- s := strings.TrimSpace(strings.ToLower(kind.StringValue))
- if s == "true" || s == "1" || s == "yes" {
- return true
- }
- if s == "false" || s == "0" || s == "no" {
- return false
- }
- glog.V(1).Infof("readBoolConfig: unrecognized string value %q for field %q, using fallback %v", kind.StringValue, field, fallback)
- case *plugin_pb.ConfigValue_Int64Value:
- return kind.Int64Value != 0
- default:
- glog.V(1).Infof("readBoolConfig: unexpected config value type %T for field %q, using fallback %v", value.Kind, field, fallback)
- }
- return fallback
-}
-
-func readInt64Config(values map[string]*plugin_pb.ConfigValue, field string, fallback int64) int64 {
- if values == nil {
- return fallback
- }
- value := values[field]
- if value == nil {
- return fallback
- }
- switch kind := value.Kind.(type) {
- case *plugin_pb.ConfigValue_Int64Value:
- return kind.Int64Value
- case *plugin_pb.ConfigValue_DoubleValue:
- return int64(kind.DoubleValue)
- case *plugin_pb.ConfigValue_StringValue:
- parsed, err := strconv.ParseInt(strings.TrimSpace(kind.StringValue), 10, 64)
- if err == nil {
- return parsed
- }
- default:
- glog.V(1).Infof("readInt64Config: unexpected config value type %T for field %q, using fallback %d", value.Kind, field, fallback)
- }
- return fallback
-}
diff --git a/weed/plugin/worker/lifecycle/detection.go b/weed/plugin/worker/lifecycle/detection.go
deleted file mode 100644
index d8267b2f0..000000000
--- a/weed/plugin/worker/lifecycle/detection.go
+++ /dev/null
@@ -1,221 +0,0 @@
-package lifecycle
-
-import (
- "context"
- "fmt"
- "path"
- "strings"
-
- "github.com/seaweedfs/seaweedfs/weed/filer"
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/pb/plugin_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/util/wildcard"
-)
-
-const lifecycleXMLKey = "s3-bucket-lifecycle-configuration-xml"
-
-// detectBucketsWithLifecycleRules scans all S3 buckets to find those
-// with lifecycle rules, either TTL entries in filer.conf or lifecycle
-// XML stored in bucket metadata.
-func (h *Handler) detectBucketsWithLifecycleRules(
- ctx context.Context,
- filerClient filer_pb.SeaweedFilerClient,
- config Config,
- bucketFilter string,
- maxResults int,
-) ([]*plugin_pb.JobProposal, error) {
- // Load filer configuration to find TTL rules.
- fc, err := loadFilerConf(ctx, filerClient)
- if err != nil {
- return nil, fmt.Errorf("load filer conf: %w", err)
- }
-
- bucketsPath := defaultBucketsPath
- bucketMatchers := wildcard.CompileWildcardMatchers(bucketFilter)
-
- // List all buckets.
- bucketEntries, err := listFilerEntries(ctx, filerClient, bucketsPath, "")
- if err != nil {
- return nil, fmt.Errorf("list buckets at %s: %w", bucketsPath, err)
- }
-
- var proposals []*plugin_pb.JobProposal
- for _, entry := range bucketEntries {
- select {
- case <-ctx.Done():
- return proposals, ctx.Err()
- default:
- }
-
- if !entry.IsDirectory {
- continue
- }
- bucketName := entry.Name
- if !wildcard.MatchesAnyWildcard(bucketMatchers, bucketName) {
- continue
- }
-
- // Check for lifecycle rules from two sources:
- // 1. filer.conf TTLs (legacy Expiration.Days fast path)
- // 2. Stored lifecycle XML in bucket metadata (full rule support)
- collection := bucketName
- ttls := fc.GetCollectionTtls(collection)
-
- hasLifecycleXML := entry.Extended != nil && len(entry.Extended[lifecycleXMLKey]) > 0
- versioningStatus := ""
- if entry.Extended != nil {
- versioningStatus = string(entry.Extended[s3_constants.ExtVersioningKey])
- }
-
- ruleCount := int64(len(ttls))
- if !hasLifecycleXML && ruleCount == 0 {
- continue
- }
-
- glog.V(2).Infof("s3_lifecycle: bucket %s has %d TTL rule(s), lifecycle_xml=%v, versioning=%s",
- bucketName, ruleCount, hasLifecycleXML, versioningStatus)
-
- proposal := &plugin_pb.JobProposal{
- ProposalId: fmt.Sprintf("s3_lifecycle:%s", bucketName),
- JobType: jobType,
- Summary: fmt.Sprintf("Lifecycle management for bucket %s", bucketName),
- DedupeKey: fmt.Sprintf("s3_lifecycle:%s", bucketName),
- Parameters: map[string]*plugin_pb.ConfigValue{
- "bucket": {Kind: &plugin_pb.ConfigValue_StringValue{StringValue: bucketName}},
- "buckets_path": {Kind: &plugin_pb.ConfigValue_StringValue{StringValue: bucketsPath}},
- "collection": {Kind: &plugin_pb.ConfigValue_StringValue{StringValue: collection}},
- "rule_count": {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: ruleCount}},
- "has_lifecycle_xml": {Kind: &plugin_pb.ConfigValue_BoolValue{BoolValue: hasLifecycleXML}},
- "versioning_status": {Kind: &plugin_pb.ConfigValue_StringValue{StringValue: versioningStatus}},
- },
- Labels: map[string]string{
- "bucket": bucketName,
- },
- }
-
- proposals = append(proposals, proposal)
- if maxResults > 0 && len(proposals) >= maxResults {
- break
- }
- }
-
- return proposals, nil
-}
-
-const defaultBucketsPath = "/buckets"
-
-// loadFilerConf reads the filer configuration from the filer.
-func loadFilerConf(ctx context.Context, client filer_pb.SeaweedFilerClient) (*filer.FilerConf, error) {
- fc := filer.NewFilerConf()
-
- content, err := filer.ReadInsideFiler(ctx, client, filer.DirectoryEtcSeaweedFS, filer.FilerConfName)
- if err != nil {
- // filer.conf may not exist yet - return empty config.
- glog.V(1).Infof("s3_lifecycle: filer.conf not found or unreadable: %v (using empty config)", err)
- return fc, nil
- }
- if err := fc.LoadFromBytes(content); err != nil {
- return nil, fmt.Errorf("parse filer.conf: %w", err)
- }
-
- return fc, nil
-}
-
-// listFilerEntries lists directory entries from the filer.
-func listFilerEntries(ctx context.Context, client filer_pb.SeaweedFilerClient, dir, startFrom string) ([]*filer_pb.Entry, error) {
- var entries []*filer_pb.Entry
- err := filer_pb.SeaweedList(ctx, client, dir, "", func(entry *filer_pb.Entry, isLast bool) error {
- entries = append(entries, entry)
- return nil
- }, startFrom, false, 10000)
- return entries, err
-}
-
-type expiredObject struct {
- dir string
- name string
-}
-
-// listExpiredObjects scans a bucket directory tree for objects whose TTL
-// has expired based on their TtlSec attribute set by PutBucketLifecycle.
-func listExpiredObjects(
- ctx context.Context,
- client filer_pb.SeaweedFilerClient,
- bucketsPath, bucket string,
- limit int64,
-) ([]expiredObject, int64, error) {
- var expired []expiredObject
- var scanned int64
-
- bucketPath := path.Join(bucketsPath, bucket)
-
- // Walk the bucket directory tree using breadth-first traversal.
- dirsToProcess := []string{bucketPath}
- for len(dirsToProcess) > 0 {
- select {
- case <-ctx.Done():
- return expired, scanned, ctx.Err()
- default:
- }
-
- dir := dirsToProcess[0]
- dirsToProcess = dirsToProcess[1:]
-
- limitReached := false
- err := filer_pb.SeaweedList(ctx, client, dir, "", func(entry *filer_pb.Entry, isLast bool) error {
- if entry.IsDirectory {
- dirsToProcess = append(dirsToProcess, path.Join(dir, entry.Name))
- return nil
- }
- scanned++
-
- if isExpiredByTTL(entry) {
- expired = append(expired, expiredObject{
- dir: dir,
- name: entry.Name,
- })
- }
-
- if limit > 0 && int64(len(expired)) >= limit {
- limitReached = true
- return fmt.Errorf("limit reached")
- }
- return nil
- }, "", false, 10000)
-
- if err != nil && !strings.Contains(err.Error(), "limit reached") {
- return expired, scanned, fmt.Errorf("list %s: %w", dir, err)
- }
-
- if limitReached || (limit > 0 && int64(len(expired)) >= limit) {
- break
- }
- }
-
- return expired, scanned, nil
-}
-
-// isExpiredByTTL checks if an entry is expired based on its TTL attribute.
-// SeaweedFS sets TtlSec on entries when lifecycle rules are applied via
-// PutBucketLifecycleConfiguration. An entry is expired when
-// creation_time + TTL < now.
-func isExpiredByTTL(entry *filer_pb.Entry) bool {
- if entry == nil || entry.Attributes == nil {
- return false
- }
-
- ttlSec := entry.Attributes.TtlSec
- if ttlSec <= 0 {
- return false
- }
-
- crTime := entry.Attributes.Crtime
- if crTime <= 0 {
- return false
- }
-
- expirationUnix := crTime + int64(ttlSec)
- return expirationUnix < nowUnix()
-}
diff --git a/weed/plugin/worker/lifecycle/detection_test.go b/weed/plugin/worker/lifecycle/detection_test.go
deleted file mode 100644
index d9ff86688..000000000
--- a/weed/plugin/worker/lifecycle/detection_test.go
+++ /dev/null
@@ -1,132 +0,0 @@
-package lifecycle
-
-import (
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
-)
-
-func TestBucketHasLifecycleXML(t *testing.T) {
- tests := []struct {
- name string
- extended map[string][]byte
- want bool
- }{
- {
- name: "has_lifecycle_xml",
- extended: map[string][]byte{lifecycleXMLKey: []byte("")},
- want: true,
- },
- {
- name: "empty_lifecycle_xml",
- extended: map[string][]byte{lifecycleXMLKey: {}},
- want: false,
- },
- {
- name: "no_lifecycle_xml",
- extended: map[string][]byte{"other-key": []byte("value")},
- want: false,
- },
- {
- name: "nil_extended",
- extended: nil,
- want: false,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := tt.extended != nil && len(tt.extended[lifecycleXMLKey]) > 0
- if got != tt.want {
- t.Errorf("hasLifecycleXML = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestBucketVersioningStatus(t *testing.T) {
- tests := []struct {
- name string
- extended map[string][]byte
- want string
- }{
- {
- name: "versioning_enabled",
- extended: map[string][]byte{
- s3_constants.ExtVersioningKey: []byte("Enabled"),
- },
- want: "Enabled",
- },
- {
- name: "versioning_suspended",
- extended: map[string][]byte{
- s3_constants.ExtVersioningKey: []byte("Suspended"),
- },
- want: "Suspended",
- },
- {
- name: "no_versioning",
- extended: map[string][]byte{},
- want: "",
- },
- {
- name: "nil_extended",
- extended: nil,
- want: "",
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- var got string
- if tt.extended != nil {
- got = string(tt.extended[s3_constants.ExtVersioningKey])
- }
- if got != tt.want {
- t.Errorf("versioningStatus = %q, want %q", got, tt.want)
- }
- })
- }
-}
-
-func TestDetectionProposalParameters(t *testing.T) {
- // Verify that bucket entries with lifecycle XML or TTL rules produce
- // proposals with the expected parameters.
- t.Run("bucket_with_lifecycle_xml_and_versioning", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Name: "my-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- lifecycleXMLKey: []byte(`Enabled`),
- s3_constants.ExtVersioningKey: []byte("Enabled"),
- },
- }
-
- hasXML := entry.Extended != nil && len(entry.Extended[lifecycleXMLKey]) > 0
- versioning := ""
- if entry.Extended != nil {
- versioning = string(entry.Extended[s3_constants.ExtVersioningKey])
- }
-
- if !hasXML {
- t.Error("expected hasLifecycleXML=true")
- }
- if versioning != "Enabled" {
- t.Errorf("expected versioning=Enabled, got %q", versioning)
- }
- })
-
- t.Run("bucket_without_lifecycle_or_ttl_is_skipped", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Name: "empty-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{},
- }
-
- hasXML := entry.Extended != nil && len(entry.Extended[lifecycleXMLKey]) > 0
- ttlCount := 0 // simulated: no TTL rules in filer.conf
-
- if hasXML || ttlCount > 0 {
- t.Error("expected bucket to be skipped (no lifecycle XML, no TTLs)")
- }
- })
-}
diff --git a/weed/plugin/worker/lifecycle/execution.go b/weed/plugin/worker/lifecycle/execution.go
deleted file mode 100644
index 183e7648e..000000000
--- a/weed/plugin/worker/lifecycle/execution.go
+++ /dev/null
@@ -1,878 +0,0 @@
-package lifecycle
-
-import (
- "context"
- "errors"
- "fmt"
- "math"
- "path"
- "sort"
- "strconv"
- "strings"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/pb/plugin_pb"
- pluginworker "github.com/seaweedfs/seaweedfs/weed/plugin/worker"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3lifecycle"
-)
-
-var errLimitReached = errors.New("limit reached")
-
-type executionResult struct {
- objectsExpired int64
- objectsScanned int64
- deleteMarkersClean int64
- mpuAborted int64
- errors int64
-}
-
-// executeLifecycleForBucket processes lifecycle rules for a single bucket:
-// 1. Reads filer.conf to get TTL rules for the bucket's collection
-// 2. Walks the bucket directory tree to find expired objects
-// 3. Deletes expired objects (unless dry run)
-func (h *Handler) executeLifecycleForBucket(
- ctx context.Context,
- filerClient filer_pb.SeaweedFilerClient,
- config Config,
- bucket, bucketsPath string,
- sender pluginworker.ExecutionSender,
- jobID string,
-) (*executionResult, error) {
- result := &executionResult{}
-
- // Try to load lifecycle rules from stored XML first (full rule evaluation).
- // Fall back to filer.conf TTL-only evaluation only if no XML is configured.
- // If XML exists but is malformed, fail closed (don't fall back to TTL,
- // which could apply broader rules and delete objects the XML rules would keep).
- // Transient filer errors fall back to TTL with a warning.
- lifecycleRules, xmlErr := loadLifecycleRulesFromBucket(ctx, filerClient, bucketsPath, bucket)
- if xmlErr != nil && errors.Is(xmlErr, errMalformedLifecycleXML) {
- glog.Errorf("s3_lifecycle: bucket %s: %v (skipping bucket)", bucket, xmlErr)
- return result, xmlErr
- }
- if xmlErr != nil {
- glog.V(1).Infof("s3_lifecycle: bucket %s: transient error loading lifecycle XML: %v, falling back to TTL", bucket, xmlErr)
- }
- // lifecycleRules is non-nil when XML was present (even if empty/all disabled).
- // Only fall back to TTL when XML was truly absent (nil).
- xmlPresent := xmlErr == nil && lifecycleRules != nil
- useRuleEval := xmlPresent && len(lifecycleRules) > 0
-
- if !useRuleEval && !xmlPresent {
- // Fall back to filer.conf TTL rules only when no lifecycle XML exists.
- // When XML is present but has no effective rules, skip TTL fallback.
- fc, err := loadFilerConf(ctx, filerClient)
- if err != nil {
- return result, fmt.Errorf("load filer conf: %w", err)
- }
- collection := bucket
- ttlRules := fc.GetCollectionTtls(collection)
- if len(ttlRules) == 0 {
- glog.V(1).Infof("s3_lifecycle: bucket %s has no lifecycle rules, skipping", bucket)
- return result, nil
- }
- }
-
- _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{
- JobId: jobID,
- JobType: jobType,
- State: plugin_pb.JobState_JOB_STATE_RUNNING,
- ProgressPercent: 10,
- Stage: "scanning",
- Message: fmt.Sprintf("scanning bucket %s for expired objects", bucket),
- })
-
- // Shared budget across all phases so we don't exceed MaxDeletesPerBucket.
- remaining := config.MaxDeletesPerBucket
-
- // Find expired objects using rule-based evaluation or TTL fallback.
- var expired []expiredObject
- var scanned int64
- var err error
- if useRuleEval {
- expired, scanned, err = listExpiredObjectsByRules(ctx, filerClient, bucketsPath, bucket, lifecycleRules, remaining)
- } else if !xmlPresent {
- // TTL-only scan when no lifecycle XML exists.
- expired, scanned, err = listExpiredObjects(ctx, filerClient, bucketsPath, bucket, remaining)
- }
- // When xmlPresent but no effective rules (all disabled), skip object scanning.
- result.objectsScanned = scanned
- if err != nil {
- return result, fmt.Errorf("list expired objects: %w", err)
- }
-
- if len(expired) > 0 {
- glog.V(1).Infof("s3_lifecycle: bucket %s: found %d expired objects out of %d scanned", bucket, len(expired), scanned)
- } else {
- glog.V(1).Infof("s3_lifecycle: bucket %s: scanned %d objects, none expired", bucket, scanned)
- }
-
- if config.DryRun && len(expired) > 0 {
- result.objectsExpired = int64(len(expired))
- _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{
- JobId: jobID,
- JobType: jobType,
- State: plugin_pb.JobState_JOB_STATE_RUNNING,
- ProgressPercent: 100,
- Stage: "dry_run",
- Message: fmt.Sprintf("dry run: would delete %d expired objects", len(expired)),
- })
- return result, nil
- }
-
- // Delete expired objects in batches.
- if len(expired) > 0 {
- _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{
- JobId: jobID,
- JobType: jobType,
- State: plugin_pb.JobState_JOB_STATE_RUNNING,
- ProgressPercent: 50,
- Stage: "deleting",
- Message: fmt.Sprintf("deleting %d expired objects", len(expired)),
- })
-
- var batchSize int
- if config.BatchSize <= 0 {
- batchSize = defaultBatchSize
- } else if config.BatchSize > math.MaxInt {
- batchSize = math.MaxInt
- } else {
- batchSize = int(config.BatchSize)
- }
-
- for i := 0; i < len(expired); i += batchSize {
- select {
- case <-ctx.Done():
- return result, ctx.Err()
- default:
- }
-
- end := i + batchSize
- if end > len(expired) {
- end = len(expired)
- }
- batch := expired[i:end]
-
- deleted, errs, batchErr := deleteExpiredObjects(ctx, filerClient, batch)
- result.objectsExpired += int64(deleted)
- result.errors += int64(errs)
-
- if batchErr != nil {
- return result, batchErr
- }
-
- progress := float64(end)/float64(len(expired))*50 + 50 // 50-100%
- _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{
- JobId: jobID,
- JobType: jobType,
- State: plugin_pb.JobState_JOB_STATE_RUNNING,
- ProgressPercent: progress,
- Stage: "deleting",
- Message: fmt.Sprintf("deleted %d/%d expired objects", result.objectsExpired, len(expired)),
- })
- }
-
- // Clean up .versions directories left empty after version deletion.
- cleanupEmptyVersionsDirectories(ctx, filerClient, expired)
-
- remaining -= result.objectsExpired + result.errors
- if remaining < 0 {
- remaining = 0
- }
- }
-
- // Delete marker cleanup.
- if config.DeleteMarkerCleanup && remaining > 0 {
- _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{
- JobId: jobID, JobType: jobType,
- State: plugin_pb.JobState_JOB_STATE_RUNNING,
- Stage: "cleaning_delete_markers", Message: "cleaning expired delete markers",
- })
- cleaned, cleanErrs, cleanCtxErr := cleanupDeleteMarkers(ctx, filerClient, bucketsPath, bucket, lifecycleRules, remaining)
- result.deleteMarkersClean = int64(cleaned)
- result.errors += int64(cleanErrs)
- if cleanCtxErr != nil {
- return result, cleanCtxErr
- }
- remaining -= int64(cleaned + cleanErrs)
- if remaining < 0 {
- remaining = 0
- }
- }
-
- // Abort incomplete multipart uploads.
- // When lifecycle XML exists, evaluate each upload against the rules
- // (respecting per-rule prefix filters and DaysAfterInitiation).
- // Fall back to worker config abort_mpu_days only when no lifecycle
- // XML is configured for the bucket.
- if xmlPresent && remaining > 0 {
- _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{
- JobId: jobID, JobType: jobType,
- State: plugin_pb.JobState_JOB_STATE_RUNNING,
- Stage: "aborting_mpus", Message: "evaluating MPU abort rules",
- })
- aborted, abortErrs, abortCtxErr := abortMPUsByRules(ctx, filerClient, bucketsPath, bucket, lifecycleRules, remaining)
- result.mpuAborted = int64(aborted)
- result.errors += int64(abortErrs)
- if abortCtxErr != nil {
- return result, abortCtxErr
- }
- } else if !xmlPresent && config.AbortMPUDays > 0 && remaining > 0 {
- _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{
- JobId: jobID, JobType: jobType,
- State: plugin_pb.JobState_JOB_STATE_RUNNING,
- Stage: "aborting_mpus", Message: fmt.Sprintf("aborting multipart uploads older than %d days", config.AbortMPUDays),
- })
- aborted, abortErrs, abortCtxErr := abortIncompleteMPUs(ctx, filerClient, bucketsPath, bucket, config.AbortMPUDays, remaining)
- result.mpuAborted = int64(aborted)
- result.errors += int64(abortErrs)
- if abortCtxErr != nil {
- return result, abortCtxErr
- }
- }
-
- return result, nil
-}
-
-// cleanupDeleteMarkers scans versioned objects and removes delete markers
-// that are the sole remaining version. This matches AWS S3
-// ExpiredObjectDeleteMarker semantics: a delete marker is only removed when
-// it is the only version of an object (no non-current versions behind it).
-//
-// This phase should run AFTER NoncurrentVersionExpiration (PR 4) so that
-// non-current versions have already been cleaned up, potentially leaving
-// delete markers as sole versions eligible for removal.
-func cleanupDeleteMarkers(
- ctx context.Context,
- client filer_pb.SeaweedFilerClient,
- bucketsPath, bucket string,
- rules []s3lifecycle.Rule,
- limit int64,
-) (cleaned, errors int, ctxErr error) {
- bucketPath := path.Join(bucketsPath, bucket)
-
- dirsToProcess := []string{bucketPath}
- for len(dirsToProcess) > 0 {
- if ctx.Err() != nil {
- return cleaned, errors, ctx.Err()
- }
-
- dir := dirsToProcess[0]
- dirsToProcess = dirsToProcess[1:]
-
- listErr := filer_pb.SeaweedList(ctx, client, dir, "", func(entry *filer_pb.Entry, isLast bool) error {
- if entry.IsDirectory {
- if dir == bucketPath && entry.Name == s3_constants.MultipartUploadsFolder {
- return nil
- }
- if strings.HasSuffix(entry.Name, s3_constants.VersionsFolder) {
- versionsDir := path.Join(dir, entry.Name)
- // Check if the latest version is a delete marker.
- latestIsMarker := string(entry.Extended[s3_constants.ExtLatestVersionIsDeleteMarker]) == "true"
- if !latestIsMarker {
- return nil
- }
- // Count versions in the directory.
- versionCount := 0
- countErr := filer_pb.SeaweedList(ctx, client, versionsDir, "", func(ve *filer_pb.Entry, _ bool) error {
- if !ve.IsDirectory {
- versionCount++
- }
- return nil
- }, "", false, 10000)
- if countErr != nil {
- glog.V(1).Infof("s3_lifecycle: failed to count versions in %s: %v", versionsDir, countErr)
- errors++
- return nil
- }
- // Only remove if the delete marker is the sole version.
- if versionCount != 1 {
- return nil
- }
- // Check that a matching ExpiredObjectDeleteMarker rule exists.
- // The rule's prefix filter must match this object's key.
- relDir := strings.TrimPrefix(versionsDir, bucketPath+"/")
- objKey := strings.TrimSuffix(relDir, s3_constants.VersionsFolder)
- if len(rules) > 0 && !matchesDeleteMarkerRule(rules, objKey) {
- return nil
- }
- // Find and remove the sole delete marker entry.
- removedHere := false
- removeErr := filer_pb.SeaweedList(ctx, client, versionsDir, "", func(ve *filer_pb.Entry, _ bool) error {
- if !ve.IsDirectory && isDeleteMarker(ve) {
- if err := filer_pb.DoRemove(ctx, client, versionsDir, ve.Name, true, false, false, false, nil); err != nil {
- glog.V(1).Infof("s3_lifecycle: failed to remove delete marker %s/%s: %v", versionsDir, ve.Name, err)
- errors++
- } else {
- cleaned++
- removedHere = true
- }
- }
- return nil
- }, "", false, 10)
- if removeErr != nil {
- glog.V(1).Infof("s3_lifecycle: failed to scan for delete marker in %s: %v", versionsDir, removeErr)
- }
- // Remove the now-empty .versions directory only if we
- // actually deleted the marker in this specific directory.
- if removedHere {
- _ = filer_pb.DoRemove(ctx, client, dir, entry.Name, true, true, true, false, nil)
- }
- return nil
- }
- dirsToProcess = append(dirsToProcess, path.Join(dir, entry.Name))
- return nil
- }
-
- // For non-versioned objects: only clean up if explicitly a delete marker
- // and a matching rule exists.
- relKey := strings.TrimPrefix(path.Join(dir, entry.Name), bucketPath+"/")
- if isDeleteMarker(entry) && matchesDeleteMarkerRule(rules, relKey) {
- if err := filer_pb.DoRemove(ctx, client, dir, entry.Name, true, false, false, false, nil); err != nil {
- glog.V(1).Infof("s3_lifecycle: failed to remove delete marker %s/%s: %v", dir, entry.Name, err)
- errors++
- } else {
- cleaned++
- }
- }
-
- if limit > 0 && int64(cleaned+errors) >= limit {
- return fmt.Errorf("limit reached")
- }
- return nil
- }, "", false, 10000)
-
- if listErr != nil && !strings.Contains(listErr.Error(), "limit reached") {
- return cleaned, errors, fmt.Errorf("list %s: %w", dir, listErr)
- }
-
- if limit > 0 && int64(cleaned+errors) >= limit {
- break
- }
- }
- return cleaned, errors, nil
-}
-
-// isDeleteMarker checks if an entry is an S3 delete marker.
-func isDeleteMarker(entry *filer_pb.Entry) bool {
- if entry == nil || entry.Extended == nil {
- return false
- }
- return string(entry.Extended[s3_constants.ExtDeleteMarkerKey]) == "true"
-}
-
-// matchesDeleteMarkerRule checks if any enabled ExpiredObjectDeleteMarker rule
-// matches the given object key using the full filter model (prefix, tags, size).
-// When no lifecycle rules are provided (nil means no XML configured),
-// falls back to legacy behavior (returns true to allow cleanup).
-// A non-nil empty slice means XML was present but had no matching rules,
-// so cleanup is not allowed.
-func matchesDeleteMarkerRule(rules []s3lifecycle.Rule, objKey string) bool {
- if rules == nil {
- return true // legacy fallback: no lifecycle XML configured
- }
- // Delete markers have no size or tags, so build a minimal ObjectInfo.
- obj := s3lifecycle.ObjectInfo{Key: objKey}
- for _, r := range rules {
- if r.Status == "Enabled" && r.ExpiredObjectDeleteMarker && s3lifecycle.MatchesFilter(r, obj) {
- return true
- }
- }
- return false
-}
-
-// abortMPUsByRules scans the .uploads directory and evaluates each upload
-// against lifecycle rules using EvaluateMPUAbort, which respects per-rule
-// prefix filters and DaysAfterInitiation thresholds.
-func abortMPUsByRules(
- ctx context.Context,
- client filer_pb.SeaweedFilerClient,
- bucketsPath, bucket string,
- rules []s3lifecycle.Rule,
- limit int64,
-) (aborted, errs int, ctxErr error) {
- uploadsDir := path.Join(bucketsPath, bucket, ".uploads")
- now := time.Now()
-
- listErr := filer_pb.SeaweedList(ctx, client, uploadsDir, "", func(entry *filer_pb.Entry, isLast bool) error {
- if ctx.Err() != nil {
- return ctx.Err()
- }
- if !entry.IsDirectory {
- return nil
- }
- if entry.Attributes == nil || entry.Attributes.Crtime <= 0 {
- return nil
- }
-
- createdAt := time.Unix(entry.Attributes.Crtime, 0)
- result := s3lifecycle.EvaluateMPUAbort(rules, entry.Name, createdAt, now)
- if result.Action == s3lifecycle.ActionAbortMultipartUpload {
- uploadPath := path.Join(uploadsDir, entry.Name)
- if err := filer_pb.DoRemove(ctx, client, uploadsDir, entry.Name, true, true, true, false, nil); err != nil {
- glog.V(1).Infof("s3_lifecycle: failed to abort MPU %s: %v", uploadPath, err)
- errs++
- } else {
- aborted++
- }
- }
-
- if limit > 0 && int64(aborted+errs) >= limit {
- return errLimitReached
- }
- return nil
- }, "", false, 10000)
-
- if listErr != nil && !errors.Is(listErr, errLimitReached) {
- return aborted, errs, fmt.Errorf("list uploads in %s: %w", uploadsDir, listErr)
- }
- return aborted, errs, nil
-}
-
-// abortIncompleteMPUs scans the .uploads directory under a bucket and
-// removes multipart upload entries older than the specified number of days.
-func abortIncompleteMPUs(
- ctx context.Context,
- client filer_pb.SeaweedFilerClient,
- bucketsPath, bucket string,
- olderThanDays, limit int64,
-) (aborted, errors int, ctxErr error) {
- uploadsDir := path.Join(bucketsPath, bucket, ".uploads")
- cutoff := time.Now().Add(-time.Duration(olderThanDays) * 24 * time.Hour)
-
- listErr := filer_pb.SeaweedList(ctx, client, uploadsDir, "", func(entry *filer_pb.Entry, isLast bool) error {
- if ctx.Err() != nil {
- return ctx.Err()
- }
-
- if !entry.IsDirectory {
- return nil
- }
-
- // Each subdirectory under .uploads is one multipart upload.
- // Check the directory creation time.
- if entry.Attributes != nil && entry.Attributes.Crtime > 0 {
- created := time.Unix(entry.Attributes.Crtime, 0)
- if created.Before(cutoff) {
- uploadPath := path.Join(uploadsDir, entry.Name)
- if err := filer_pb.DoRemove(ctx, client, uploadsDir, entry.Name, true, true, true, false, nil); err != nil {
- glog.V(1).Infof("s3_lifecycle: failed to abort MPU %s: %v", uploadPath, err)
- errors++
- } else {
- aborted++
- }
- }
- }
-
- if limit > 0 && int64(aborted+errors) >= limit {
- return fmt.Errorf("limit reached")
- }
- return nil
- }, "", false, 10000)
-
- if listErr != nil && !strings.Contains(listErr.Error(), "limit reached") {
- return aborted, errors, fmt.Errorf("list uploads in %s: %w", uploadsDir, listErr)
- }
-
- return aborted, errors, nil
-}
-
-// deleteExpiredObjects deletes a batch of expired objects from the filer.
-// Returns a non-nil error when the context is canceled mid-batch.
-func deleteExpiredObjects(
- ctx context.Context,
- client filer_pb.SeaweedFilerClient,
- objects []expiredObject,
-) (deleted, errors int, ctxErr error) {
- for _, obj := range objects {
- if ctx.Err() != nil {
- return deleted, errors, ctx.Err()
- }
-
- err := filer_pb.DoRemove(ctx, client, obj.dir, obj.name, true, false, false, false, nil)
- if err != nil {
- glog.V(1).Infof("s3_lifecycle: failed to delete %s/%s: %v", obj.dir, obj.name, err)
- errors++
- continue
- }
- deleted++
- }
- return deleted, errors, nil
-}
-
-// nowUnix returns the current time as a Unix timestamp.
-func nowUnix() int64 {
- return time.Now().Unix()
-}
-
-// listExpiredObjectsByRules scans a bucket directory tree and evaluates
-// lifecycle rules against each object using the s3lifecycle evaluator.
-// This function handles non-versioned objects (IsLatest=true). Versioned
-// objects in .versions directories are handled by processVersionsDirectory
-// (added in a separate change for NoncurrentVersionExpiration support).
-func listExpiredObjectsByRules(
- ctx context.Context,
- client filer_pb.SeaweedFilerClient,
- bucketsPath, bucket string,
- rules []s3lifecycle.Rule,
- limit int64,
-) ([]expiredObject, int64, error) {
- var expired []expiredObject
- var scanned int64
-
- bucketPath := path.Join(bucketsPath, bucket)
- now := time.Now()
- needTags := s3lifecycle.HasTagRules(rules)
-
- dirsToProcess := []string{bucketPath}
- for len(dirsToProcess) > 0 {
- select {
- case <-ctx.Done():
- return expired, scanned, ctx.Err()
- default:
- }
-
- dir := dirsToProcess[0]
- dirsToProcess = dirsToProcess[1:]
-
- limitReached := false
- err := filer_pb.SeaweedList(ctx, client, dir, "", func(entry *filer_pb.Entry, isLast bool) error {
- if entry.IsDirectory {
- if dir == bucketPath && entry.Name == s3_constants.MultipartUploadsFolder {
- return nil // skip .uploads at bucket root only
- }
- if strings.HasSuffix(entry.Name, s3_constants.VersionsFolder) {
- versionsDir := path.Join(dir, entry.Name)
-
- // Evaluate Expiration rules against the latest version.
- // In versioned buckets, data lives in .versions/ directories,
- // so we must evaluate the latest version here — it is never
- // seen as a regular file entry in the parent directory.
- if obj, ok := latestVersionExpiredByRules(ctx, client, entry, versionsDir, bucketPath, rules, now, needTags); ok {
- expired = append(expired, obj)
- scanned++
- if limit > 0 && int64(len(expired)) >= limit {
- limitReached = true
- return errLimitReached
- }
- }
-
- // Process noncurrent versions.
- vExpired, vScanned, vErr := processVersionsDirectory(ctx, client, versionsDir, bucketPath, rules, now, needTags, limit-int64(len(expired)))
- if vErr != nil {
- glog.V(1).Infof("s3_lifecycle: %v", vErr)
- return vErr
- }
- expired = append(expired, vExpired...)
- scanned += vScanned
- if limit > 0 && int64(len(expired)) >= limit {
- limitReached = true
- return errLimitReached
- }
- return nil
- }
- dirsToProcess = append(dirsToProcess, path.Join(dir, entry.Name))
- return nil
- }
- scanned++
-
- // Skip objects already handled by TTL fast path.
- if entry.Attributes != nil && entry.Attributes.TtlSec > 0 {
- expirationUnix := entry.Attributes.Crtime + int64(entry.Attributes.TtlSec)
- if expirationUnix > nowUnix() {
- return nil // will be expired by RocksDB compaction
- }
- }
-
- // Build ObjectInfo for the evaluator.
- relKey := strings.TrimPrefix(path.Join(dir, entry.Name), bucketPath+"/")
- objInfo := s3lifecycle.ObjectInfo{
- Key: relKey,
- IsLatest: true, // non-versioned objects are always "latest"
- }
- if entry.Attributes != nil {
- objInfo.Size = int64(entry.Attributes.GetFileSize())
- if entry.Attributes.Mtime > 0 {
- objInfo.ModTime = time.Unix(entry.Attributes.Mtime, 0)
- } else if entry.Attributes.Crtime > 0 {
- objInfo.ModTime = time.Unix(entry.Attributes.Crtime, 0)
- }
- }
- if needTags {
- objInfo.Tags = s3lifecycle.ExtractTags(entry.Extended)
- }
-
- result := s3lifecycle.Evaluate(rules, objInfo, now)
- if result.Action == s3lifecycle.ActionDeleteObject {
- expired = append(expired, expiredObject{dir: dir, name: entry.Name})
- }
-
- if limit > 0 && int64(len(expired)) >= limit {
- limitReached = true
- return errLimitReached
- }
- return nil
- }, "", false, 10000)
-
- if err != nil && !errors.Is(err, errLimitReached) {
- return expired, scanned, fmt.Errorf("list %s: %w", dir, err)
- }
-
- if limitReached || (limit > 0 && int64(len(expired)) >= limit) {
- break
- }
- }
-
- return expired, scanned, nil
-}
-
-// processVersionsDirectory evaluates NoncurrentVersionExpiration rules
-// against all versions in a .versions directory.
-func processVersionsDirectory(
- ctx context.Context,
- client filer_pb.SeaweedFilerClient,
- versionsDir, bucketPath string,
- rules []s3lifecycle.Rule,
- now time.Time,
- needTags bool,
- limit int64,
-) ([]expiredObject, int64, error) {
- var expired []expiredObject
- var scanned int64
-
- // Check if any rule has NoncurrentVersionExpiration.
- hasNoncurrentRules := false
- for _, r := range rules {
- if r.Status == "Enabled" && r.NoncurrentVersionExpirationDays > 0 {
- hasNoncurrentRules = true
- break
- }
- }
- if !hasNoncurrentRules {
- return nil, 0, nil
- }
-
- // List all versions in this directory.
- var versions []*filer_pb.Entry
- listErr := filer_pb.SeaweedList(ctx, client, versionsDir, "", func(entry *filer_pb.Entry, isLast bool) error {
- if !entry.IsDirectory {
- versions = append(versions, entry)
- }
- return nil
- }, "", false, 10000)
- if listErr != nil {
- return nil, 0, fmt.Errorf("list versions in %s: %w", versionsDir, listErr)
- }
- if len(versions) <= 1 {
- return nil, 0, nil // only one version (the latest), nothing to expire
- }
-
- // Sort by version timestamp, newest first.
- sortVersionsByVersionId(versions)
-
- // Derive the object key from the .versions directory path.
- // e.g., /buckets/mybucket/path/to/key.versions -> path/to/key
- relDir := strings.TrimPrefix(versionsDir, bucketPath+"/")
- objKey := strings.TrimSuffix(relDir, s3_constants.VersionsFolder)
-
- // Walk versions: first is latest, rest are non-current.
- noncurrentIndex := 0
- for i := 1; i < len(versions); i++ {
- entry := versions[i]
- scanned++
-
- // Skip delete markers from expiration evaluation, but count
- // them toward NewerNoncurrentVersions so data versions get
- // the correct noncurrent index.
- if isDeleteMarker(entry) {
- noncurrentIndex++
- continue
- }
-
- // Determine successor's timestamp (the version that replaced this one).
- successorEntry := versions[i-1]
- successorVersionId := strings.TrimPrefix(successorEntry.Name, "v_")
- successorTime := s3lifecycle.GetVersionTimestamp(successorVersionId)
- if successorTime.IsZero() && successorEntry.Attributes != nil && successorEntry.Attributes.Mtime > 0 {
- successorTime = time.Unix(successorEntry.Attributes.Mtime, 0)
- }
-
- objInfo := s3lifecycle.ObjectInfo{
- Key: objKey,
- IsLatest: false,
- SuccessorModTime: successorTime,
- NumVersions: len(versions),
- NoncurrentIndex: noncurrentIndex,
- }
- if entry.Attributes != nil {
- objInfo.Size = int64(entry.Attributes.GetFileSize())
- if entry.Attributes.Mtime > 0 {
- objInfo.ModTime = time.Unix(entry.Attributes.Mtime, 0)
- }
- }
- if needTags {
- objInfo.Tags = s3lifecycle.ExtractTags(entry.Extended)
- }
-
- // Evaluate using the detailed ShouldExpireNoncurrentVersion which
- // handles NewerNoncurrentVersions.
- for _, rule := range rules {
- if s3lifecycle.ShouldExpireNoncurrentVersion(rule, objInfo, noncurrentIndex, now) {
- expired = append(expired, expiredObject{dir: versionsDir, name: entry.Name})
- break
- }
- }
-
- noncurrentIndex++
-
- if limit > 0 && int64(len(expired)) >= limit {
- break
- }
- }
-
- return expired, scanned, nil
-}
-
-// latestVersionExpiredByRules evaluates Expiration rules (Days/Date) against
-// the latest version in a .versions directory. In versioned buckets all data
-// lives inside .versions/ directories, so the latest version is never seen as
-// a regular file entry during the bucket walk. Without this check, Expiration
-// rules would never fire for versioned objects (issue #8757).
-//
-// The .versions directory entry caches metadata about the latest version in
-// its Extended attributes, so we can evaluate expiration without an extra
-// filer round-trip.
-func latestVersionExpiredByRules(
- ctx context.Context,
- client filer_pb.SeaweedFilerClient,
- dirEntry *filer_pb.Entry,
- versionsDir, bucketPath string,
- rules []s3lifecycle.Rule,
- now time.Time,
- needTags bool,
-) (expiredObject, bool) {
- if dirEntry.Extended == nil {
- return expiredObject{}, false
- }
-
- // Skip if the latest version is a delete marker — those are handled
- // by the ExpiredObjectDeleteMarker rule in cleanupDeleteMarkers.
- if string(dirEntry.Extended[s3_constants.ExtLatestVersionIsDeleteMarker]) == "true" {
- return expiredObject{}, false
- }
-
- latestFileName := string(dirEntry.Extended[s3_constants.ExtLatestVersionFileNameKey])
- if latestFileName == "" {
- return expiredObject{}, false
- }
-
- // Derive the object key: /buckets/b/path/key.versions → path/key
- relDir := strings.TrimPrefix(versionsDir, bucketPath+"/")
- objKey := strings.TrimSuffix(relDir, s3_constants.VersionsFolder)
-
- objInfo := s3lifecycle.ObjectInfo{
- Key: objKey,
- IsLatest: true,
- }
-
- // Populate ModTime from cached metadata.
- if mtimeStr := string(dirEntry.Extended[s3_constants.ExtLatestVersionMtimeKey]); mtimeStr != "" {
- if mtime, err := strconv.ParseInt(mtimeStr, 10, 64); err == nil {
- objInfo.ModTime = time.Unix(mtime, 0)
- }
- }
- if objInfo.ModTime.IsZero() && dirEntry.Attributes != nil && dirEntry.Attributes.Mtime > 0 {
- objInfo.ModTime = time.Unix(dirEntry.Attributes.Mtime, 0)
- }
-
- // Populate Size from cached metadata.
- if sizeStr := string(dirEntry.Extended[s3_constants.ExtLatestVersionSizeKey]); sizeStr != "" {
- if size, err := strconv.ParseInt(sizeStr, 10, 64); err == nil {
- objInfo.Size = size
- }
- }
-
- if needTags {
- // Tags are stored on the version file entry, not the .versions
- // directory. Fetch the actual version file to get them.
- resp, err := client.LookupDirectoryEntry(ctx, &filer_pb.LookupDirectoryEntryRequest{
- Directory: versionsDir,
- Name: latestFileName,
- })
- if err == nil && resp.Entry != nil {
- objInfo.Tags = s3lifecycle.ExtractTags(resp.Entry.Extended)
- }
- }
-
- result := s3lifecycle.Evaluate(rules, objInfo, now)
- if result.Action == s3lifecycle.ActionDeleteObject {
- return expiredObject{dir: versionsDir, name: latestFileName}, true
- }
-
- return expiredObject{}, false
-}
-
-// cleanupEmptyVersionsDirectories removes .versions directories that became
-// empty after their contents were deleted. This is called after
-// deleteExpiredObjects to avoid leaving orphaned directories.
-func cleanupEmptyVersionsDirectories(
- ctx context.Context,
- client filer_pb.SeaweedFilerClient,
- deleted []expiredObject,
-) int {
- // Collect unique .versions directories that had entries deleted.
- versionsDirs := map[string]struct{}{}
- for _, obj := range deleted {
- if strings.HasSuffix(obj.dir, s3_constants.VersionsFolder) {
- versionsDirs[obj.dir] = struct{}{}
- }
- }
-
- cleaned := 0
- for vDir := range versionsDirs {
- if ctx.Err() != nil {
- break
- }
- // Check if the directory is now empty.
- empty := true
- listErr := filer_pb.SeaweedList(ctx, client, vDir, "", func(entry *filer_pb.Entry, isLast bool) error {
- empty = false
- return errLimitReached // stop after first entry
- }, "", false, 1)
-
- if listErr != nil && !errors.Is(listErr, errLimitReached) {
- glog.V(1).Infof("s3_lifecycle: failed to check if versions dir %s is empty: %v", vDir, listErr)
- continue
- }
-
- if !empty {
- continue
- }
-
- // Remove the empty .versions directory.
- parentDir, dirName := path.Split(vDir)
- parentDir = strings.TrimSuffix(parentDir, "/")
- if err := filer_pb.DoRemove(ctx, client, parentDir, dirName, false, true, true, false, nil); err != nil {
- glog.V(1).Infof("s3_lifecycle: failed to clean up empty versions dir %s: %v", vDir, err)
- } else {
- cleaned++
- }
- }
- return cleaned
-}
-
-// sortVersionsByVersionId sorts version entries newest-first using full
-// version ID comparison (matching compareVersionIds in s3api_version_id.go).
-// This uses the complete version ID string, not just the decoded timestamp,
-// so entries with the same timestamp prefix are correctly ordered by their
-// random suffix.
-func sortVersionsByVersionId(versions []*filer_pb.Entry) {
- sort.Slice(versions, func(i, j int) bool {
- vidI := strings.TrimPrefix(versions[i].Name, "v_")
- vidJ := strings.TrimPrefix(versions[j].Name, "v_")
- return s3lifecycle.CompareVersionIds(vidI, vidJ) < 0
- })
-}
diff --git a/weed/plugin/worker/lifecycle/execution_test.go b/weed/plugin/worker/lifecycle/execution_test.go
deleted file mode 100644
index cfcae7613..000000000
--- a/weed/plugin/worker/lifecycle/execution_test.go
+++ /dev/null
@@ -1,72 +0,0 @@
-package lifecycle
-
-import (
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3lifecycle"
-)
-
-func TestMatchesDeleteMarkerRule(t *testing.T) {
- t.Run("nil_rules_legacy_fallback", func(t *testing.T) {
- if !matchesDeleteMarkerRule(nil, "any/key") {
- t.Error("nil rules should return true (legacy fallback)")
- }
- })
-
- t.Run("empty_rules_xml_present_no_match", func(t *testing.T) {
- rules := []s3lifecycle.Rule{}
- if matchesDeleteMarkerRule(rules, "any/key") {
- t.Error("empty rules (XML present) should return false")
- }
- })
-
- t.Run("matching_prefix_rule", func(t *testing.T) {
- rules := []s3lifecycle.Rule{
- {ID: "cleanup", Status: "Enabled", Prefix: "logs/", ExpiredObjectDeleteMarker: true},
- }
- if !matchesDeleteMarkerRule(rules, "logs/app.log") {
- t.Error("should match rule with matching prefix")
- }
- })
-
- t.Run("non_matching_prefix", func(t *testing.T) {
- rules := []s3lifecycle.Rule{
- {ID: "cleanup", Status: "Enabled", Prefix: "logs/", ExpiredObjectDeleteMarker: true},
- }
- if matchesDeleteMarkerRule(rules, "data/file.txt") {
- t.Error("should not match rule with non-matching prefix")
- }
- })
-
- t.Run("disabled_rule", func(t *testing.T) {
- rules := []s3lifecycle.Rule{
- {ID: "cleanup", Status: "Disabled", ExpiredObjectDeleteMarker: true},
- }
- if matchesDeleteMarkerRule(rules, "any/key") {
- t.Error("disabled rule should not match")
- }
- })
-
- t.Run("rule_without_delete_marker_flag", func(t *testing.T) {
- rules := []s3lifecycle.Rule{
- {ID: "expire", Status: "Enabled", ExpirationDays: 30},
- }
- if matchesDeleteMarkerRule(rules, "any/key") {
- t.Error("rule without ExpiredObjectDeleteMarker should not match")
- }
- })
-
- t.Run("tag_filtered_rule_no_tags_on_marker", func(t *testing.T) {
- rules := []s3lifecycle.Rule{
- {
- ID: "tagged", Status: "Enabled",
- ExpiredObjectDeleteMarker: true,
- FilterTags: map[string]string{"env": "dev"},
- },
- }
- // Delete markers have no tags, so a tag-filtered rule should not match.
- if matchesDeleteMarkerRule(rules, "any/key") {
- t.Error("tag-filtered rule should not match delete marker (no tags)")
- }
- })
-}
diff --git a/weed/plugin/worker/lifecycle/handler.go b/weed/plugin/worker/lifecycle/handler.go
deleted file mode 100644
index 22ab4d1ff..000000000
--- a/weed/plugin/worker/lifecycle/handler.go
+++ /dev/null
@@ -1,380 +0,0 @@
-package lifecycle
-
-import (
- "context"
- "fmt"
- "strings"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/pb"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/pb/plugin_pb"
- pluginworker "github.com/seaweedfs/seaweedfs/weed/plugin/worker"
- "google.golang.org/grpc"
- "google.golang.org/protobuf/types/known/timestamppb"
-)
-
-func init() {
- pluginworker.RegisterHandler(pluginworker.HandlerFactory{
- JobType: jobType,
- Category: pluginworker.CategoryHeavy,
- Aliases: []string{"lifecycle", "s3-lifecycle", "s3.lifecycle"},
- Build: func(opts pluginworker.HandlerBuildOptions) (pluginworker.JobHandler, error) {
- return NewHandler(opts.GrpcDialOption), nil
- },
- })
-}
-
-// Handler implements the JobHandler interface for S3 lifecycle management:
-// object expiration, delete marker cleanup, and abort incomplete multipart uploads.
-type Handler struct {
- grpcDialOption grpc.DialOption
-}
-
-const filerConnectTimeout = 5 * time.Second
-
-// NewHandler creates a new handler for S3 lifecycle management.
-func NewHandler(grpcDialOption grpc.DialOption) *Handler {
- return &Handler{grpcDialOption: grpcDialOption}
-}
-
-func (h *Handler) Capability() *plugin_pb.JobTypeCapability {
- return &plugin_pb.JobTypeCapability{
- JobType: jobType,
- CanDetect: true,
- CanExecute: true,
- MaxDetectionConcurrency: 1,
- MaxExecutionConcurrency: 4,
- DisplayName: "S3 Lifecycle",
- Description: "Manages S3 object lifecycle: expiration of objects based on TTL rules, delete marker cleanup, and abort of incomplete multipart uploads",
- Weight: 40,
- }
-}
-
-func (h *Handler) Descriptor() *plugin_pb.JobTypeDescriptor {
- return &plugin_pb.JobTypeDescriptor{
- JobType: jobType,
- DisplayName: "S3 Lifecycle Management",
- Description: "Automated S3 object lifecycle management: expire objects by TTL rules, clean up expired delete markers, and abort stale multipart uploads",
- Icon: "fas fa-hourglass-half",
- DescriptorVersion: 1,
- AdminConfigForm: &plugin_pb.ConfigForm{
- FormId: "s3-lifecycle-admin",
- Title: "S3 Lifecycle Admin Config",
- Description: "Admin-side controls for S3 lifecycle management scope.",
- Sections: []*plugin_pb.ConfigSection{
- {
- SectionId: "scope",
- Title: "Scope",
- Description: "Which buckets to include in lifecycle management.",
- Fields: []*plugin_pb.ConfigField{
- {
- Name: "bucket_filter",
- Label: "Bucket Filter",
- Description: "Wildcard pattern for bucket names to include (e.g. \"prod-*\"). Empty means all buckets.",
- FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_STRING,
- Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_TEXT,
- },
- },
- },
- },
- },
- WorkerConfigForm: &plugin_pb.ConfigForm{
- FormId: "s3-lifecycle-worker",
- Title: "S3 Lifecycle Worker Config",
- Description: "Worker-side controls for lifecycle execution behavior.",
- Sections: []*plugin_pb.ConfigSection{
- {
- SectionId: "execution",
- Title: "Execution",
- Description: "Controls for lifecycle rule execution.",
- Fields: []*plugin_pb.ConfigField{
- {
- Name: "batch_size",
- Label: "Batch Size",
- Description: "Number of entries to process per filer listing page.",
- FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_INT64,
- Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_NUMBER,
- MinValue: configInt64(100),
- MaxValue: configInt64(10000),
- },
- {
- Name: "max_deletes_per_bucket",
- Label: "Max Deletes Per Bucket",
- Description: "Maximum number of expired objects to delete per bucket in one execution run.",
- FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_INT64,
- Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_NUMBER,
- MinValue: configInt64(100),
- MaxValue: configInt64(1000000),
- },
- {
- Name: "dry_run",
- Label: "Dry Run",
- Description: "When enabled, detect expired objects but do not delete them.",
- FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_BOOL,
- Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_TOGGLE,
- },
- {
- Name: "delete_marker_cleanup",
- Label: "Delete Marker Cleanup",
- Description: "Remove expired delete markers that have no non-current versions.",
- FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_BOOL,
- Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_TOGGLE,
- },
- {
- Name: "abort_mpu_days",
- Label: "Abort Incomplete MPU (days)",
- Description: "Abort incomplete multipart uploads older than this many days. 0 disables.",
- FieldType: plugin_pb.ConfigFieldType_CONFIG_FIELD_TYPE_INT64,
- Widget: plugin_pb.ConfigWidget_CONFIG_WIDGET_NUMBER,
- MinValue: configInt64(0),
- MaxValue: configInt64(365),
- },
- },
- },
- },
- },
- AdminRuntimeDefaults: &plugin_pb.AdminRuntimeDefaults{
- Enabled: true,
- DetectionIntervalSeconds: 300, // 5 minutes
- DetectionTimeoutSeconds: 60,
- MaxJobsPerDetection: 100,
- GlobalExecutionConcurrency: 2,
- PerWorkerExecutionConcurrency: 2,
- RetryLimit: 1,
- RetryBackoffSeconds: 10,
- },
- WorkerDefaultValues: map[string]*plugin_pb.ConfigValue{
- "batch_size": {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: defaultBatchSize}},
- "max_deletes_per_bucket": {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: defaultMaxDeletesPerBucket}},
- "dry_run": {Kind: &plugin_pb.ConfigValue_BoolValue{BoolValue: defaultDryRun}},
- "delete_marker_cleanup": {Kind: &plugin_pb.ConfigValue_BoolValue{BoolValue: defaultDeleteMarkerCleanup}},
- "abort_mpu_days": {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: defaultAbortMPUDaysDefault}},
- },
- }
-}
-
-func (h *Handler) Detect(ctx context.Context, req *plugin_pb.RunDetectionRequest, sender pluginworker.DetectionSender) error {
- if req == nil {
- return fmt.Errorf("nil detection request")
- }
-
- config := ParseConfig(req.WorkerConfigValues)
-
- bucketFilter := readStringConfig(req.AdminConfigValues, "bucket_filter", "")
-
- filerAddresses := filerAddressesFromCluster(req.ClusterContext)
- if len(filerAddresses) == 0 {
- _ = sender.SendActivity(pluginworker.BuildDetectorActivity("skipped", "no filer addresses in cluster context", nil))
- return sendEmptyDetection(sender)
- }
-
- _ = sender.SendActivity(pluginworker.BuildDetectorActivity("connecting", "connecting to filer", nil))
-
- filerClient, filerConn, err := connectToFiler(ctx, filerAddresses, h.grpcDialOption)
- if err != nil {
- return fmt.Errorf("failed to connect to any filer: %v", err)
- }
- defer filerConn.Close()
-
- maxResults := int(req.MaxResults)
- if maxResults <= 0 {
- maxResults = 100
- }
-
- _ = sender.SendActivity(pluginworker.BuildDetectorActivity("scanning", "scanning buckets for lifecycle rules", nil))
- proposals, err := h.detectBucketsWithLifecycleRules(ctx, filerClient, config, bucketFilter, maxResults)
- if err != nil {
- _ = sender.SendActivity(pluginworker.BuildDetectorActivity("scan_error", fmt.Sprintf("error scanning buckets: %v", err), nil))
- return fmt.Errorf("detect lifecycle rules: %w", err)
- }
-
- _ = sender.SendActivity(pluginworker.BuildDetectorActivity("scan_complete",
- fmt.Sprintf("found %d bucket(s) with lifecycle rules", len(proposals)),
- map[string]*plugin_pb.ConfigValue{
- "buckets_found": {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: int64(len(proposals))}},
- }))
-
- if err := sender.SendProposals(&plugin_pb.DetectionProposals{
- JobType: jobType,
- Proposals: proposals,
- HasMore: len(proposals) >= maxResults,
- }); err != nil {
- return err
- }
-
- return sender.SendComplete(&plugin_pb.DetectionComplete{
- JobType: jobType,
- Success: true,
- TotalProposals: int32(len(proposals)),
- })
-}
-
-func (h *Handler) Execute(ctx context.Context, req *plugin_pb.ExecuteJobRequest, sender pluginworker.ExecutionSender) error {
- if req == nil || req.Job == nil {
- return fmt.Errorf("nil execution request")
- }
-
- job := req.Job
- config := ParseConfig(req.WorkerConfigValues)
-
- bucket := readParamString(job.Parameters, "bucket")
- bucketsPath := readParamString(job.Parameters, "buckets_path")
- if bucket == "" || bucketsPath == "" {
- return fmt.Errorf("missing bucket or buckets_path parameter")
- }
-
- filerAddresses := filerAddressesFromCluster(req.ClusterContext)
- if len(filerAddresses) == 0 {
- return fmt.Errorf("no filer addresses in cluster context")
- }
-
- filerClient, filerConn, err := connectToFiler(ctx, filerAddresses, h.grpcDialOption)
- if err != nil {
- return fmt.Errorf("failed to connect to any filer: %v", err)
- }
- defer filerConn.Close()
-
- _ = sender.SendProgress(&plugin_pb.JobProgressUpdate{
- JobId: job.JobId,
- JobType: jobType,
- State: plugin_pb.JobState_JOB_STATE_ASSIGNED,
- ProgressPercent: 0,
- Stage: "starting",
- Message: fmt.Sprintf("executing lifecycle rules for bucket %s", bucket),
- })
-
- start := time.Now()
- result, execErr := h.executeLifecycleForBucket(ctx, filerClient, config, bucket, bucketsPath, sender, job.JobId)
- elapsed := time.Since(start)
-
- metrics := map[string]*plugin_pb.ConfigValue{
- MetricDurationMs: {Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: elapsed.Milliseconds()}},
- }
- if result != nil {
- metrics[MetricObjectsExpired] = &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: result.objectsExpired}}
- metrics[MetricObjectsScanned] = &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: result.objectsScanned}}
- metrics[MetricDeleteMarkersClean] = &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: result.deleteMarkersClean}}
- metrics[MetricMPUAborted] = &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: result.mpuAborted}}
- metrics[MetricErrors] = &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: result.errors}}
- }
-
- var scanned, expired int64
- if result != nil {
- scanned = result.objectsScanned
- expired = result.objectsExpired
- }
-
- success := execErr == nil && (result == nil || result.errors == 0)
- message := fmt.Sprintf("bucket %s: scanned %d objects, expired %d", bucket, scanned, expired)
- if result != nil && result.deleteMarkersClean > 0 {
- message += fmt.Sprintf(", delete markers cleaned %d", result.deleteMarkersClean)
- }
- if result != nil && result.mpuAborted > 0 {
- message += fmt.Sprintf(", MPUs aborted %d", result.mpuAborted)
- }
- if config.DryRun {
- message += " (dry run)"
- }
- if result != nil && result.errors > 0 {
- message += fmt.Sprintf(" (%d errors)", result.errors)
- }
- if execErr != nil {
- message = fmt.Sprintf("lifecycle execution failed for bucket %s: %v", bucket, execErr)
- }
-
- errMsg := ""
- if execErr != nil {
- errMsg = execErr.Error()
- } else if result != nil && result.errors > 0 {
- errMsg = fmt.Sprintf("%d objects failed to process", result.errors)
- }
-
- return sender.SendCompleted(&plugin_pb.JobCompleted{
- JobId: job.JobId,
- JobType: jobType,
- Success: success,
- ErrorMessage: errMsg,
- Result: &plugin_pb.JobResult{
- Summary: message,
- OutputValues: metrics,
- },
- CompletedAt: timestamppb.Now(),
- })
-}
-
-func connectToFiler(ctx context.Context, addresses []string, dialOption grpc.DialOption) (filer_pb.SeaweedFilerClient, *grpc.ClientConn, error) {
- var lastErr error
- for _, addr := range addresses {
- grpcAddr := pb.ServerAddress(addr).ToGrpcAddress()
- connCtx, cancel := context.WithTimeout(ctx, filerConnectTimeout)
- conn, err := pb.GrpcDial(connCtx, grpcAddr, false, dialOption)
- cancel()
- if err != nil {
- lastErr = err
- glog.V(1).Infof("s3_lifecycle: failed to connect to filer %s (grpc %s): %v", addr, grpcAddr, err)
- continue
- }
- // Verify the connection with a ping.
- client := filer_pb.NewSeaweedFilerClient(conn)
- pingCtx, pingCancel := context.WithTimeout(ctx, filerConnectTimeout)
- _, pingErr := client.Ping(pingCtx, &filer_pb.PingRequest{})
- pingCancel()
- if pingErr != nil {
- _ = conn.Close()
- lastErr = pingErr
- glog.V(1).Infof("s3_lifecycle: filer %s ping failed: %v", grpcAddr, pingErr)
- continue
- }
- return client, conn, nil
- }
- return nil, nil, lastErr
-}
-
-func sendEmptyDetection(sender pluginworker.DetectionSender) error {
- if err := sender.SendProposals(&plugin_pb.DetectionProposals{
- JobType: jobType,
- Proposals: []*plugin_pb.JobProposal{},
- HasMore: false,
- }); err != nil {
- return err
- }
- return sender.SendComplete(&plugin_pb.DetectionComplete{
- JobType: jobType,
- Success: true,
- TotalProposals: 0,
- })
-}
-
-func filerAddressesFromCluster(cc *plugin_pb.ClusterContext) []string {
- if cc == nil {
- return nil
- }
- var addrs []string
- for _, addr := range cc.FilerGrpcAddresses {
- trimmed := strings.TrimSpace(addr)
- if trimmed != "" {
- addrs = append(addrs, trimmed)
- }
- }
- return addrs
-}
-
-func readParamString(params map[string]*plugin_pb.ConfigValue, key string) string {
- if params == nil {
- return ""
- }
- v := params[key]
- if v == nil {
- return ""
- }
- if sv, ok := v.Kind.(*plugin_pb.ConfigValue_StringValue); ok {
- return sv.StringValue
- }
- return ""
-}
-
-func configInt64(v int64) *plugin_pb.ConfigValue {
- return &plugin_pb.ConfigValue{Kind: &plugin_pb.ConfigValue_Int64Value{Int64Value: v}}
-}
diff --git a/weed/plugin/worker/lifecycle/integration_test.go b/weed/plugin/worker/lifecycle/integration_test.go
deleted file mode 100644
index 60b11175c..000000000
--- a/weed/plugin/worker/lifecycle/integration_test.go
+++ /dev/null
@@ -1,781 +0,0 @@
-package lifecycle
-
-import (
- "context"
- "fmt"
- "math"
- "net"
- "sort"
- "strconv"
- "strings"
- "sync"
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/pb"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3lifecycle"
- "google.golang.org/grpc"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/credentials/insecure"
- "google.golang.org/grpc/status"
- "google.golang.org/protobuf/proto"
-)
-
-// testFilerServer is an in-memory filer gRPC server for integration tests.
-type testFilerServer struct {
- filer_pb.UnimplementedSeaweedFilerServer
- mu sync.RWMutex
- entries map[string]*filer_pb.Entry // key: "dir\x00name"
-}
-
-func newTestFilerServer() *testFilerServer {
- return &testFilerServer{entries: make(map[string]*filer_pb.Entry)}
-}
-
-func (s *testFilerServer) key(dir, name string) string { return dir + "\x00" + name }
-
-func (s *testFilerServer) splitKey(key string) (string, string) {
- for i := range key {
- if key[i] == '\x00' {
- return key[:i], key[i+1:]
- }
- }
- return key, ""
-}
-
-func (s *testFilerServer) putEntry(dir string, entry *filer_pb.Entry) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.entries[s.key(dir, entry.Name)] = proto.Clone(entry).(*filer_pb.Entry)
-}
-
-func (s *testFilerServer) getEntry(dir, name string) *filer_pb.Entry {
- s.mu.RLock()
- defer s.mu.RUnlock()
- e := s.entries[s.key(dir, name)]
- if e == nil {
- return nil
- }
- return proto.Clone(e).(*filer_pb.Entry)
-}
-
-func (s *testFilerServer) hasEntry(dir, name string) bool {
- s.mu.RLock()
- defer s.mu.RUnlock()
- _, ok := s.entries[s.key(dir, name)]
- return ok
-}
-
-func (s *testFilerServer) LookupDirectoryEntry(_ context.Context, req *filer_pb.LookupDirectoryEntryRequest) (*filer_pb.LookupDirectoryEntryResponse, error) {
- s.mu.RLock()
- defer s.mu.RUnlock()
- entry, found := s.entries[s.key(req.Directory, req.Name)]
- if !found {
- return nil, status.Error(codes.NotFound, filer_pb.ErrNotFound.Error())
- }
- return &filer_pb.LookupDirectoryEntryResponse{Entry: proto.Clone(entry).(*filer_pb.Entry)}, nil
-}
-
-func (s *testFilerServer) ListEntries(req *filer_pb.ListEntriesRequest, stream grpc.ServerStreamingServer[filer_pb.ListEntriesResponse]) error {
- // Snapshot entries under lock, then stream without holding the lock
- // (streaming callbacks may trigger DeleteEntry which needs a write lock).
- s.mu.RLock()
- var names []string
- for key := range s.entries {
- dir, name := s.splitKey(key)
- if dir == req.Directory {
- if req.StartFromFileName != "" && name <= req.StartFromFileName {
- continue
- }
- if req.Prefix != "" && !strings.HasPrefix(name, req.Prefix) {
- continue
- }
- names = append(names, name)
- }
- }
- sort.Strings(names)
-
- // Clone entries while still holding the lock.
- type namedEntry struct {
- name string
- entry *filer_pb.Entry
- }
- snapshot := make([]namedEntry, 0, len(names))
- for _, name := range names {
- if req.Limit > 0 && uint32(len(snapshot)) >= req.Limit {
- break
- }
- snapshot = append(snapshot, namedEntry{
- name: name,
- entry: proto.Clone(s.entries[s.key(req.Directory, name)]).(*filer_pb.Entry),
- })
- }
- s.mu.RUnlock()
-
- // Stream responses without holding any lock.
- for _, ne := range snapshot {
- if err := stream.Send(&filer_pb.ListEntriesResponse{Entry: ne.entry}); err != nil {
- return err
- }
- }
- return nil
-}
-
-func (s *testFilerServer) CreateEntry(_ context.Context, req *filer_pb.CreateEntryRequest) (*filer_pb.CreateEntryResponse, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.entries[s.key(req.Directory, req.Entry.Name)] = proto.Clone(req.Entry).(*filer_pb.Entry)
- return &filer_pb.CreateEntryResponse{}, nil
-}
-
-func (s *testFilerServer) DeleteEntry(_ context.Context, req *filer_pb.DeleteEntryRequest) (*filer_pb.DeleteEntryResponse, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
- k := s.key(req.Directory, req.Name)
- if _, found := s.entries[k]; !found {
- return nil, status.Error(codes.NotFound, filer_pb.ErrNotFound.Error())
- }
- delete(s.entries, k)
- if req.IsRecursive {
- // Delete all descendants: any entry whose directory starts with
- // the deleted path (handles nested subdirectories).
- deletedPath := req.Directory + "/" + req.Name
- for key := range s.entries {
- dir, _ := s.splitKey(key)
- if dir == deletedPath || strings.HasPrefix(dir, deletedPath+"/") {
- delete(s.entries, key)
- }
- }
- }
- return &filer_pb.DeleteEntryResponse{}, nil
-}
-
-// startTestFiler starts an in-memory filer gRPC server and returns a client.
-func startTestFiler(t *testing.T) (*testFilerServer, filer_pb.SeaweedFilerClient) {
- t.Helper()
-
- lis, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- t.Fatalf("listen: %v", err)
- }
-
- server := newTestFilerServer()
- grpcServer := pb.NewGrpcServer()
- filer_pb.RegisterSeaweedFilerServer(grpcServer, server)
- go func() { _ = grpcServer.Serve(lis) }()
-
- t.Cleanup(func() {
- grpcServer.Stop()
- _ = lis.Close()
- })
-
- host, portStr, err := net.SplitHostPort(lis.Addr().String())
- if err != nil {
- t.Fatalf("split host port: %v", err)
- }
- port, err := strconv.Atoi(portStr)
- if err != nil {
- t.Fatalf("parse port: %v", err)
- }
- addr := pb.NewServerAddress(host, 1, port)
-
- conn, err := pb.GrpcDial(context.Background(), addr.ToGrpcAddress(), false, grpc.WithTransportCredentials(insecure.NewCredentials()))
- if err != nil {
- t.Fatalf("dial: %v", err)
- }
- t.Cleanup(func() { _ = conn.Close() })
-
- return server, filer_pb.NewSeaweedFilerClient(conn)
-}
-
-// Helper to create a version ID from a timestamp.
-func testVersionId(ts time.Time) string {
- inverted := math.MaxInt64 - ts.UnixNano()
- return fmt.Sprintf("%016x", inverted) + "0000000000000000"
-}
-
-func TestIntegration_ListExpiredObjectsByRules(t *testing.T) {
- server, client := startTestFiler(t)
- bucketsPath := "/buckets"
- bucket := "test-bucket"
- bucketDir := bucketsPath + "/" + bucket
-
- now := time.Now()
- old := now.Add(-60 * 24 * time.Hour) // 60 days ago
- recent := now.Add(-5 * 24 * time.Hour) // 5 days ago
-
- // Create bucket directory.
- server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true})
-
- // Create objects.
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "old-file.txt",
- Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 1024},
- })
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "recent-file.txt",
- Attributes: &filer_pb.FuseAttributes{Mtime: recent.Unix(), FileSize: 1024},
- })
-
- rules := []s3lifecycle.Rule{{
- ID: "expire-30d", Status: "Enabled",
- ExpirationDays: 30,
- }}
-
- expired, scanned, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100)
- if err != nil {
- t.Fatalf("listExpiredObjectsByRules: %v", err)
- }
-
- if scanned != 2 {
- t.Errorf("expected 2 scanned, got %d", scanned)
- }
- if len(expired) != 1 {
- t.Fatalf("expected 1 expired, got %d", len(expired))
- }
- if expired[0].name != "old-file.txt" {
- t.Errorf("expected old-file.txt expired, got %s", expired[0].name)
- }
-}
-
-func TestIntegration_ListExpiredObjectsByRules_TagFilter(t *testing.T) {
- server, client := startTestFiler(t)
- bucketsPath := "/buckets"
- bucket := "tag-bucket"
- bucketDir := bucketsPath + "/" + bucket
-
- old := time.Now().Add(-60 * 24 * time.Hour)
-
- server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true})
-
- // Object with matching tag.
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "tagged.txt",
- Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 100},
- Extended: map[string][]byte{"X-Amz-Tagging-env": []byte("dev")},
- })
- // Object without tag.
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "untagged.txt",
- Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 100},
- })
-
- rules := []s3lifecycle.Rule{{
- ID: "tag-expire", Status: "Enabled",
- ExpirationDays: 30,
- FilterTags: map[string]string{"env": "dev"},
- }}
-
- expired, _, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100)
- if err != nil {
- t.Fatalf("listExpiredObjectsByRules: %v", err)
- }
-
- if len(expired) != 1 {
- t.Fatalf("expected 1 expired (tagged only), got %d", len(expired))
- }
- if expired[0].name != "tagged.txt" {
- t.Errorf("expected tagged.txt, got %s", expired[0].name)
- }
-}
-
-func TestIntegration_ProcessVersionsDirectory(t *testing.T) {
- server, client := startTestFiler(t)
- bucketsPath := "/buckets"
- bucket := "versioned-bucket"
- bucketDir := bucketsPath + "/" + bucket
- versionsDir := bucketDir + "/key.versions"
-
- now := time.Now()
- t1 := now.Add(-90 * 24 * time.Hour) // oldest
- t2 := now.Add(-60 * 24 * time.Hour)
- t3 := now.Add(-1 * 24 * time.Hour) // newest (latest)
-
- vid1 := testVersionId(t1)
- vid2 := testVersionId(t2)
- vid3 := testVersionId(t3)
-
- server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true})
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "key.versions", IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.ExtLatestVersionIdKey: []byte(vid3),
- },
- })
-
- // Three versions: vid3 (latest), vid2 (noncurrent), vid1 (noncurrent)
- server.putEntry(versionsDir, &filer_pb.Entry{
- Name: "v_" + vid3,
- Attributes: &filer_pb.FuseAttributes{Mtime: t3.Unix(), FileSize: 100},
- Extended: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte(vid3),
- },
- })
- server.putEntry(versionsDir, &filer_pb.Entry{
- Name: "v_" + vid2,
- Attributes: &filer_pb.FuseAttributes{Mtime: t2.Unix(), FileSize: 100},
- Extended: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte(vid2),
- },
- })
- server.putEntry(versionsDir, &filer_pb.Entry{
- Name: "v_" + vid1,
- Attributes: &filer_pb.FuseAttributes{Mtime: t1.Unix(), FileSize: 100},
- Extended: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte(vid1),
- },
- })
-
- rules := []s3lifecycle.Rule{{
- ID: "noncurrent-30d", Status: "Enabled",
- NoncurrentVersionExpirationDays: 30,
- }}
-
- expired, scanned, err := processVersionsDirectory(
- context.Background(), client, versionsDir, bucketDir,
- rules, now, false, 100,
- )
- if err != nil {
- t.Fatalf("processVersionsDirectory: %v", err)
- }
-
- // vid3 is latest (not expired). vid2 became noncurrent when vid3 was created
- // (1 day ago), so vid2 is NOT old enough (< 30 days noncurrent).
- // vid1 became noncurrent when vid2 was created (60 days ago), so vid1 IS expired.
- if scanned != 2 {
- t.Errorf("expected 2 scanned (non-current versions), got %d", scanned)
- }
- if len(expired) != 1 {
- t.Fatalf("expected 1 expired (only vid1), got %d", len(expired))
- }
- if expired[0].name != "v_"+vid1 {
- t.Errorf("expected v_%s expired, got %s", vid1, expired[0].name)
- }
-}
-
-func TestIntegration_ProcessVersionsDirectory_NewerNoncurrentVersions(t *testing.T) {
- server, client := startTestFiler(t)
- bucketsPath := "/buckets"
- bucket := "keep-n-bucket"
- bucketDir := bucketsPath + "/" + bucket
- versionsDir := bucketDir + "/obj.versions"
-
- now := time.Now()
- // Create 5 versions, all old enough to expire by days alone.
- versions := make([]time.Time, 5)
- vids := make([]string, 5)
- for i := 0; i < 5; i++ {
- versions[i] = now.Add(time.Duration(-(90 - i*10)) * 24 * time.Hour)
- vids[i] = testVersionId(versions[i])
- }
- // vids[4] is newest (latest), vids[0] is oldest
-
- server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true})
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "obj.versions", IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.ExtLatestVersionIdKey: []byte(vids[4]),
- },
- })
-
- for i, vid := range vids {
- server.putEntry(versionsDir, &filer_pb.Entry{
- Name: "v_" + vid,
- Attributes: &filer_pb.FuseAttributes{Mtime: versions[i].Unix(), FileSize: 100},
- Extended: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte(vid),
- },
- })
- }
-
- rules := []s3lifecycle.Rule{{
- ID: "keep-2", Status: "Enabled",
- NoncurrentVersionExpirationDays: 7,
- NewerNoncurrentVersions: 2,
- }}
-
- expired, _, err := processVersionsDirectory(
- context.Background(), client, versionsDir, bucketDir,
- rules, now, false, 100,
- )
- if err != nil {
- t.Fatalf("processVersionsDirectory: %v", err)
- }
-
- // 4 noncurrent versions (vids[0..3]). Keep newest 2 (vids[3], vids[2]).
- // Expire vids[1] and vids[0].
- if len(expired) != 2 {
- t.Fatalf("expected 2 expired (keep 2 newest noncurrent), got %d", len(expired))
- }
- expiredNames := map[string]bool{}
- for _, e := range expired {
- expiredNames[e.name] = true
- }
- if !expiredNames["v_"+vids[0]] {
- t.Errorf("expected vids[0] (oldest) to be expired")
- }
- if !expiredNames["v_"+vids[1]] {
- t.Errorf("expected vids[1] to be expired")
- }
-}
-
-func TestIntegration_AbortMPUsByRules(t *testing.T) {
- server, client := startTestFiler(t)
- bucketsPath := "/buckets"
- bucket := "mpu-bucket"
- uploadsDir := bucketsPath + "/" + bucket + "/.uploads"
-
- now := time.Now()
- old := now.Add(-10 * 24 * time.Hour)
- recent := now.Add(-2 * 24 * time.Hour)
-
- server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true})
- server.putEntry(bucketsPath+"/"+bucket, &filer_pb.Entry{Name: ".uploads", IsDirectory: true})
-
- // Old upload under logs/ prefix.
- server.putEntry(uploadsDir, &filer_pb.Entry{
- Name: "logs_upload1", IsDirectory: true,
- Attributes: &filer_pb.FuseAttributes{Crtime: old.Unix()},
- })
- // Recent upload under logs/ prefix.
- server.putEntry(uploadsDir, &filer_pb.Entry{
- Name: "logs_upload2", IsDirectory: true,
- Attributes: &filer_pb.FuseAttributes{Crtime: recent.Unix()},
- })
- // Old upload under data/ prefix (should not match logs/ rule).
- server.putEntry(uploadsDir, &filer_pb.Entry{
- Name: "data_upload1", IsDirectory: true,
- Attributes: &filer_pb.FuseAttributes{Crtime: old.Unix()},
- })
-
- rules := []s3lifecycle.Rule{{
- ID: "abort-logs", Status: "Enabled",
- Prefix: "logs",
- AbortMPUDaysAfterInitiation: 7,
- }}
-
- aborted, errs, err := abortMPUsByRules(context.Background(), client, bucketsPath, bucket, rules, 100)
- if err != nil {
- t.Fatalf("abortMPUsByRules: %v", err)
- }
- if errs != 0 {
- t.Errorf("expected 0 errors, got %d", errs)
- }
-
- // Only logs_upload1 should be aborted (old + matches prefix).
- // logs_upload2 is too recent, data_upload1 doesn't match prefix.
- if aborted != 1 {
- t.Errorf("expected 1 aborted, got %d", aborted)
- }
-
- // Verify the correct upload was removed.
- if server.hasEntry(uploadsDir, "logs_upload1") {
- t.Error("logs_upload1 should have been removed")
- }
- if !server.hasEntry(uploadsDir, "logs_upload2") {
- t.Error("logs_upload2 should still exist")
- }
- if !server.hasEntry(uploadsDir, "data_upload1") {
- t.Error("data_upload1 should still exist (wrong prefix)")
- }
-}
-
-func TestIntegration_DeleteExpiredObjects(t *testing.T) {
- server, client := startTestFiler(t)
- bucketsPath := "/buckets"
- bucket := "delete-bucket"
- bucketDir := bucketsPath + "/" + bucket
-
- now := time.Now()
- old := now.Add(-60 * 24 * time.Hour)
-
- server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true})
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "to-delete.txt",
- Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 100},
- })
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "to-keep.txt",
- Attributes: &filer_pb.FuseAttributes{Mtime: now.Unix(), FileSize: 100},
- })
-
- rules := []s3lifecycle.Rule{{
- ID: "expire", Status: "Enabled",
- ExpirationDays: 30,
- }}
-
- expired, _, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100)
- if err != nil {
- t.Fatalf("list: %v", err)
- }
-
- // Actually delete them.
- deleted, errs, err := deleteExpiredObjects(context.Background(), client, expired)
- if err != nil {
- t.Fatalf("delete: %v", err)
- }
- if deleted != 1 || errs != 0 {
- t.Errorf("expected 1 deleted 0 errors, got %d deleted %d errors", deleted, errs)
- }
-
- if server.hasEntry(bucketDir, "to-delete.txt") {
- t.Error("to-delete.txt should have been removed")
- }
- if !server.hasEntry(bucketDir, "to-keep.txt") {
- t.Error("to-keep.txt should still exist")
- }
-}
-
-// TestIntegration_VersionedBucket_ExpirationDays verifies that Expiration.Days
-// rules correctly detect and delete the latest version in a versioned bucket
-// where all data lives in .versions/ directories (issue #8757).
-func TestIntegration_VersionedBucket_ExpirationDays(t *testing.T) {
- server, client := startTestFiler(t)
- bucketsPath := "/buckets"
- bucket := "versioned-expire"
- bucketDir := bucketsPath + "/" + bucket
-
- now := time.Now()
- old := now.Add(-60 * 24 * time.Hour) // 60 days ago — should expire
- recent := now.Add(-5 * 24 * time.Hour) // 5 days ago — should NOT expire
-
- vidOld := testVersionId(old)
- vidRecent := testVersionId(recent)
-
- server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true})
-
- // --- Single-version object (old, should expire) ---
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "old-file.txt" + s3_constants.VersionsFolder, IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.ExtLatestVersionIdKey: []byte(vidOld),
- s3_constants.ExtLatestVersionFileNameKey: []byte("v_" + vidOld),
- s3_constants.ExtLatestVersionMtimeKey: []byte(strconv.FormatInt(old.Unix(), 10)),
- s3_constants.ExtLatestVersionSizeKey: []byte("3400000000"),
- s3_constants.ExtLatestVersionIsDeleteMarker: []byte("false"),
- },
- })
- oldVersionsDir := bucketDir + "/old-file.txt" + s3_constants.VersionsFolder
- server.putEntry(oldVersionsDir, &filer_pb.Entry{
- Name: "v_" + vidOld,
- Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 3400000000},
- Extended: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte(vidOld),
- },
- })
-
- // --- Single-version object (recent, should NOT expire) ---
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "recent-file.txt" + s3_constants.VersionsFolder, IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.ExtLatestVersionIdKey: []byte(vidRecent),
- s3_constants.ExtLatestVersionFileNameKey: []byte("v_" + vidRecent),
- s3_constants.ExtLatestVersionMtimeKey: []byte(strconv.FormatInt(recent.Unix(), 10)),
- s3_constants.ExtLatestVersionSizeKey: []byte("3400000000"),
- s3_constants.ExtLatestVersionIsDeleteMarker: []byte("false"),
- },
- })
- recentVersionsDir := bucketDir + "/recent-file.txt" + s3_constants.VersionsFolder
- server.putEntry(recentVersionsDir, &filer_pb.Entry{
- Name: "v_" + vidRecent,
- Attributes: &filer_pb.FuseAttributes{Mtime: recent.Unix(), FileSize: 3400000000},
- Extended: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte(vidRecent),
- },
- })
-
- // --- Object with delete marker as latest (should NOT be expired by Expiration.Days) ---
- vidMarker := testVersionId(old)
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "deleted-obj.txt" + s3_constants.VersionsFolder, IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.ExtLatestVersionIdKey: []byte(vidMarker),
- s3_constants.ExtLatestVersionFileNameKey: []byte("v_" + vidMarker),
- s3_constants.ExtLatestVersionMtimeKey: []byte(strconv.FormatInt(old.Unix(), 10)),
- s3_constants.ExtLatestVersionIsDeleteMarker: []byte("true"),
- },
- })
-
- rules := []s3lifecycle.Rule{{
- ID: "expire-30d", Status: "Enabled",
- ExpirationDays: 30,
- }}
-
- expired, scanned, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100)
- if err != nil {
- t.Fatalf("listExpiredObjectsByRules: %v", err)
- }
-
- // Only old-file.txt's latest version should be expired.
- // recent-file.txt is too young; deleted-obj.txt is a delete marker.
- if len(expired) != 1 {
- t.Fatalf("expected 1 expired, got %d: %+v", len(expired), expired)
- }
- if expired[0].dir != oldVersionsDir {
- t.Errorf("expected dir=%s, got %s", oldVersionsDir, expired[0].dir)
- }
- if expired[0].name != "v_"+vidOld {
- t.Errorf("expected name=v_%s, got %s", vidOld, expired[0].name)
- }
- // The old-file.txt latest version should count as scanned.
- if scanned < 1 {
- t.Errorf("expected at least 1 scanned, got %d", scanned)
- }
-}
-
-// TestIntegration_VersionedBucket_ExpirationDays_DeleteAndCleanup verifies
-// end-to-end deletion and .versions directory cleanup for a single-version
-// versioned object expired by Expiration.Days.
-func TestIntegration_VersionedBucket_ExpirationDays_DeleteAndCleanup(t *testing.T) {
- server, client := startTestFiler(t)
- bucketsPath := "/buckets"
- bucket := "versioned-cleanup"
- bucketDir := bucketsPath + "/" + bucket
-
- now := time.Now()
- old := now.Add(-60 * 24 * time.Hour)
- vidOld := testVersionId(old)
-
- server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true})
-
- // Single-version object that should expire.
- versionsDir := bucketDir + "/data.bin" + s3_constants.VersionsFolder
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "data.bin" + s3_constants.VersionsFolder, IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.ExtLatestVersionIdKey: []byte(vidOld),
- s3_constants.ExtLatestVersionFileNameKey: []byte("v_" + vidOld),
- s3_constants.ExtLatestVersionMtimeKey: []byte(strconv.FormatInt(old.Unix(), 10)),
- s3_constants.ExtLatestVersionSizeKey: []byte("1024"),
- s3_constants.ExtLatestVersionIsDeleteMarker: []byte("false"),
- },
- })
- server.putEntry(versionsDir, &filer_pb.Entry{
- Name: "v_" + vidOld,
- Attributes: &filer_pb.FuseAttributes{Mtime: old.Unix(), FileSize: 1024},
- Extended: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte(vidOld),
- },
- })
-
- rules := []s3lifecycle.Rule{{
- ID: "expire-30d", Status: "Enabled",
- ExpirationDays: 30,
- }}
-
- // Step 1: Detect expired.
- expired, _, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100)
- if err != nil {
- t.Fatalf("list: %v", err)
- }
- if len(expired) != 1 {
- t.Fatalf("expected 1 expired, got %d", len(expired))
- }
-
- // Step 2: Delete the expired version file.
- deleted, errs, delErr := deleteExpiredObjects(context.Background(), client, expired)
- if delErr != nil {
- t.Fatalf("delete: %v", delErr)
- }
- if deleted != 1 || errs != 0 {
- t.Errorf("expected 1 deleted 0 errors, got %d deleted %d errors", deleted, errs)
- }
-
- // Version file should be gone.
- if server.hasEntry(versionsDir, "v_"+vidOld) {
- t.Error("version file should have been removed")
- }
-
- // Step 3: Cleanup empty .versions directory.
- cleaned := cleanupEmptyVersionsDirectories(context.Background(), client, expired)
- if cleaned != 1 {
- t.Errorf("expected 1 directory cleaned, got %d", cleaned)
- }
-
- // The .versions directory itself should be gone.
- if server.hasEntry(bucketDir, "data.bin"+s3_constants.VersionsFolder) {
- t.Error(".versions directory should have been removed after cleanup")
- }
-}
-
-// TestIntegration_VersionedBucket_MultiVersion_ExpirationDays verifies that
-// when a multi-version object's latest version expires, only the latest
-// version is deleted and noncurrent versions remain.
-func TestIntegration_VersionedBucket_MultiVersion_ExpirationDays(t *testing.T) {
- server, client := startTestFiler(t)
- bucketsPath := "/buckets"
- bucket := "versioned-multi"
- bucketDir := bucketsPath + "/" + bucket
-
- now := time.Now()
- tOld := now.Add(-60 * 24 * time.Hour)
- tNoncurrent := now.Add(-90 * 24 * time.Hour)
- vidLatest := testVersionId(tOld)
- vidNoncurrent := testVersionId(tNoncurrent)
-
- server.putEntry(bucketsPath, &filer_pb.Entry{Name: bucket, IsDirectory: true})
-
- versionsDir := bucketDir + "/multi.txt" + s3_constants.VersionsFolder
- server.putEntry(bucketDir, &filer_pb.Entry{
- Name: "multi.txt" + s3_constants.VersionsFolder, IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.ExtLatestVersionIdKey: []byte(vidLatest),
- s3_constants.ExtLatestVersionFileNameKey: []byte("v_" + vidLatest),
- s3_constants.ExtLatestVersionMtimeKey: []byte(strconv.FormatInt(tOld.Unix(), 10)),
- s3_constants.ExtLatestVersionSizeKey: []byte("500"),
- s3_constants.ExtLatestVersionIsDeleteMarker: []byte("false"),
- },
- })
- server.putEntry(versionsDir, &filer_pb.Entry{
- Name: "v_" + vidLatest,
- Attributes: &filer_pb.FuseAttributes{Mtime: tOld.Unix(), FileSize: 500},
- Extended: map[string][]byte{s3_constants.ExtVersionIdKey: []byte(vidLatest)},
- })
- server.putEntry(versionsDir, &filer_pb.Entry{
- Name: "v_" + vidNoncurrent,
- Attributes: &filer_pb.FuseAttributes{Mtime: tNoncurrent.Unix(), FileSize: 500},
- Extended: map[string][]byte{s3_constants.ExtVersionIdKey: []byte(vidNoncurrent)},
- })
-
- rules := []s3lifecycle.Rule{{
- ID: "expire-30d", Status: "Enabled",
- ExpirationDays: 30,
- }}
-
- expired, _, err := listExpiredObjectsByRules(context.Background(), client, bucketsPath, bucket, rules, 100)
- if err != nil {
- t.Fatalf("list: %v", err)
- }
- // Only the latest version should be detected as expired.
- if len(expired) != 1 {
- t.Fatalf("expected 1 expired (latest only), got %d", len(expired))
- }
- if expired[0].name != "v_"+vidLatest {
- t.Errorf("expected latest version expired, got %s", expired[0].name)
- }
-
- // Delete it.
- deleted, errs, delErr := deleteExpiredObjects(context.Background(), client, expired)
- if delErr != nil {
- t.Fatalf("delete: %v", delErr)
- }
- if deleted != 1 || errs != 0 {
- t.Errorf("expected 1 deleted 0 errors, got %d deleted %d errors", deleted, errs)
- }
-
- // Noncurrent version should still exist.
- if !server.hasEntry(versionsDir, "v_"+vidNoncurrent) {
- t.Error("noncurrent version should still exist")
- }
-
- // .versions directory should NOT be cleaned up (not empty).
- cleaned := cleanupEmptyVersionsDirectories(context.Background(), client, expired)
- if cleaned != 0 {
- t.Errorf("expected 0 directories cleaned (not empty), got %d", cleaned)
- }
- if !server.hasEntry(bucketDir, "multi.txt"+s3_constants.VersionsFolder) {
- t.Error(".versions directory should still exist (has noncurrent version)")
- }
-}
diff --git a/weed/plugin/worker/lifecycle/rules.go b/weed/plugin/worker/lifecycle/rules.go
deleted file mode 100644
index c3855f22c..000000000
--- a/weed/plugin/worker/lifecycle/rules.go
+++ /dev/null
@@ -1,199 +0,0 @@
-package lifecycle
-
-import (
- "bytes"
- "context"
- "encoding/xml"
- "errors"
- "fmt"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3lifecycle"
-)
-
-// lifecycleConfig mirrors the XML structure just enough to parse rules.
-// We define a minimal local struct to avoid importing the s3api package
-// (which would create a circular dependency if s3api ever imports the worker).
-type lifecycleConfig struct {
- XMLName xml.Name `xml:"LifecycleConfiguration"`
- Rules []lifecycleConfigRule `xml:"Rule"`
-}
-
-type lifecycleConfigRule struct {
- ID string `xml:"ID"`
- Status string `xml:"Status"`
- Filter lifecycleFilter `xml:"Filter"`
- Prefix string `xml:"Prefix"`
- Expiration lifecycleExpiration `xml:"Expiration"`
- NoncurrentVersionExpiration noncurrentVersionExpiration `xml:"NoncurrentVersionExpiration"`
- AbortIncompleteMultipartUpload abortMPU `xml:"AbortIncompleteMultipartUpload"`
-}
-
-type lifecycleFilter struct {
- Prefix string `xml:"Prefix"`
- Tag lifecycleTag `xml:"Tag"`
- And lifecycleAnd `xml:"And"`
- ObjectSizeGreaterThan int64 `xml:"ObjectSizeGreaterThan"`
- ObjectSizeLessThan int64 `xml:"ObjectSizeLessThan"`
-}
-
-type lifecycleAnd struct {
- Prefix string `xml:"Prefix"`
- Tags []lifecycleTag `xml:"Tag"`
- ObjectSizeGreaterThan int64 `xml:"ObjectSizeGreaterThan"`
- ObjectSizeLessThan int64 `xml:"ObjectSizeLessThan"`
-}
-
-type lifecycleTag struct {
- Key string `xml:"Key"`
- Value string `xml:"Value"`
-}
-
-type lifecycleExpiration struct {
- Days int `xml:"Days"`
- Date string `xml:"Date"`
- ExpiredObjectDeleteMarker bool `xml:"ExpiredObjectDeleteMarker"`
-}
-
-type noncurrentVersionExpiration struct {
- NoncurrentDays int `xml:"NoncurrentDays"`
- NewerNoncurrentVersions int `xml:"NewerNoncurrentVersions"`
-}
-
-type abortMPU struct {
- DaysAfterInitiation int `xml:"DaysAfterInitiation"`
-}
-
-// errMalformedLifecycleXML indicates the lifecycle XML exists but could not be parsed.
-// Callers should fail closed (not fall back to TTL) to avoid broader deletions.
-var errMalformedLifecycleXML = errors.New("malformed lifecycle XML")
-
-// loadLifecycleRulesFromBucket reads the lifecycle XML from a bucket's
-// metadata and converts it to evaluator-friendly rules.
-//
-// Returns:
-// - (rules, nil) when lifecycle XML is configured and parseable
-// - (nil, nil) when no lifecycle XML is configured (caller should use TTL fallback)
-// - (nil, errMalformedLifecycleXML) when XML exists but is malformed (fail closed)
-// - (nil, err) for transient filer errors (caller should use TTL fallback with warning)
-func loadLifecycleRulesFromBucket(
- ctx context.Context,
- client filer_pb.SeaweedFilerClient,
- bucketsPath, bucket string,
-) ([]s3lifecycle.Rule, error) {
- bucketDir := bucketsPath
- resp, err := filer_pb.LookupEntry(ctx, client, &filer_pb.LookupDirectoryEntryRequest{
- Directory: bucketDir,
- Name: bucket,
- })
- if err != nil {
- // Transient filer error — not the same as malformed XML.
- return nil, fmt.Errorf("lookup bucket %s: %w", bucket, err)
- }
- if resp.Entry == nil || resp.Entry.Extended == nil {
- return nil, nil
- }
- xmlData := resp.Entry.Extended[lifecycleXMLKey]
- if len(xmlData) == 0 {
- return nil, nil
- }
- rules, parseErr := parseLifecycleXML(xmlData)
- if parseErr != nil {
- return nil, fmt.Errorf("%w: bucket %s: %v", errMalformedLifecycleXML, bucket, parseErr)
- }
- // Return non-nil empty slice when XML was present but yielded no rules
- // (e.g., all rules disabled). This lets callers distinguish "no XML" (nil)
- // from "XML present, no effective rules" (empty slice).
- if rules == nil {
- rules = []s3lifecycle.Rule{}
- }
- return rules, nil
-}
-
-// parseLifecycleXML parses lifecycle configuration XML and converts it
-// to evaluator-friendly rules.
-func parseLifecycleXML(data []byte) ([]s3lifecycle.Rule, error) {
- var config lifecycleConfig
- if err := xml.NewDecoder(bytes.NewReader(data)).Decode(&config); err != nil {
- return nil, fmt.Errorf("decode lifecycle XML: %w", err)
- }
-
- var rules []s3lifecycle.Rule
- for _, r := range config.Rules {
- rule := s3lifecycle.Rule{
- ID: r.ID,
- Status: r.Status,
- }
-
- // Resolve prefix: Filter.And.Prefix > Filter.Prefix > Rule.Prefix
- switch {
- case r.Filter.And.Prefix != "" || len(r.Filter.And.Tags) > 0 ||
- r.Filter.And.ObjectSizeGreaterThan > 0 || r.Filter.And.ObjectSizeLessThan > 0:
- rule.Prefix = r.Filter.And.Prefix
- rule.FilterTags = tagsToMap(r.Filter.And.Tags)
- rule.FilterSizeGreaterThan = r.Filter.And.ObjectSizeGreaterThan
- rule.FilterSizeLessThan = r.Filter.And.ObjectSizeLessThan
- case r.Filter.Tag.Key != "":
- rule.Prefix = r.Filter.Prefix
- rule.FilterTags = map[string]string{r.Filter.Tag.Key: r.Filter.Tag.Value}
- rule.FilterSizeGreaterThan = r.Filter.ObjectSizeGreaterThan
- rule.FilterSizeLessThan = r.Filter.ObjectSizeLessThan
- default:
- if r.Filter.Prefix != "" {
- rule.Prefix = r.Filter.Prefix
- } else {
- rule.Prefix = r.Prefix
- }
- rule.FilterSizeGreaterThan = r.Filter.ObjectSizeGreaterThan
- rule.FilterSizeLessThan = r.Filter.ObjectSizeLessThan
- }
-
- rule.ExpirationDays = r.Expiration.Days
- rule.ExpiredObjectDeleteMarker = r.Expiration.ExpiredObjectDeleteMarker
- rule.NoncurrentVersionExpirationDays = r.NoncurrentVersionExpiration.NoncurrentDays
- rule.NewerNoncurrentVersions = r.NoncurrentVersionExpiration.NewerNoncurrentVersions
- rule.AbortMPUDaysAfterInitiation = r.AbortIncompleteMultipartUpload.DaysAfterInitiation
-
- // Parse Date if present.
- if r.Expiration.Date != "" {
- // Date may be RFC3339 or ISO 8601 date-only.
- parsed, parseErr := parseExpirationDate(r.Expiration.Date)
- if parseErr != nil {
- glog.V(1).Infof("s3_lifecycle: skipping rule %s: invalid expiration date %q: %v", r.ID, r.Expiration.Date, parseErr)
- continue
- }
- rule.ExpirationDate = parsed
- }
-
- rules = append(rules, rule)
- }
- return rules, nil
-}
-
-func tagsToMap(tags []lifecycleTag) map[string]string {
- if len(tags) == 0 {
- return nil
- }
- m := make(map[string]string, len(tags))
- for _, t := range tags {
- m[t.Key] = t.Value
- }
- return m
-}
-
-func parseExpirationDate(s string) (time.Time, error) {
- // Try RFC3339 first, then ISO 8601 date-only.
- formats := []string{
- "2006-01-02T15:04:05Z07:00",
- "2006-01-02",
- }
- for _, f := range formats {
- t, err := time.Parse(f, s)
- if err == nil {
- return t, nil
- }
- }
- return time.Time{}, fmt.Errorf("unrecognized date format: %s", s)
-}
diff --git a/weed/plugin/worker/lifecycle/rules_test.go b/weed/plugin/worker/lifecycle/rules_test.go
deleted file mode 100644
index ab57137a7..000000000
--- a/weed/plugin/worker/lifecycle/rules_test.go
+++ /dev/null
@@ -1,256 +0,0 @@
-package lifecycle
-
-import (
- "testing"
- "time"
-)
-
-func TestParseLifecycleXML_CompleteConfig(t *testing.T) {
- xml := []byte(`
-
- rotation
-
- Enabled
- 30
-
- 7
- 2
-
-
- 3
-
-
-`)
-
- rules, err := parseLifecycleXML(xml)
- if err != nil {
- t.Fatalf("parseLifecycleXML: %v", err)
- }
- if len(rules) != 1 {
- t.Fatalf("expected 1 rule, got %d", len(rules))
- }
-
- r := rules[0]
- if r.ID != "rotation" {
- t.Errorf("expected ID 'rotation', got %q", r.ID)
- }
- if r.Status != "Enabled" {
- t.Errorf("expected Status 'Enabled', got %q", r.Status)
- }
- if r.ExpirationDays != 30 {
- t.Errorf("expected ExpirationDays=30, got %d", r.ExpirationDays)
- }
- if r.NoncurrentVersionExpirationDays != 7 {
- t.Errorf("expected NoncurrentVersionExpirationDays=7, got %d", r.NoncurrentVersionExpirationDays)
- }
- if r.NewerNoncurrentVersions != 2 {
- t.Errorf("expected NewerNoncurrentVersions=2, got %d", r.NewerNoncurrentVersions)
- }
- if r.AbortMPUDaysAfterInitiation != 3 {
- t.Errorf("expected AbortMPUDaysAfterInitiation=3, got %d", r.AbortMPUDaysAfterInitiation)
- }
-}
-
-func TestParseLifecycleXML_PrefixFilter(t *testing.T) {
- xml := []byte(`
-
- logs
- Enabled
- logs/
- 7
-
-`)
-
- rules, err := parseLifecycleXML(xml)
- if err != nil {
- t.Fatalf("parseLifecycleXML: %v", err)
- }
- if len(rules) != 1 {
- t.Fatalf("expected 1 rule, got %d", len(rules))
- }
- if rules[0].Prefix != "logs/" {
- t.Errorf("expected Prefix='logs/', got %q", rules[0].Prefix)
- }
-}
-
-func TestParseLifecycleXML_LegacyPrefix(t *testing.T) {
- // Old-style at rule level instead of inside .
- xml := []byte(`
-
- old
- Enabled
- archive/
- 90
-
-`)
-
- rules, err := parseLifecycleXML(xml)
- if err != nil {
- t.Fatalf("parseLifecycleXML: %v", err)
- }
- if len(rules) != 1 {
- t.Fatalf("expected 1 rule, got %d", len(rules))
- }
- if rules[0].Prefix != "archive/" {
- t.Errorf("expected Prefix='archive/', got %q", rules[0].Prefix)
- }
-}
-
-func TestParseLifecycleXML_TagFilter(t *testing.T) {
- xml := []byte(`
-
- tag-rule
- Enabled
-
- envdev
-
- 1
-
-`)
-
- rules, err := parseLifecycleXML(xml)
- if err != nil {
- t.Fatalf("parseLifecycleXML: %v", err)
- }
- if len(rules) != 1 {
- t.Fatalf("expected 1 rule, got %d", len(rules))
- }
- r := rules[0]
- if len(r.FilterTags) != 1 || r.FilterTags["env"] != "dev" {
- t.Errorf("expected FilterTags={env:dev}, got %v", r.FilterTags)
- }
-}
-
-func TestParseLifecycleXML_AndFilter(t *testing.T) {
- xml := []byte(`
-
- and-rule
- Enabled
-
-
- data/
- envstaging
- 1024
-
-
- 14
-
-`)
-
- rules, err := parseLifecycleXML(xml)
- if err != nil {
- t.Fatalf("parseLifecycleXML: %v", err)
- }
- if len(rules) != 1 {
- t.Fatalf("expected 1 rule, got %d", len(rules))
- }
- r := rules[0]
- if r.Prefix != "data/" {
- t.Errorf("expected Prefix='data/', got %q", r.Prefix)
- }
- if r.FilterTags["env"] != "staging" {
- t.Errorf("expected tag env=staging, got %v", r.FilterTags)
- }
- if r.FilterSizeGreaterThan != 1024 {
- t.Errorf("expected FilterSizeGreaterThan=1024, got %d", r.FilterSizeGreaterThan)
- }
-}
-
-func TestParseLifecycleXML_ExpirationDate(t *testing.T) {
- xml := []byte(`
-
- date-rule
- Enabled
-
- 2026-06-01T00:00:00Z
-
-`)
-
- rules, err := parseLifecycleXML(xml)
- if err != nil {
- t.Fatalf("parseLifecycleXML: %v", err)
- }
- expected := time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC)
- if !rules[0].ExpirationDate.Equal(expected) {
- t.Errorf("expected ExpirationDate=%v, got %v", expected, rules[0].ExpirationDate)
- }
-}
-
-func TestParseLifecycleXML_ExpiredObjectDeleteMarker(t *testing.T) {
- xml := []byte(`
-
- marker-cleanup
- Enabled
-
- true
-
-`)
-
- rules, err := parseLifecycleXML(xml)
- if err != nil {
- t.Fatalf("parseLifecycleXML: %v", err)
- }
- if !rules[0].ExpiredObjectDeleteMarker {
- t.Error("expected ExpiredObjectDeleteMarker=true")
- }
-}
-
-func TestParseLifecycleXML_MultipleRules(t *testing.T) {
- xml := []byte(`
-
- rule1
- Enabled
- logs/
- 7
-
-
- rule2
- Disabled
- temp/
- 1
-
-
- rule3
- Enabled
-
- 365
-
-`)
-
- rules, err := parseLifecycleXML(xml)
- if err != nil {
- t.Fatalf("parseLifecycleXML: %v", err)
- }
- if len(rules) != 3 {
- t.Fatalf("expected 3 rules, got %d", len(rules))
- }
- if rules[1].Status != "Disabled" {
- t.Errorf("expected rule2 Status=Disabled, got %q", rules[1].Status)
- }
-}
-
-func TestParseExpirationDate(t *testing.T) {
- tests := []struct {
- name string
- input string
- want time.Time
- wantErr bool
- }{
- {"rfc3339_utc", "2026-06-01T00:00:00Z", time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC), false},
- {"rfc3339_offset", "2026-06-01T00:00:00+05:00", time.Date(2026, 6, 1, 0, 0, 0, 0, time.FixedZone("", 5*3600)), false},
- {"date_only", "2026-06-01", time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC), false},
- {"invalid", "not-a-date", time.Time{}, true},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got, err := parseExpirationDate(tt.input)
- if (err != nil) != tt.wantErr {
- t.Errorf("parseExpirationDate(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
- return
- }
- if !tt.wantErr && !got.Equal(tt.want) {
- t.Errorf("parseExpirationDate(%q) = %v, want %v", tt.input, got, tt.want)
- }
- })
- }
-}
diff --git a/weed/plugin/worker/lifecycle/version_test.go b/weed/plugin/worker/lifecycle/version_test.go
deleted file mode 100644
index 43cc0d93b..000000000
--- a/weed/plugin/worker/lifecycle/version_test.go
+++ /dev/null
@@ -1,112 +0,0 @@
-package lifecycle
-
-import (
- "fmt"
- "math"
- "strings"
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3lifecycle"
-)
-
-// makeVersionId creates a new-format version ID from a timestamp.
-func makeVersionId(t time.Time) string {
- inverted := math.MaxInt64 - t.UnixNano()
- return fmt.Sprintf("%016x", inverted) + "0000000000000000"
-}
-
-func TestSortVersionsByVersionId(t *testing.T) {
- t1 := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
- t2 := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
- t3 := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
-
- vid1 := makeVersionId(t1)
- vid2 := makeVersionId(t2)
- vid3 := makeVersionId(t3)
-
- entries := []*filer_pb.Entry{
- {Name: "v_" + vid1},
- {Name: "v_" + vid3},
- {Name: "v_" + vid2},
- }
-
- sortVersionsByVersionId(entries)
-
- // Should be sorted newest first: t3, t2, t1.
- expected := []string{"v_" + vid3, "v_" + vid2, "v_" + vid1}
- for i, want := range expected {
- if entries[i].Name != want {
- t.Errorf("entries[%d].Name = %s, want %s", i, entries[i].Name, want)
- }
- }
-}
-
-func TestSortVersionsByVersionId_SameTimestampDifferentSuffix(t *testing.T) {
- // Two versions with the same timestamp prefix but different random suffix.
- // The sort must still produce a deterministic order.
- base := makeVersionId(time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC))
- vid1 := base[:16] + "aaaaaaaaaaaaaaaa"
- vid2 := base[:16] + "bbbbbbbbbbbbbbbb"
-
- entries := []*filer_pb.Entry{
- {Name: "v_" + vid2},
- {Name: "v_" + vid1},
- }
-
- sortVersionsByVersionId(entries)
-
- // New format: smaller hex = newer. vid1 ("aaa...") < vid2 ("bbb...") so vid1 is newer.
- if strings.TrimPrefix(entries[0].Name, "v_") != vid1 {
- t.Errorf("expected vid1 (newer) first, got %s", entries[0].Name)
- }
-}
-
-func TestCompareVersionIdsMixedFormats(t *testing.T) {
- // Old format: raw nanosecond timestamp (below threshold ~0x17...).
- // New format: inverted timestamp (above threshold ~0x68...).
- oldTs := time.Date(2023, 6, 15, 12, 0, 0, 0, time.UTC)
- newTs := time.Date(2026, 3, 1, 0, 0, 0, 0, time.UTC)
-
- oldFormatId := fmt.Sprintf("%016x", oldTs.UnixNano()) + "abcdef0123456789"
- newFormatId := makeVersionId(newTs) // uses inverted timestamp
-
- // newTs is more recent, so newFormatId should sort as "newer".
- cmp := s3lifecycle.CompareVersionIds(newFormatId, oldFormatId)
- if cmp >= 0 {
- t.Errorf("expected new-format ID (2026) to be newer than old-format ID (2023), got cmp=%d", cmp)
- }
-
- // Reverse comparison.
- cmp2 := s3lifecycle.CompareVersionIds(oldFormatId, newFormatId)
- if cmp2 <= 0 {
- t.Errorf("expected old-format ID (2023) to be older than new-format ID (2026), got cmp=%d", cmp2)
- }
-
- // Sort a mixed slice: should be newest-first.
- entries := []*filer_pb.Entry{
- {Name: "v_" + oldFormatId},
- {Name: "v_" + newFormatId},
- }
- sortVersionsByVersionId(entries)
-
- if strings.TrimPrefix(entries[0].Name, "v_") != newFormatId {
- t.Errorf("expected new-format (newer) entry first after sort")
- }
-}
-
-func TestVersionsDirectoryNaming(t *testing.T) {
- if s3_constants.VersionsFolder != ".versions" {
- t.Fatalf("unexpected VersionsFolder constant: %q", s3_constants.VersionsFolder)
- }
-
- versionsDir := "/buckets/mybucket/path/to/key.versions"
- bucketPath := "/buckets/mybucket"
- relDir := strings.TrimPrefix(versionsDir, bucketPath+"/")
- objKey := strings.TrimSuffix(relDir, s3_constants.VersionsFolder)
- if objKey != "path/to/key" {
- t.Errorf("expected 'path/to/key', got %q", objKey)
- }
-}
diff --git a/weed/query/engine/aggregations.go b/weed/query/engine/aggregations.go
index 6b58517e1..54130212e 100644
--- a/weed/query/engine/aggregations.go
+++ b/weed/query/engine/aggregations.go
@@ -74,11 +74,6 @@ func (opt *FastPathOptimizer) DetermineStrategy(aggregations []AggregationSpec)
return strategy
}
-// CollectDataSources gathers information about available data sources for a topic
-func (opt *FastPathOptimizer) CollectDataSources(ctx context.Context, hybridScanner *HybridMessageScanner) (*TopicDataSources, error) {
- return opt.CollectDataSourcesWithTimeFilter(ctx, hybridScanner, 0, 0)
-}
-
// CollectDataSourcesWithTimeFilter gathers information about available data sources for a topic
// with optional time filtering to skip irrelevant parquet files
func (opt *FastPathOptimizer) CollectDataSourcesWithTimeFilter(ctx context.Context, hybridScanner *HybridMessageScanner, startTimeNs, stopTimeNs int64) (*TopicDataSources, error) {
diff --git a/weed/query/engine/engine.go b/weed/query/engine/engine.go
index ac66a7453..7a1c783ba 100644
--- a/weed/query/engine/engine.go
+++ b/weed/query/engine/engine.go
@@ -539,20 +539,6 @@ func NewSQLEngine(masterAddress string) *SQLEngine {
}
}
-// NewSQLEngineWithCatalog creates a new SQL execution engine with a custom catalog
-// Used for testing or when you want to provide a pre-configured catalog
-func NewSQLEngineWithCatalog(catalog *SchemaCatalog) *SQLEngine {
- // Initialize global HTTP client if not already done
- // This is needed for reading partition data from the filer
- if util_http.GetGlobalHttpClient() == nil {
- util_http.InitGlobalHttpClient()
- }
-
- return &SQLEngine{
- catalog: catalog,
- }
-}
-
// GetCatalog returns the schema catalog for external access
func (e *SQLEngine) GetCatalog() *SchemaCatalog {
return e.catalog
@@ -3682,11 +3668,6 @@ type ExecutionPlanBuilder struct {
engine *SQLEngine
}
-// NewExecutionPlanBuilder creates a new execution plan builder
-func NewExecutionPlanBuilder(engine *SQLEngine) *ExecutionPlanBuilder {
- return &ExecutionPlanBuilder{engine: engine}
-}
-
// BuildAggregationPlan builds an execution plan for aggregation queries
func (builder *ExecutionPlanBuilder) BuildAggregationPlan(
stmt *SelectStatement,
diff --git a/weed/query/engine/engine_test.go b/weed/query/engine/engine_test.go
deleted file mode 100644
index 42a5f4911..000000000
--- a/weed/query/engine/engine_test.go
+++ /dev/null
@@ -1,1329 +0,0 @@
-package engine
-
-import (
- "context"
- "encoding/binary"
- "errors"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/mq/topic"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/mock"
- "google.golang.org/protobuf/proto"
-)
-
-// Mock implementations for testing
-type MockHybridMessageScanner struct {
- mock.Mock
- topic topic.Topic
-}
-
-func (m *MockHybridMessageScanner) ReadParquetStatistics(partitionPath string) ([]*ParquetFileStats, error) {
- args := m.Called(partitionPath)
- return args.Get(0).([]*ParquetFileStats), args.Error(1)
-}
-
-type MockSQLEngine struct {
- *SQLEngine
- mockPartitions map[string][]string
- mockParquetSourceFiles map[string]map[string]bool
- mockLiveLogRowCounts map[string]int64
- mockColumnStats map[string]map[string]*ParquetColumnStats
-}
-
-func NewMockSQLEngine() *MockSQLEngine {
- return &MockSQLEngine{
- SQLEngine: &SQLEngine{
- catalog: &SchemaCatalog{
- databases: make(map[string]*DatabaseInfo),
- currentDatabase: "test",
- },
- },
- mockPartitions: make(map[string][]string),
- mockParquetSourceFiles: make(map[string]map[string]bool),
- mockLiveLogRowCounts: make(map[string]int64),
- mockColumnStats: make(map[string]map[string]*ParquetColumnStats),
- }
-}
-
-func (m *MockSQLEngine) discoverTopicPartitions(namespace, topicName string) ([]string, error) {
- key := namespace + "." + topicName
- if partitions, exists := m.mockPartitions[key]; exists {
- return partitions, nil
- }
- return []string{"partition-1", "partition-2"}, nil
-}
-
-func (m *MockSQLEngine) extractParquetSourceFiles(fileStats []*ParquetFileStats) map[string]bool {
- if len(fileStats) == 0 {
- return make(map[string]bool)
- }
- return map[string]bool{"converted-log-1": true}
-}
-
-func (m *MockSQLEngine) countLiveLogRowsExcludingParquetSources(ctx context.Context, partition string, parquetSources map[string]bool) (int64, error) {
- if count, exists := m.mockLiveLogRowCounts[partition]; exists {
- return count, nil
- }
- return 25, nil
-}
-
-func (m *MockSQLEngine) computeLiveLogMinMax(partition, column string, parquetSources map[string]bool) (interface{}, interface{}, error) {
- switch column {
- case "id":
- return int64(1), int64(50), nil
- case "value":
- return 10.5, 99.9, nil
- default:
- return nil, nil, nil
- }
-}
-
-func (m *MockSQLEngine) getSystemColumnGlobalMin(column string, allFileStats map[string][]*ParquetFileStats) interface{} {
- return int64(1000000000)
-}
-
-func (m *MockSQLEngine) getSystemColumnGlobalMax(column string, allFileStats map[string][]*ParquetFileStats) interface{} {
- return int64(2000000000)
-}
-
-func createMockColumnStats(column string, minVal, maxVal interface{}) *ParquetColumnStats {
- return &ParquetColumnStats{
- ColumnName: column,
- MinValue: convertToSchemaValue(minVal),
- MaxValue: convertToSchemaValue(maxVal),
- NullCount: 0,
- }
-}
-
-func convertToSchemaValue(val interface{}) *schema_pb.Value {
- switch v := val.(type) {
- case int64:
- return &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: v}}
- case float64:
- return &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: v}}
- case string:
- return &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: v}}
- }
- return nil
-}
-
-// Test FastPathOptimizer
-func TestFastPathOptimizer_DetermineStrategy(t *testing.T) {
- engine := NewMockSQLEngine()
- optimizer := NewFastPathOptimizer(engine.SQLEngine)
-
- tests := []struct {
- name string
- aggregations []AggregationSpec
- expected AggregationStrategy
- }{
- {
- name: "Supported aggregations",
- aggregations: []AggregationSpec{
- {Function: FuncCOUNT, Column: "*"},
- {Function: FuncMAX, Column: "id"},
- {Function: FuncMIN, Column: "value"},
- },
- expected: AggregationStrategy{
- CanUseFastPath: true,
- Reason: "all_aggregations_supported",
- UnsupportedSpecs: []AggregationSpec{},
- },
- },
- {
- name: "Unsupported aggregation",
- aggregations: []AggregationSpec{
- {Function: FuncCOUNT, Column: "*"},
- {Function: FuncAVG, Column: "value"}, // Not supported
- },
- expected: AggregationStrategy{
- CanUseFastPath: false,
- Reason: "unsupported_aggregation_functions",
- },
- },
- {
- name: "Empty aggregations",
- aggregations: []AggregationSpec{},
- expected: AggregationStrategy{
- CanUseFastPath: true,
- Reason: "all_aggregations_supported",
- UnsupportedSpecs: []AggregationSpec{},
- },
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- strategy := optimizer.DetermineStrategy(tt.aggregations)
-
- assert.Equal(t, tt.expected.CanUseFastPath, strategy.CanUseFastPath)
- assert.Equal(t, tt.expected.Reason, strategy.Reason)
- if !tt.expected.CanUseFastPath {
- assert.NotEmpty(t, strategy.UnsupportedSpecs)
- }
- })
- }
-}
-
-// Test AggregationComputer
-func TestAggregationComputer_ComputeFastPathAggregations(t *testing.T) {
- engine := NewMockSQLEngine()
- computer := NewAggregationComputer(engine.SQLEngine)
-
- dataSources := &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/topic1/partition-1": {
- {
- RowCount: 30,
- ColumnStats: map[string]*ParquetColumnStats{
- "id": createMockColumnStats("id", int64(10), int64(40)),
- },
- },
- },
- },
- ParquetRowCount: 30,
- LiveLogRowCount: 25,
- PartitionsCount: 1,
- }
-
- partitions := []string{"/topics/test/topic1/partition-1"}
-
- tests := []struct {
- name string
- aggregations []AggregationSpec
- validate func(t *testing.T, results []AggregationResult)
- }{
- {
- name: "COUNT aggregation",
- aggregations: []AggregationSpec{
- {Function: FuncCOUNT, Column: "*"},
- },
- validate: func(t *testing.T, results []AggregationResult) {
- assert.Len(t, results, 1)
- assert.Equal(t, int64(55), results[0].Count) // 30 + 25
- },
- },
- {
- name: "MAX aggregation",
- aggregations: []AggregationSpec{
- {Function: FuncMAX, Column: "id"},
- },
- validate: func(t *testing.T, results []AggregationResult) {
- assert.Len(t, results, 1)
- // Should be max of parquet stats (40) - mock doesn't combine with live log
- assert.Equal(t, int64(40), results[0].Max)
- },
- },
- {
- name: "MIN aggregation",
- aggregations: []AggregationSpec{
- {Function: FuncMIN, Column: "id"},
- },
- validate: func(t *testing.T, results []AggregationResult) {
- assert.Len(t, results, 1)
- // Should be min of parquet stats (10) - mock doesn't combine with live log
- assert.Equal(t, int64(10), results[0].Min)
- },
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- results, err := computer.ComputeFastPathAggregations(ctx, tt.aggregations, dataSources, partitions)
-
- assert.NoError(t, err)
- tt.validate(t, results)
- })
- }
-}
-
-// Test case-insensitive column lookup and null handling for MIN/MAX aggregations
-func TestAggregationComputer_MinMaxEdgeCases(t *testing.T) {
- engine := NewMockSQLEngine()
- computer := NewAggregationComputer(engine.SQLEngine)
-
- tests := []struct {
- name string
- dataSources *TopicDataSources
- aggregations []AggregationSpec
- validate func(t *testing.T, results []AggregationResult, err error)
- }{
- {
- name: "Case insensitive column lookup",
- dataSources: &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/partition-1": {
- {
- RowCount: 50,
- ColumnStats: map[string]*ParquetColumnStats{
- "ID": createMockColumnStats("ID", int64(5), int64(95)), // Uppercase column name
- },
- },
- },
- },
- ParquetRowCount: 50,
- LiveLogRowCount: 0,
- PartitionsCount: 1,
- },
- aggregations: []AggregationSpec{
- {Function: FuncMIN, Column: "id"}, // lowercase column name
- {Function: FuncMAX, Column: "id"},
- },
- validate: func(t *testing.T, results []AggregationResult, err error) {
- assert.NoError(t, err)
- assert.Len(t, results, 2)
- assert.Equal(t, int64(5), results[0].Min, "MIN should work with case-insensitive lookup")
- assert.Equal(t, int64(95), results[1].Max, "MAX should work with case-insensitive lookup")
- },
- },
- {
- name: "Null column stats handling",
- dataSources: &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/partition-1": {
- {
- RowCount: 50,
- ColumnStats: map[string]*ParquetColumnStats{
- "id": {
- ColumnName: "id",
- MinValue: nil, // Null min value
- MaxValue: nil, // Null max value
- NullCount: 50,
- RowCount: 50,
- },
- },
- },
- },
- },
- ParquetRowCount: 50,
- LiveLogRowCount: 0,
- PartitionsCount: 1,
- },
- aggregations: []AggregationSpec{
- {Function: FuncMIN, Column: "id"},
- {Function: FuncMAX, Column: "id"},
- },
- validate: func(t *testing.T, results []AggregationResult, err error) {
- assert.NoError(t, err)
- assert.Len(t, results, 2)
- // When stats are null, should fall back to system column or return nil
- // This tests that we don't crash on null stats
- },
- },
- {
- name: "Mixed data types - string column",
- dataSources: &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/partition-1": {
- {
- RowCount: 30,
- ColumnStats: map[string]*ParquetColumnStats{
- "name": createMockColumnStats("name", "Alice", "Zoe"),
- },
- },
- },
- },
- ParquetRowCount: 30,
- LiveLogRowCount: 0,
- PartitionsCount: 1,
- },
- aggregations: []AggregationSpec{
- {Function: FuncMIN, Column: "name"},
- {Function: FuncMAX, Column: "name"},
- },
- validate: func(t *testing.T, results []AggregationResult, err error) {
- assert.NoError(t, err)
- assert.Len(t, results, 2)
- assert.Equal(t, "Alice", results[0].Min)
- assert.Equal(t, "Zoe", results[1].Max)
- },
- },
- {
- name: "Mixed data types - float column",
- dataSources: &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/partition-1": {
- {
- RowCount: 25,
- ColumnStats: map[string]*ParquetColumnStats{
- "price": createMockColumnStats("price", float64(19.99), float64(299.50)),
- },
- },
- },
- },
- ParquetRowCount: 25,
- LiveLogRowCount: 0,
- PartitionsCount: 1,
- },
- aggregations: []AggregationSpec{
- {Function: FuncMIN, Column: "price"},
- {Function: FuncMAX, Column: "price"},
- },
- validate: func(t *testing.T, results []AggregationResult, err error) {
- assert.NoError(t, err)
- assert.Len(t, results, 2)
- assert.Equal(t, float64(19.99), results[0].Min)
- assert.Equal(t, float64(299.50), results[1].Max)
- },
- },
- {
- name: "Column not found in parquet stats",
- dataSources: &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/partition-1": {
- {
- RowCount: 20,
- ColumnStats: map[string]*ParquetColumnStats{
- "id": createMockColumnStats("id", int64(1), int64(100)),
- // Note: "nonexistent_column" is not in stats
- },
- },
- },
- },
- ParquetRowCount: 20,
- LiveLogRowCount: 10, // Has live logs to fall back to
- PartitionsCount: 1,
- },
- aggregations: []AggregationSpec{
- {Function: FuncMIN, Column: "nonexistent_column"},
- {Function: FuncMAX, Column: "nonexistent_column"},
- },
- validate: func(t *testing.T, results []AggregationResult, err error) {
- assert.NoError(t, err)
- assert.Len(t, results, 2)
- // Should fall back to live log processing or return nil
- // The key is that it shouldn't crash
- },
- },
- {
- name: "Multiple parquet files with different ranges",
- dataSources: &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/partition-1": {
- {
- RowCount: 30,
- ColumnStats: map[string]*ParquetColumnStats{
- "score": createMockColumnStats("score", int64(10), int64(50)),
- },
- },
- {
- RowCount: 40,
- ColumnStats: map[string]*ParquetColumnStats{
- "score": createMockColumnStats("score", int64(5), int64(75)), // Lower min, higher max
- },
- },
- },
- },
- ParquetRowCount: 70,
- LiveLogRowCount: 0,
- PartitionsCount: 1,
- },
- aggregations: []AggregationSpec{
- {Function: FuncMIN, Column: "score"},
- {Function: FuncMAX, Column: "score"},
- },
- validate: func(t *testing.T, results []AggregationResult, err error) {
- assert.NoError(t, err)
- assert.Len(t, results, 2)
- assert.Equal(t, int64(5), results[0].Min, "Should find global minimum across all files")
- assert.Equal(t, int64(75), results[1].Max, "Should find global maximum across all files")
- },
- },
- }
-
- partitions := []string{"/topics/test/partition-1"}
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- results, err := computer.ComputeFastPathAggregations(ctx, tt.aggregations, tt.dataSources, partitions)
- tt.validate(t, results, err)
- })
- }
-}
-
-// Test the specific bug where MIN/MAX was returning empty values
-func TestAggregationComputer_MinMaxEmptyValuesBugFix(t *testing.T) {
- engine := NewMockSQLEngine()
- computer := NewAggregationComputer(engine.SQLEngine)
-
- // This test specifically addresses the bug where MIN/MAX returned empty
- // due to improper null checking and extraction logic
- dataSources := &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/test-topic/partition1": {
- {
- RowCount: 100,
- ColumnStats: map[string]*ParquetColumnStats{
- "id": {
- ColumnName: "id",
- MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int32Value{Int32Value: 0}}, // Min should be 0
- MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int32Value{Int32Value: 99}}, // Max should be 99
- NullCount: 0,
- RowCount: 100,
- },
- },
- },
- },
- },
- ParquetRowCount: 100,
- LiveLogRowCount: 0, // No live logs, pure parquet stats
- PartitionsCount: 1,
- }
-
- partitions := []string{"/topics/test/test-topic/partition1"}
-
- tests := []struct {
- name string
- aggregSpec AggregationSpec
- expected interface{}
- }{
- {
- name: "MIN should return 0 not empty",
- aggregSpec: AggregationSpec{Function: FuncMIN, Column: "id"},
- expected: int32(0), // Should extract the actual minimum value
- },
- {
- name: "MAX should return 99 not empty",
- aggregSpec: AggregationSpec{Function: FuncMAX, Column: "id"},
- expected: int32(99), // Should extract the actual maximum value
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- results, err := computer.ComputeFastPathAggregations(ctx, []AggregationSpec{tt.aggregSpec}, dataSources, partitions)
-
- assert.NoError(t, err)
- assert.Len(t, results, 1)
-
- // Verify the result is not nil/empty
- if tt.aggregSpec.Function == FuncMIN {
- assert.NotNil(t, results[0].Min, "MIN result should not be nil")
- assert.Equal(t, tt.expected, results[0].Min)
- } else if tt.aggregSpec.Function == FuncMAX {
- assert.NotNil(t, results[0].Max, "MAX result should not be nil")
- assert.Equal(t, tt.expected, results[0].Max)
- }
- })
- }
-}
-
-// Test the formatAggregationResult function with MIN/MAX edge cases
-func TestSQLEngine_FormatAggregationResult_MinMax(t *testing.T) {
- engine := NewTestSQLEngine()
-
- tests := []struct {
- name string
- spec AggregationSpec
- result AggregationResult
- expected string
- }{
- {
- name: "MIN with zero value should not be empty",
- spec: AggregationSpec{Function: FuncMIN, Column: "id"},
- result: AggregationResult{Min: int32(0)},
- expected: "0",
- },
- {
- name: "MAX with large value",
- spec: AggregationSpec{Function: FuncMAX, Column: "id"},
- result: AggregationResult{Max: int32(99)},
- expected: "99",
- },
- {
- name: "MIN with negative value",
- spec: AggregationSpec{Function: FuncMIN, Column: "score"},
- result: AggregationResult{Min: int64(-50)},
- expected: "-50",
- },
- {
- name: "MAX with float value",
- spec: AggregationSpec{Function: FuncMAX, Column: "price"},
- result: AggregationResult{Max: float64(299.99)},
- expected: "299.99",
- },
- {
- name: "MIN with string value",
- spec: AggregationSpec{Function: FuncMIN, Column: "name"},
- result: AggregationResult{Min: "Alice"},
- expected: "Alice",
- },
- {
- name: "MIN with nil should return NULL",
- spec: AggregationSpec{Function: FuncMIN, Column: "missing"},
- result: AggregationResult{Min: nil},
- expected: "", // NULL values display as empty
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- sqlValue := engine.formatAggregationResult(tt.spec, tt.result)
- assert.Equal(t, tt.expected, sqlValue.String())
- })
- }
-}
-
-// Test the direct formatAggregationResult scenario that was originally broken
-func TestSQLEngine_MinMaxBugFixIntegration(t *testing.T) {
- // This test focuses on the core bug fix without the complexity of table discovery
- // It directly tests the scenario where MIN/MAX returned empty due to the bug
-
- engine := NewTestSQLEngine()
-
- // Test the direct formatting path that was failing
- tests := []struct {
- name string
- aggregSpec AggregationSpec
- aggResult AggregationResult
- expectedEmpty bool
- expectedValue string
- }{
- {
- name: "MIN with zero should not be empty (the original bug)",
- aggregSpec: AggregationSpec{Function: FuncMIN, Column: "id", Alias: "MIN(id)"},
- aggResult: AggregationResult{Min: int32(0)}, // This was returning empty before fix
- expectedEmpty: false,
- expectedValue: "0",
- },
- {
- name: "MAX with valid value should not be empty",
- aggregSpec: AggregationSpec{Function: FuncMAX, Column: "id", Alias: "MAX(id)"},
- aggResult: AggregationResult{Max: int32(99)},
- expectedEmpty: false,
- expectedValue: "99",
- },
- {
- name: "MIN with negative value should work",
- aggregSpec: AggregationSpec{Function: FuncMIN, Column: "score", Alias: "MIN(score)"},
- aggResult: AggregationResult{Min: int64(-10)},
- expectedEmpty: false,
- expectedValue: "-10",
- },
- {
- name: "MIN with nil should be empty (expected behavior)",
- aggregSpec: AggregationSpec{Function: FuncMIN, Column: "missing", Alias: "MIN(missing)"},
- aggResult: AggregationResult{Min: nil},
- expectedEmpty: true,
- expectedValue: "",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Test the formatAggregationResult function directly
- sqlValue := engine.formatAggregationResult(tt.aggregSpec, tt.aggResult)
- result := sqlValue.String()
-
- if tt.expectedEmpty {
- assert.Empty(t, result, "Result should be empty for nil values")
- } else {
- assert.NotEmpty(t, result, "Result should not be empty")
- assert.Equal(t, tt.expectedValue, result)
- }
- })
- }
-}
-
-// Test the tryFastParquetAggregation method specifically for the bug
-func TestSQLEngine_FastParquetAggregationBugFix(t *testing.T) {
- // This test verifies that the fast path aggregation logic works correctly
- // and doesn't return nil/empty values when it should return actual data
-
- engine := NewMockSQLEngine()
- computer := NewAggregationComputer(engine.SQLEngine)
-
- // Create realistic data sources that mimic the user's scenario
- dataSources := &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/test-topic/v2025-09-01-22-54-02/0000-0630": {
- {
- RowCount: 100,
- ColumnStats: map[string]*ParquetColumnStats{
- "id": {
- ColumnName: "id",
- MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int32Value{Int32Value: 0}},
- MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int32Value{Int32Value: 99}},
- NullCount: 0,
- RowCount: 100,
- },
- },
- },
- },
- },
- ParquetRowCount: 100,
- LiveLogRowCount: 0, // Pure parquet scenario
- PartitionsCount: 1,
- }
-
- partitions := []string{"/topics/test/test-topic/v2025-09-01-22-54-02/0000-0630"}
-
- tests := []struct {
- name string
- aggregations []AggregationSpec
- validateResults func(t *testing.T, results []AggregationResult)
- }{
- {
- name: "Single MIN aggregation should return value not nil",
- aggregations: []AggregationSpec{
- {Function: FuncMIN, Column: "id", Alias: "MIN(id)"},
- },
- validateResults: func(t *testing.T, results []AggregationResult) {
- assert.Len(t, results, 1)
- assert.NotNil(t, results[0].Min, "MIN result should not be nil")
- assert.Equal(t, int32(0), results[0].Min, "MIN should return the correct minimum value")
- },
- },
- {
- name: "Single MAX aggregation should return value not nil",
- aggregations: []AggregationSpec{
- {Function: FuncMAX, Column: "id", Alias: "MAX(id)"},
- },
- validateResults: func(t *testing.T, results []AggregationResult) {
- assert.Len(t, results, 1)
- assert.NotNil(t, results[0].Max, "MAX result should not be nil")
- assert.Equal(t, int32(99), results[0].Max, "MAX should return the correct maximum value")
- },
- },
- {
- name: "Combined MIN/MAX should both return values",
- aggregations: []AggregationSpec{
- {Function: FuncMIN, Column: "id", Alias: "MIN(id)"},
- {Function: FuncMAX, Column: "id", Alias: "MAX(id)"},
- },
- validateResults: func(t *testing.T, results []AggregationResult) {
- assert.Len(t, results, 2)
- assert.NotNil(t, results[0].Min, "MIN result should not be nil")
- assert.NotNil(t, results[1].Max, "MAX result should not be nil")
- assert.Equal(t, int32(0), results[0].Min)
- assert.Equal(t, int32(99), results[1].Max)
- },
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- results, err := computer.ComputeFastPathAggregations(ctx, tt.aggregations, dataSources, partitions)
-
- assert.NoError(t, err, "ComputeFastPathAggregations should not error")
- tt.validateResults(t, results)
- })
- }
-}
-
-// Test ExecutionPlanBuilder
-func TestExecutionPlanBuilder_BuildAggregationPlan(t *testing.T) {
- engine := NewMockSQLEngine()
- builder := NewExecutionPlanBuilder(engine.SQLEngine)
-
- // Parse a simple SELECT statement using the native parser
- stmt, err := ParseSQL("SELECT COUNT(*) FROM test_topic")
- assert.NoError(t, err)
- selectStmt := stmt.(*SelectStatement)
-
- aggregations := []AggregationSpec{
- {Function: FuncCOUNT, Column: "*"},
- }
-
- strategy := AggregationStrategy{
- CanUseFastPath: true,
- Reason: "all_aggregations_supported",
- }
-
- dataSources := &TopicDataSources{
- ParquetRowCount: 100,
- LiveLogRowCount: 50,
- PartitionsCount: 3,
- ParquetFiles: map[string][]*ParquetFileStats{
- "partition-1": {{RowCount: 50}},
- "partition-2": {{RowCount: 50}},
- },
- }
-
- plan := builder.BuildAggregationPlan(selectStmt, aggregations, strategy, dataSources)
-
- assert.Equal(t, "SELECT", plan.QueryType)
- assert.Equal(t, "hybrid_fast_path", plan.ExecutionStrategy)
- assert.Contains(t, plan.DataSources, "parquet_stats")
- assert.Contains(t, plan.DataSources, "live_logs")
- assert.Equal(t, 3, plan.PartitionsScanned)
- assert.Equal(t, 2, plan.ParquetFilesScanned)
- assert.Contains(t, plan.OptimizationsUsed, "parquet_statistics")
- assert.Equal(t, []string{"COUNT(*)"}, plan.Aggregations)
- assert.Equal(t, int64(50), plan.TotalRowsProcessed) // Only live logs scanned
-}
-
-// Test Error Types
-func TestErrorTypes(t *testing.T) {
- t.Run("AggregationError", func(t *testing.T) {
- err := AggregationError{
- Operation: "MAX",
- Column: "id",
- Cause: errors.New("column not found"),
- }
-
- expected := "aggregation error in MAX(id): column not found"
- assert.Equal(t, expected, err.Error())
- })
-
- t.Run("DataSourceError", func(t *testing.T) {
- err := DataSourceError{
- Source: "partition_discovery:test.topic1",
- Cause: errors.New("network timeout"),
- }
-
- expected := "data source error in partition_discovery:test.topic1: network timeout"
- assert.Equal(t, expected, err.Error())
- })
-
- t.Run("OptimizationError", func(t *testing.T) {
- err := OptimizationError{
- Strategy: "fast_path_aggregation",
- Reason: "unsupported function: AVG",
- }
-
- expected := "optimization failed for fast_path_aggregation: unsupported function: AVG"
- assert.Equal(t, expected, err.Error())
- })
-}
-
-// Integration Tests
-func TestIntegration_FastPathOptimization(t *testing.T) {
- engine := NewMockSQLEngine()
-
- // Setup components
- optimizer := NewFastPathOptimizer(engine.SQLEngine)
- computer := NewAggregationComputer(engine.SQLEngine)
-
- // Mock data setup
- aggregations := []AggregationSpec{
- {Function: FuncCOUNT, Column: "*"},
- {Function: FuncMAX, Column: "id"},
- }
-
- // Step 1: Determine strategy
- strategy := optimizer.DetermineStrategy(aggregations)
- assert.True(t, strategy.CanUseFastPath)
-
- // Step 2: Mock data sources
- dataSources := &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/topic1/partition-1": {{
- RowCount: 75,
- ColumnStats: map[string]*ParquetColumnStats{
- "id": createMockColumnStats("id", int64(1), int64(100)),
- },
- }},
- },
- ParquetRowCount: 75,
- LiveLogRowCount: 25,
- PartitionsCount: 1,
- }
-
- partitions := []string{"/topics/test/topic1/partition-1"}
-
- // Step 3: Compute aggregations
- ctx := context.Background()
- results, err := computer.ComputeFastPathAggregations(ctx, aggregations, dataSources, partitions)
- assert.NoError(t, err)
- assert.Len(t, results, 2)
- assert.Equal(t, int64(100), results[0].Count) // 75 + 25
- assert.Equal(t, int64(100), results[1].Max) // From parquet stats mock
-}
-
-func TestIntegration_FallbackToFullScan(t *testing.T) {
- engine := NewMockSQLEngine()
- optimizer := NewFastPathOptimizer(engine.SQLEngine)
-
- // Unsupported aggregations
- aggregations := []AggregationSpec{
- {Function: "AVG", Column: "value"}, // Not supported
- }
-
- // Step 1: Strategy should reject fast path
- strategy := optimizer.DetermineStrategy(aggregations)
- assert.False(t, strategy.CanUseFastPath)
- assert.Equal(t, "unsupported_aggregation_functions", strategy.Reason)
- assert.NotEmpty(t, strategy.UnsupportedSpecs)
-}
-
-// Benchmark Tests
-func BenchmarkFastPathOptimizer_DetermineStrategy(b *testing.B) {
- engine := NewMockSQLEngine()
- optimizer := NewFastPathOptimizer(engine.SQLEngine)
-
- aggregations := []AggregationSpec{
- {Function: FuncCOUNT, Column: "*"},
- {Function: FuncMAX, Column: "id"},
- {Function: "MIN", Column: "value"},
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- strategy := optimizer.DetermineStrategy(aggregations)
- _ = strategy.CanUseFastPath
- }
-}
-
-func BenchmarkAggregationComputer_ComputeFastPathAggregations(b *testing.B) {
- engine := NewMockSQLEngine()
- computer := NewAggregationComputer(engine.SQLEngine)
-
- dataSources := &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "partition-1": {{
- RowCount: 1000,
- ColumnStats: map[string]*ParquetColumnStats{
- "id": createMockColumnStats("id", int64(1), int64(1000)),
- },
- }},
- },
- ParquetRowCount: 1000,
- LiveLogRowCount: 100,
- }
-
- aggregations := []AggregationSpec{
- {Function: FuncCOUNT, Column: "*"},
- {Function: FuncMAX, Column: "id"},
- }
-
- partitions := []string{"partition-1"}
- ctx := context.Background()
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- results, err := computer.ComputeFastPathAggregations(ctx, aggregations, dataSources, partitions)
- if err != nil {
- b.Fatal(err)
- }
- _ = results
- }
-}
-
-// Tests for convertLogEntryToRecordValue - Protocol Buffer parsing bug fix
-func TestSQLEngine_ConvertLogEntryToRecordValue_ValidProtobuf(t *testing.T) {
- engine := NewTestSQLEngine()
-
- // Create a valid RecordValue protobuf with user data
- originalRecord := &schema_pb.RecordValue{
- Fields: map[string]*schema_pb.Value{
- "id": {Kind: &schema_pb.Value_Int32Value{Int32Value: 42}},
- "name": {Kind: &schema_pb.Value_StringValue{StringValue: "test-user"}},
- "score": {Kind: &schema_pb.Value_DoubleValue{DoubleValue: 95.5}},
- },
- }
-
- // Serialize the protobuf (this is what MQ actually stores)
- protobufData, err := proto.Marshal(originalRecord)
- assert.NoError(t, err)
-
- // Create a LogEntry with the serialized data
- logEntry := &filer_pb.LogEntry{
- TsNs: 1609459200000000000, // 2021-01-01 00:00:00 UTC
- PartitionKeyHash: 123,
- Data: protobufData, // Protocol buffer data (not JSON!)
- Key: []byte("test-key-001"),
- }
-
- // Test the conversion
- result, source, err := engine.convertLogEntryToRecordValue(logEntry)
-
- // Verify no error
- assert.NoError(t, err)
- assert.Equal(t, "live_log", source)
- assert.NotNil(t, result)
- assert.NotNil(t, result.Fields)
-
- // Verify system columns are added correctly
- assert.Contains(t, result.Fields, SW_COLUMN_NAME_TIMESTAMP)
- assert.Contains(t, result.Fields, SW_COLUMN_NAME_KEY)
- assert.Equal(t, int64(1609459200000000000), result.Fields[SW_COLUMN_NAME_TIMESTAMP].GetInt64Value())
- assert.Equal(t, []byte("test-key-001"), result.Fields[SW_COLUMN_NAME_KEY].GetBytesValue())
-
- // Verify user data is preserved
- assert.Contains(t, result.Fields, "id")
- assert.Contains(t, result.Fields, "name")
- assert.Contains(t, result.Fields, "score")
- assert.Equal(t, int32(42), result.Fields["id"].GetInt32Value())
- assert.Equal(t, "test-user", result.Fields["name"].GetStringValue())
- assert.Equal(t, 95.5, result.Fields["score"].GetDoubleValue())
-}
-
-func TestSQLEngine_ConvertLogEntryToRecordValue_InvalidProtobuf(t *testing.T) {
- engine := NewTestSQLEngine()
-
- // Create LogEntry with invalid protobuf data (this would cause the original JSON parsing bug)
- logEntry := &filer_pb.LogEntry{
- TsNs: 1609459200000000000,
- PartitionKeyHash: 123,
- Data: []byte{0x17, 0x00, 0xFF, 0xFE}, // Invalid protobuf data (starts with \x17 like in the original error)
- Key: []byte("test-key"),
- }
-
- // Test the conversion
- result, source, err := engine.convertLogEntryToRecordValue(logEntry)
-
- // Should return error for invalid protobuf
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "failed to unmarshal log entry protobuf")
- assert.Nil(t, result)
- assert.Empty(t, source)
-}
-
-func TestSQLEngine_ConvertLogEntryToRecordValue_EmptyProtobuf(t *testing.T) {
- engine := NewTestSQLEngine()
-
- // Create a minimal valid RecordValue (empty fields)
- emptyRecord := &schema_pb.RecordValue{
- Fields: map[string]*schema_pb.Value{},
- }
- protobufData, err := proto.Marshal(emptyRecord)
- assert.NoError(t, err)
-
- logEntry := &filer_pb.LogEntry{
- TsNs: 1609459200000000000,
- PartitionKeyHash: 456,
- Data: protobufData,
- Key: []byte("empty-key"),
- }
-
- // Test the conversion
- result, source, err := engine.convertLogEntryToRecordValue(logEntry)
-
- // Should succeed and add system columns
- assert.NoError(t, err)
- assert.Equal(t, "live_log", source)
- assert.NotNil(t, result)
- assert.NotNil(t, result.Fields)
-
- // Should have system columns
- assert.Contains(t, result.Fields, SW_COLUMN_NAME_TIMESTAMP)
- assert.Contains(t, result.Fields, SW_COLUMN_NAME_KEY)
- assert.Equal(t, int64(1609459200000000000), result.Fields[SW_COLUMN_NAME_TIMESTAMP].GetInt64Value())
- assert.Equal(t, []byte("empty-key"), result.Fields[SW_COLUMN_NAME_KEY].GetBytesValue())
-
- // Should have no user fields
- userFieldCount := 0
- for fieldName := range result.Fields {
- if fieldName != SW_COLUMN_NAME_TIMESTAMP && fieldName != SW_COLUMN_NAME_KEY {
- userFieldCount++
- }
- }
- assert.Equal(t, 0, userFieldCount)
-}
-
-func TestSQLEngine_ConvertLogEntryToRecordValue_NilFieldsMap(t *testing.T) {
- engine := NewTestSQLEngine()
-
- // Create RecordValue with nil Fields map (edge case)
- recordWithNilFields := &schema_pb.RecordValue{
- Fields: nil, // This should be handled gracefully
- }
- protobufData, err := proto.Marshal(recordWithNilFields)
- assert.NoError(t, err)
-
- logEntry := &filer_pb.LogEntry{
- TsNs: 1609459200000000000,
- PartitionKeyHash: 789,
- Data: protobufData,
- Key: []byte("nil-fields-key"),
- }
-
- // Test the conversion
- result, source, err := engine.convertLogEntryToRecordValue(logEntry)
-
- // Should succeed and create Fields map
- assert.NoError(t, err)
- assert.Equal(t, "live_log", source)
- assert.NotNil(t, result)
- assert.NotNil(t, result.Fields) // Should be created by the function
-
- // Should have system columns
- assert.Contains(t, result.Fields, SW_COLUMN_NAME_TIMESTAMP)
- assert.Contains(t, result.Fields, SW_COLUMN_NAME_KEY)
- assert.Equal(t, int64(1609459200000000000), result.Fields[SW_COLUMN_NAME_TIMESTAMP].GetInt64Value())
- assert.Equal(t, []byte("nil-fields-key"), result.Fields[SW_COLUMN_NAME_KEY].GetBytesValue())
-}
-
-func TestSQLEngine_ConvertLogEntryToRecordValue_SystemColumnOverride(t *testing.T) {
- engine := NewTestSQLEngine()
-
- // Create RecordValue that already has system column names (should be overridden)
- recordWithSystemCols := &schema_pb.RecordValue{
- Fields: map[string]*schema_pb.Value{
- "user_field": {Kind: &schema_pb.Value_StringValue{StringValue: "user-data"}},
- SW_COLUMN_NAME_TIMESTAMP: {Kind: &schema_pb.Value_Int64Value{Int64Value: 999999999}}, // Should be overridden
- SW_COLUMN_NAME_KEY: {Kind: &schema_pb.Value_StringValue{StringValue: "old-key"}}, // Should be overridden
- },
- }
- protobufData, err := proto.Marshal(recordWithSystemCols)
- assert.NoError(t, err)
-
- logEntry := &filer_pb.LogEntry{
- TsNs: 1609459200000000000,
- PartitionKeyHash: 100,
- Data: protobufData,
- Key: []byte("actual-key"),
- }
-
- // Test the conversion
- result, source, err := engine.convertLogEntryToRecordValue(logEntry)
-
- // Should succeed
- assert.NoError(t, err)
- assert.Equal(t, "live_log", source)
- assert.NotNil(t, result)
-
- // System columns should use LogEntry values, not protobuf values
- assert.Equal(t, int64(1609459200000000000), result.Fields[SW_COLUMN_NAME_TIMESTAMP].GetInt64Value())
- assert.Equal(t, []byte("actual-key"), result.Fields[SW_COLUMN_NAME_KEY].GetBytesValue())
-
- // User field should be preserved
- assert.Contains(t, result.Fields, "user_field")
- assert.Equal(t, "user-data", result.Fields["user_field"].GetStringValue())
-}
-
-func TestSQLEngine_ConvertLogEntryToRecordValue_ComplexDataTypes(t *testing.T) {
- engine := NewTestSQLEngine()
-
- // Test with various data types
- complexRecord := &schema_pb.RecordValue{
- Fields: map[string]*schema_pb.Value{
- "int32_field": {Kind: &schema_pb.Value_Int32Value{Int32Value: -42}},
- "int64_field": {Kind: &schema_pb.Value_Int64Value{Int64Value: 9223372036854775807}},
- "float_field": {Kind: &schema_pb.Value_FloatValue{FloatValue: 3.14159}},
- "double_field": {Kind: &schema_pb.Value_DoubleValue{DoubleValue: 2.718281828}},
- "bool_field": {Kind: &schema_pb.Value_BoolValue{BoolValue: true}},
- "string_field": {Kind: &schema_pb.Value_StringValue{StringValue: "test string with unicode party"}},
- "bytes_field": {Kind: &schema_pb.Value_BytesValue{BytesValue: []byte{0x01, 0x02, 0x03}}},
- },
- }
- protobufData, err := proto.Marshal(complexRecord)
- assert.NoError(t, err)
-
- logEntry := &filer_pb.LogEntry{
- TsNs: 1609459200000000000,
- PartitionKeyHash: 200,
- Data: protobufData,
- Key: []byte("complex-key"),
- }
-
- // Test the conversion
- result, source, err := engine.convertLogEntryToRecordValue(logEntry)
-
- // Should succeed
- assert.NoError(t, err)
- assert.Equal(t, "live_log", source)
- assert.NotNil(t, result)
-
- // Verify all data types are preserved
- assert.Equal(t, int32(-42), result.Fields["int32_field"].GetInt32Value())
- assert.Equal(t, int64(9223372036854775807), result.Fields["int64_field"].GetInt64Value())
- assert.Equal(t, float32(3.14159), result.Fields["float_field"].GetFloatValue())
- assert.Equal(t, 2.718281828, result.Fields["double_field"].GetDoubleValue())
- assert.Equal(t, true, result.Fields["bool_field"].GetBoolValue())
- assert.Equal(t, "test string with unicode party", result.Fields["string_field"].GetStringValue())
- assert.Equal(t, []byte{0x01, 0x02, 0x03}, result.Fields["bytes_field"].GetBytesValue())
-
- // System columns should still be present
- assert.Contains(t, result.Fields, SW_COLUMN_NAME_TIMESTAMP)
- assert.Contains(t, result.Fields, SW_COLUMN_NAME_KEY)
-}
-
-// Tests for log buffer deduplication functionality
-func TestSQLEngine_GetLogBufferStartFromFile_BinaryFormat(t *testing.T) {
- engine := NewTestSQLEngine()
-
- // Create sample buffer start (binary format)
- bufferStartBytes := make([]byte, 8)
- binary.BigEndian.PutUint64(bufferStartBytes, uint64(1609459100000000001))
-
- // Create file entry with buffer start + some chunks
- entry := &filer_pb.Entry{
- Name: "test-log-file",
- Extended: map[string][]byte{
- "buffer_start": bufferStartBytes,
- },
- Chunks: []*filer_pb.FileChunk{
- {FileId: "chunk1", Offset: 0, Size: 1000},
- {FileId: "chunk2", Offset: 1000, Size: 1000},
- {FileId: "chunk3", Offset: 2000, Size: 1000},
- },
- }
-
- // Test extraction
- result, err := engine.getLogBufferStartFromFile(entry)
- assert.NoError(t, err)
- assert.NotNil(t, result)
- assert.Equal(t, int64(1609459100000000001), result.StartIndex)
-
- // Test extraction works correctly with the binary format
-}
-
-func TestSQLEngine_GetLogBufferStartFromFile_NoMetadata(t *testing.T) {
- engine := NewTestSQLEngine()
-
- // Create file entry without buffer start
- entry := &filer_pb.Entry{
- Name: "test-log-file",
- Extended: nil,
- }
-
- // Test extraction
- result, err := engine.getLogBufferStartFromFile(entry)
- assert.NoError(t, err)
- assert.Nil(t, result)
-}
-
-func TestSQLEngine_GetLogBufferStartFromFile_InvalidData(t *testing.T) {
- engine := NewTestSQLEngine()
-
- // Create file entry with invalid buffer start (wrong size)
- entry := &filer_pb.Entry{
- Name: "test-log-file",
- Extended: map[string][]byte{
- "buffer_start": []byte("invalid-binary"),
- },
- }
-
- // Test extraction
- result, err := engine.getLogBufferStartFromFile(entry)
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "invalid buffer_start format: expected 8 bytes")
- assert.Nil(t, result)
-}
-
-func TestSQLEngine_BuildLogBufferDeduplicationMap_NoBrokerClient(t *testing.T) {
- engine := NewTestSQLEngine()
- engine.catalog.brokerClient = nil // Simulate no broker client
-
- ctx := context.Background()
- result, err := engine.buildLogBufferDeduplicationMap(ctx, "/topics/test/test-topic")
-
- assert.NoError(t, err)
- assert.NotNil(t, result)
- assert.Empty(t, result)
-}
-
-func TestSQLEngine_LogBufferDeduplication_ServerRestartScenario(t *testing.T) {
- // Simulate scenario: Buffer indexes are now initialized with process start time
- // This tests that buffer start indexes are globally unique across server restarts
-
- // Before server restart: Process 1 buffer start (3 chunks)
- beforeRestartStart := LogBufferStart{
- StartIndex: 1609459100000000000, // Process 1 start time
- }
-
- // After server restart: Process 2 buffer start (3 chunks)
- afterRestartStart := LogBufferStart{
- StartIndex: 1609459300000000000, // Process 2 start time (DIFFERENT)
- }
-
- // Simulate 3 chunks for each file
- chunkCount := int64(3)
-
- // Calculate end indexes for range comparison
- beforeEnd := beforeRestartStart.StartIndex + chunkCount - 1 // [start, start+2]
- afterStart := afterRestartStart.StartIndex // [start, start+2]
-
- // Test range overlap detection (should NOT overlap)
- overlaps := beforeRestartStart.StartIndex <= (afterStart+chunkCount-1) && beforeEnd >= afterStart
- assert.False(t, overlaps, "Buffer ranges after restart should not overlap")
-
- // Verify the start indexes are globally unique
- assert.NotEqual(t, beforeRestartStart.StartIndex, afterRestartStart.StartIndex, "Start indexes should be different")
- assert.Less(t, beforeEnd, afterStart, "Ranges should be completely separate")
-
- // Expected values:
- // Before restart: [1609459100000000000, 1609459100000000002]
- // After restart: [1609459300000000000, 1609459300000000002]
- expectedBeforeEnd := int64(1609459100000000002)
- expectedAfterStart := int64(1609459300000000000)
-
- assert.Equal(t, expectedBeforeEnd, beforeEnd)
- assert.Equal(t, expectedAfterStart, afterStart)
-
- // This demonstrates that buffer start indexes initialized with process start time
- // prevent false positive duplicates across server restarts
-}
-
-// TestGetSQLValAlias tests the getSQLValAlias function, particularly for SQL injection prevention
-func TestGetSQLValAlias(t *testing.T) {
- engine := &SQLEngine{}
-
- tests := []struct {
- name string
- sqlVal *SQLVal
- expected string
- desc string
- }{
- {
- name: "simple string",
- sqlVal: &SQLVal{
- Type: StrVal,
- Val: []byte("hello"),
- },
- expected: "'hello'",
- desc: "Simple string should be wrapped in single quotes",
- },
- {
- name: "string with single quote",
- sqlVal: &SQLVal{
- Type: StrVal,
- Val: []byte("don't"),
- },
- expected: "'don''t'",
- desc: "String with single quote should have the quote escaped by doubling it",
- },
- {
- name: "string with multiple single quotes",
- sqlVal: &SQLVal{
- Type: StrVal,
- Val: []byte("'malicious'; DROP TABLE users; --"),
- },
- expected: "'''malicious''; DROP TABLE users; --'",
- desc: "String with SQL injection attempt should have all single quotes properly escaped",
- },
- {
- name: "empty string",
- sqlVal: &SQLVal{
- Type: StrVal,
- Val: []byte(""),
- },
- expected: "''",
- desc: "Empty string should result in empty quoted string",
- },
- {
- name: "integer value",
- sqlVal: &SQLVal{
- Type: IntVal,
- Val: []byte("123"),
- },
- expected: "123",
- desc: "Integer value should not be quoted",
- },
- {
- name: "float value",
- sqlVal: &SQLVal{
- Type: FloatVal,
- Val: []byte("123.45"),
- },
- expected: "123.45",
- desc: "Float value should not be quoted",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := engine.getSQLValAlias(tt.sqlVal)
- assert.Equal(t, tt.expected, result, tt.desc)
- })
- }
-}
diff --git a/weed/query/engine/errors.go b/weed/query/engine/errors.go
index 6a297d92f..2c68ab10d 100644
--- a/weed/query/engine/errors.go
+++ b/weed/query/engine/errors.go
@@ -44,7 +44,7 @@ type ParseError struct {
func (e ParseError) Error() string {
if e.Cause != nil {
- return fmt.Sprintf("SQL parse error: %s (%v)", e.Message, e.Cause)
+ return fmt.Sprintf("SQL parse error: %s (caused by: %v)", e.Message, e.Cause)
}
return fmt.Sprintf("SQL parse error: %s", e.Message)
}
diff --git a/weed/query/engine/execution_plan_fast_path_test.go b/weed/query/engine/execution_plan_fast_path_test.go
deleted file mode 100644
index c0f08fa21..000000000
--- a/weed/query/engine/execution_plan_fast_path_test.go
+++ /dev/null
@@ -1,133 +0,0 @@
-package engine
-
-import (
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
- "github.com/stretchr/testify/assert"
-)
-
-// TestExecutionPlanFastPathDisplay tests that the execution plan correctly shows
-// "Parquet Statistics (fast path)" when fast path is used, not "Parquet Files (full scan)"
-func TestExecutionPlanFastPathDisplay(t *testing.T) {
- engine := NewMockSQLEngine()
-
- // Create realistic data sources for fast path scenario
- dataSources := &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/topic/partition-1": {
- {
- RowCount: 500,
- ColumnStats: map[string]*ParquetColumnStats{
- "id": {
- ColumnName: "id",
- MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1}},
- MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 500}},
- NullCount: 0,
- RowCount: 500,
- },
- },
- },
- },
- },
- ParquetRowCount: 500,
- LiveLogRowCount: 0, // Pure parquet scenario - ideal for fast path
- PartitionsCount: 1,
- }
-
- t.Run("Fast path execution plan shows correct data sources", func(t *testing.T) {
- optimizer := NewFastPathOptimizer(engine.SQLEngine)
-
- aggregations := []AggregationSpec{
- {Function: FuncCOUNT, Column: "*", Alias: "COUNT(*)"},
- }
-
- // Test the strategy determination
- strategy := optimizer.DetermineStrategy(aggregations)
- assert.True(t, strategy.CanUseFastPath, "Strategy should allow fast path for COUNT(*)")
- assert.Equal(t, "all_aggregations_supported", strategy.Reason)
-
- // Test data source list building
- builder := &ExecutionPlanBuilder{}
- dataSources := &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/topic/partition-1": {
- {RowCount: 500},
- },
- },
- ParquetRowCount: 500,
- LiveLogRowCount: 0,
- PartitionsCount: 1,
- }
-
- dataSourcesList := builder.buildDataSourcesList(strategy, dataSources)
-
- // When fast path is used, should show "parquet_stats" not "parquet_files"
- assert.Contains(t, dataSourcesList, "parquet_stats",
- "Data sources should contain 'parquet_stats' when fast path is used")
- assert.NotContains(t, dataSourcesList, "parquet_files",
- "Data sources should NOT contain 'parquet_files' when fast path is used")
-
- // Test that the formatting works correctly
- formattedSource := engine.SQLEngine.formatDataSource("parquet_stats")
- assert.Equal(t, "Parquet Statistics (fast path)", formattedSource,
- "parquet_stats should format to 'Parquet Statistics (fast path)'")
-
- formattedFullScan := engine.SQLEngine.formatDataSource("parquet_files")
- assert.Equal(t, "Parquet Files (full scan)", formattedFullScan,
- "parquet_files should format to 'Parquet Files (full scan)'")
- })
-
- t.Run("Slow path execution plan shows full scan data sources", func(t *testing.T) {
- builder := &ExecutionPlanBuilder{}
-
- // Create strategy that cannot use fast path
- strategy := AggregationStrategy{
- CanUseFastPath: false,
- Reason: "unsupported_aggregation_functions",
- }
-
- dataSourcesList := builder.buildDataSourcesList(strategy, dataSources)
-
- // When slow path is used, should show "parquet_files" and "live_logs"
- assert.Contains(t, dataSourcesList, "parquet_files",
- "Slow path should contain 'parquet_files'")
- assert.Contains(t, dataSourcesList, "live_logs",
- "Slow path should contain 'live_logs'")
- assert.NotContains(t, dataSourcesList, "parquet_stats",
- "Slow path should NOT contain 'parquet_stats'")
- })
-
- t.Run("Data source formatting works correctly", func(t *testing.T) {
- // Test just the data source formatting which is the key fix
-
- // Test parquet_stats formatting (fast path)
- fastPathFormatted := engine.SQLEngine.formatDataSource("parquet_stats")
- assert.Equal(t, "Parquet Statistics (fast path)", fastPathFormatted,
- "parquet_stats should format to show fast path usage")
-
- // Test parquet_files formatting (slow path)
- slowPathFormatted := engine.SQLEngine.formatDataSource("parquet_files")
- assert.Equal(t, "Parquet Files (full scan)", slowPathFormatted,
- "parquet_files should format to show full scan")
-
- // Test that data sources list is built correctly for fast path
- builder := &ExecutionPlanBuilder{}
- fastStrategy := AggregationStrategy{CanUseFastPath: true}
-
- fastSources := builder.buildDataSourcesList(fastStrategy, dataSources)
- assert.Contains(t, fastSources, "parquet_stats",
- "Fast path should include parquet_stats")
- assert.NotContains(t, fastSources, "parquet_files",
- "Fast path should NOT include parquet_files")
-
- // Test that data sources list is built correctly for slow path
- slowStrategy := AggregationStrategy{CanUseFastPath: false}
-
- slowSources := builder.buildDataSourcesList(slowStrategy, dataSources)
- assert.Contains(t, slowSources, "parquet_files",
- "Slow path should include parquet_files")
- assert.NotContains(t, slowSources, "parquet_stats",
- "Slow path should NOT include parquet_stats")
- })
-}
diff --git a/weed/query/engine/fast_path_fix_test.go b/weed/query/engine/fast_path_fix_test.go
deleted file mode 100644
index 3769e9215..000000000
--- a/weed/query/engine/fast_path_fix_test.go
+++ /dev/null
@@ -1,193 +0,0 @@
-package engine
-
-import (
- "context"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
- "github.com/stretchr/testify/assert"
-)
-
-// TestFastPathCountFixRealistic tests the specific scenario mentioned in the bug report:
-// Fast path returning 0 for COUNT(*) when slow path returns 1803
-func TestFastPathCountFixRealistic(t *testing.T) {
- engine := NewMockSQLEngine()
-
- // Set up debug mode to see our new logging
- ctx := context.WithValue(context.Background(), "debug", true)
-
- // Create realistic data sources that mimic a scenario with 1803 rows
- dataSources := &TopicDataSources{
- ParquetFiles: map[string][]*ParquetFileStats{
- "/topics/test/large-topic/0000-1023": {
- {
- RowCount: 800,
- ColumnStats: map[string]*ParquetColumnStats{
- "id": {
- ColumnName: "id",
- MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1}},
- MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 800}},
- NullCount: 0,
- RowCount: 800,
- },
- },
- },
- {
- RowCount: 500,
- ColumnStats: map[string]*ParquetColumnStats{
- "id": {
- ColumnName: "id",
- MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 801}},
- MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1300}},
- NullCount: 0,
- RowCount: 500,
- },
- },
- },
- },
- "/topics/test/large-topic/1024-2047": {
- {
- RowCount: 300,
- ColumnStats: map[string]*ParquetColumnStats{
- "id": {
- ColumnName: "id",
- MinValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1301}},
- MaxValue: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1600}},
- NullCount: 0,
- RowCount: 300,
- },
- },
- },
- },
- },
- ParquetRowCount: 1600, // 800 + 500 + 300
- LiveLogRowCount: 203, // Additional live log data
- PartitionsCount: 2,
- LiveLogFilesCount: 15,
- }
-
- partitions := []string{
- "/topics/test/large-topic/0000-1023",
- "/topics/test/large-topic/1024-2047",
- }
-
- t.Run("COUNT(*) should return correct total (1803)", func(t *testing.T) {
- computer := NewAggregationComputer(engine.SQLEngine)
-
- aggregations := []AggregationSpec{
- {Function: FuncCOUNT, Column: "*", Alias: "COUNT(*)"},
- }
-
- results, err := computer.ComputeFastPathAggregations(ctx, aggregations, dataSources, partitions)
-
- assert.NoError(t, err, "Fast path aggregation should not error")
- assert.Len(t, results, 1, "Should return one result")
-
- // This is the key test - before our fix, this was returning 0
- expectedCount := int64(1803) // 1600 (parquet) + 203 (live log)
- actualCount := results[0].Count
-
- assert.Equal(t, expectedCount, actualCount,
- "COUNT(*) should return %d (1600 parquet + 203 live log), but got %d",
- expectedCount, actualCount)
- })
-
- t.Run("MIN/MAX should work with multiple partitions", func(t *testing.T) {
- computer := NewAggregationComputer(engine.SQLEngine)
-
- aggregations := []AggregationSpec{
- {Function: FuncMIN, Column: "id", Alias: "MIN(id)"},
- {Function: FuncMAX, Column: "id", Alias: "MAX(id)"},
- }
-
- results, err := computer.ComputeFastPathAggregations(ctx, aggregations, dataSources, partitions)
-
- assert.NoError(t, err, "Fast path aggregation should not error")
- assert.Len(t, results, 2, "Should return two results")
-
- // MIN should be the lowest across all parquet files
- assert.Equal(t, int64(1), results[0].Min, "MIN should be 1")
-
- // MAX should be the highest across all parquet files
- assert.Equal(t, int64(1600), results[1].Max, "MAX should be 1600")
- })
-}
-
-// TestFastPathDataSourceDiscoveryLogging tests that our debug logging works correctly
-func TestFastPathDataSourceDiscoveryLogging(t *testing.T) {
- // This test verifies that our enhanced data source collection structure is correct
-
- t.Run("DataSources structure validation", func(t *testing.T) {
- // Test the TopicDataSources structure initialization
- dataSources := &TopicDataSources{
- ParquetFiles: make(map[string][]*ParquetFileStats),
- ParquetRowCount: 0,
- LiveLogRowCount: 0,
- LiveLogFilesCount: 0,
- PartitionsCount: 0,
- }
-
- assert.NotNil(t, dataSources, "Data sources should not be nil")
- assert.NotNil(t, dataSources.ParquetFiles, "ParquetFiles map should be initialized")
- assert.GreaterOrEqual(t, dataSources.PartitionsCount, 0, "PartitionsCount should be non-negative")
- assert.GreaterOrEqual(t, dataSources.ParquetRowCount, int64(0), "ParquetRowCount should be non-negative")
- assert.GreaterOrEqual(t, dataSources.LiveLogRowCount, int64(0), "LiveLogRowCount should be non-negative")
- })
-}
-
-// TestFastPathValidationLogic tests the enhanced validation we added
-func TestFastPathValidationLogic(t *testing.T) {
- t.Run("Validation catches data source vs computation mismatch", func(t *testing.T) {
- // Create a scenario where data sources and computation might be inconsistent
- dataSources := &TopicDataSources{
- ParquetFiles: make(map[string][]*ParquetFileStats),
- ParquetRowCount: 1000, // Data sources say 1000 rows
- LiveLogRowCount: 0,
- PartitionsCount: 1,
- }
-
- // But aggregation result says different count (simulating the original bug)
- aggResults := []AggregationResult{
- {Count: 0}, // Bug: returns 0 when data sources show 1000
- }
-
- // This simulates the validation logic from tryFastParquetAggregation
- totalRows := dataSources.ParquetRowCount + dataSources.LiveLogRowCount
- countResult := aggResults[0].Count
-
- // Our validation should catch this mismatch
- assert.NotEqual(t, totalRows, countResult,
- "This test simulates the bug: data sources show %d but COUNT returns %d",
- totalRows, countResult)
-
- // In the real code, this would trigger a fallback to slow path
- validationPassed := (countResult == totalRows)
- assert.False(t, validationPassed, "Validation should fail for inconsistent data")
- })
-
- t.Run("Validation passes for consistent data", func(t *testing.T) {
- // Create a scenario where everything is consistent
- dataSources := &TopicDataSources{
- ParquetFiles: make(map[string][]*ParquetFileStats),
- ParquetRowCount: 1000,
- LiveLogRowCount: 803,
- PartitionsCount: 1,
- }
-
- // Aggregation result matches data sources
- aggResults := []AggregationResult{
- {Count: 1803}, // Correct: matches 1000 + 803
- }
-
- totalRows := dataSources.ParquetRowCount + dataSources.LiveLogRowCount
- countResult := aggResults[0].Count
-
- // Our validation should pass this
- assert.Equal(t, totalRows, countResult,
- "Validation should pass when data sources (%d) match COUNT result (%d)",
- totalRows, countResult)
-
- validationPassed := (countResult == totalRows)
- assert.True(t, validationPassed, "Validation should pass for consistent data")
- })
-}
diff --git a/weed/query/engine/parquet_scanner.go b/weed/query/engine/parquet_scanner.go
index 9bcced904..4c33df76f 100644
--- a/weed/query/engine/parquet_scanner.go
+++ b/weed/query/engine/parquet_scanner.go
@@ -1,280 +1,14 @@
package engine
import (
- "context"
"fmt"
"math/big"
"time"
- "github.com/seaweedfs/seaweedfs/weed/mq/schema"
- "github.com/seaweedfs/seaweedfs/weed/mq/topic"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/pb/mq_pb"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
"github.com/seaweedfs/seaweedfs/weed/query/sqltypes"
- "github.com/seaweedfs/seaweedfs/weed/util/chunk_cache"
)
-// ParquetScanner scans MQ topic Parquet files for SELECT queries
-// Assumptions:
-// 1. All MQ messages are stored in Parquet format in topic partitions
-// 2. Each partition directory contains dated Parquet files
-// 3. System columns (_ts_ns, _key) are added to user schema
-// 4. Predicate pushdown is used for efficient scanning
-type ParquetScanner struct {
- filerClient filer_pb.FilerClient
- chunkCache chunk_cache.ChunkCache
- topic topic.Topic
- recordSchema *schema_pb.RecordType
- parquetLevels *schema.ParquetLevels
-}
-
-// NewParquetScanner creates a scanner for a specific MQ topic
-// Assumption: Topic exists and has Parquet files in partition directories
-func NewParquetScanner(filerClient filer_pb.FilerClient, namespace, topicName string) (*ParquetScanner, error) {
- // Check if filerClient is available
- if filerClient == nil {
- return nil, fmt.Errorf("filerClient is required but not available")
- }
-
- // Create topic reference
- t := topic.Topic{
- Namespace: namespace,
- Name: topicName,
- }
-
- // Read topic configuration to get schema
- var topicConf *mq_pb.ConfigureTopicResponse
- var err error
- if err := filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
- topicConf, err = t.ReadConfFile(client)
- return err
- }); err != nil {
- return nil, fmt.Errorf("failed to read topic config: %v", err)
- }
-
- // Build complete schema with system columns - prefer flat schema if available
- var recordType *schema_pb.RecordType
-
- if topicConf.GetMessageRecordType() != nil {
- // New flat schema format - use directly
- recordType = topicConf.GetMessageRecordType()
- }
-
- if recordType == nil || len(recordType.Fields) == 0 {
- // For topics without schema, create a minimal schema with system fields and _value
- recordType = schema.RecordTypeBegin().
- WithField(SW_COLUMN_NAME_TIMESTAMP, schema.TypeInt64).
- WithField(SW_COLUMN_NAME_KEY, schema.TypeBytes).
- WithField(SW_COLUMN_NAME_VALUE, schema.TypeBytes). // Raw message value
- RecordTypeEnd()
- } else {
- // Add system columns that MQ adds to all records
- recordType = schema.NewRecordTypeBuilder(recordType).
- WithField(SW_COLUMN_NAME_TIMESTAMP, schema.TypeInt64).
- WithField(SW_COLUMN_NAME_KEY, schema.TypeBytes).
- RecordTypeEnd()
- }
-
- // Convert to Parquet levels for efficient reading
- parquetLevels, err := schema.ToParquetLevels(recordType)
- if err != nil {
- return nil, fmt.Errorf("failed to create Parquet levels: %v", err)
- }
-
- return &ParquetScanner{
- filerClient: filerClient,
- chunkCache: chunk_cache.NewChunkCacheInMemory(256), // Same as MQ logstore
- topic: t,
- recordSchema: recordType,
- parquetLevels: parquetLevels,
- }, nil
-}
-
-// ScanOptions configure how the scanner reads data
-type ScanOptions struct {
- // Time range filtering (Unix nanoseconds)
- StartTimeNs int64
- StopTimeNs int64
-
- // Column projection - if empty, select all columns
- Columns []string
-
- // Row limit - 0 means no limit
- Limit int
-
- // Predicate for WHERE clause filtering
- Predicate func(*schema_pb.RecordValue) bool
-}
-
-// ScanResult represents a single scanned record
-type ScanResult struct {
- Values map[string]*schema_pb.Value // Column name -> value
- Timestamp int64 // Message timestamp (_ts_ns)
- Key []byte // Message key (_key)
-}
-
-// Scan reads records from the topic's Parquet files
-// Assumptions:
-// 1. Scans all partitions of the topic
-// 2. Applies time filtering at Parquet level for efficiency
-// 3. Applies predicates and projections after reading
-func (ps *ParquetScanner) Scan(ctx context.Context, options ScanOptions) ([]ScanResult, error) {
- var results []ScanResult
-
- // Get all partitions for this topic
- // TODO: Implement proper partition discovery
- // For now, assume partition 0 exists
- partitions := []topic.Partition{{RangeStart: 0, RangeStop: 1000}}
-
- for _, partition := range partitions {
- partitionResults, err := ps.scanPartition(ctx, partition, options)
- if err != nil {
- return nil, fmt.Errorf("failed to scan partition %v: %v", partition, err)
- }
-
- results = append(results, partitionResults...)
-
- // Apply global limit across all partitions
- if options.Limit > 0 && len(results) >= options.Limit {
- results = results[:options.Limit]
- break
- }
- }
-
- return results, nil
-}
-
-// scanPartition scans a specific topic partition
-func (ps *ParquetScanner) scanPartition(ctx context.Context, partition topic.Partition, options ScanOptions) ([]ScanResult, error) {
- // partitionDir := topic.PartitionDir(ps.topic, partition) // TODO: Use for actual file listing
-
- var results []ScanResult
-
- // List Parquet files in partition directory
- // TODO: Implement proper file listing with date range filtering
- // For now, this is a placeholder that would list actual Parquet files
-
- // Simulate file processing - in real implementation, this would:
- // 1. List files in partitionDir via filerClient
- // 2. Filter files by date range if time filtering is enabled
- // 3. Process each Parquet file in chronological order
-
- // Placeholder: Create sample data for testing
- if len(results) == 0 {
- // Generate sample data for demonstration
- sampleData := ps.generateSampleData(options)
- results = append(results, sampleData...)
- }
-
- return results, nil
-}
-
-// generateSampleData creates sample data for testing when no real Parquet files exist
-func (ps *ParquetScanner) generateSampleData(options ScanOptions) []ScanResult {
- now := time.Now().UnixNano()
-
- sampleData := []ScanResult{
- {
- Values: map[string]*schema_pb.Value{
- "user_id": {Kind: &schema_pb.Value_Int32Value{Int32Value: 1001}},
- "event_type": {Kind: &schema_pb.Value_StringValue{StringValue: "login"}},
- "data": {Kind: &schema_pb.Value_StringValue{StringValue: `{"ip": "192.168.1.1"}`}},
- },
- Timestamp: now - 3600000000000, // 1 hour ago
- Key: []byte("user-1001"),
- },
- {
- Values: map[string]*schema_pb.Value{
- "user_id": {Kind: &schema_pb.Value_Int32Value{Int32Value: 1002}},
- "event_type": {Kind: &schema_pb.Value_StringValue{StringValue: "page_view"}},
- "data": {Kind: &schema_pb.Value_StringValue{StringValue: `{"page": "/dashboard"}`}},
- },
- Timestamp: now - 1800000000000, // 30 minutes ago
- Key: []byte("user-1002"),
- },
- {
- Values: map[string]*schema_pb.Value{
- "user_id": {Kind: &schema_pb.Value_Int32Value{Int32Value: 1001}},
- "event_type": {Kind: &schema_pb.Value_StringValue{StringValue: "logout"}},
- "data": {Kind: &schema_pb.Value_StringValue{StringValue: `{"session_duration": 3600}`}},
- },
- Timestamp: now - 900000000000, // 15 minutes ago
- Key: []byte("user-1001"),
- },
- }
-
- // Apply predicate filtering if specified
- if options.Predicate != nil {
- var filtered []ScanResult
- for _, result := range sampleData {
- // Convert to RecordValue for predicate testing
- recordValue := &schema_pb.RecordValue{Fields: make(map[string]*schema_pb.Value)}
- for k, v := range result.Values {
- recordValue.Fields[k] = v
- }
- recordValue.Fields[SW_COLUMN_NAME_TIMESTAMP] = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: result.Timestamp}}
- recordValue.Fields[SW_COLUMN_NAME_KEY] = &schema_pb.Value{Kind: &schema_pb.Value_BytesValue{BytesValue: result.Key}}
-
- if options.Predicate(recordValue) {
- filtered = append(filtered, result)
- }
- }
- sampleData = filtered
- }
-
- // Apply limit
- if options.Limit > 0 && len(sampleData) > options.Limit {
- sampleData = sampleData[:options.Limit]
- }
-
- return sampleData
-}
-
-// ConvertToSQLResult converts ScanResults to SQL query results
-func (ps *ParquetScanner) ConvertToSQLResult(results []ScanResult, columns []string) *QueryResult {
- if len(results) == 0 {
- return &QueryResult{
- Columns: columns,
- Rows: [][]sqltypes.Value{},
- }
- }
-
- // Determine columns if not specified
- if len(columns) == 0 {
- columnSet := make(map[string]bool)
- for _, result := range results {
- for columnName := range result.Values {
- columnSet[columnName] = true
- }
- }
-
- columns = make([]string, 0, len(columnSet))
- for columnName := range columnSet {
- columns = append(columns, columnName)
- }
- }
-
- // Convert to SQL rows
- rows := make([][]sqltypes.Value, len(results))
- for i, result := range results {
- row := make([]sqltypes.Value, len(columns))
- for j, columnName := range columns {
- if value, exists := result.Values[columnName]; exists {
- row[j] = convertSchemaValueToSQL(value)
- } else {
- row[j] = sqltypes.NULL
- }
- }
- rows[i] = row
- }
-
- return &QueryResult{
- Columns: columns,
- Rows: rows,
- }
-}
-
// convertSchemaValueToSQL converts schema_pb.Value to sqltypes.Value
func convertSchemaValueToSQL(value *schema_pb.Value) sqltypes.Value {
if value == nil {
diff --git a/weed/query/engine/partition_path_fix_test.go b/weed/query/engine/partition_path_fix_test.go
deleted file mode 100644
index 8d92136e6..000000000
--- a/weed/query/engine/partition_path_fix_test.go
+++ /dev/null
@@ -1,117 +0,0 @@
-package engine
-
-import (
- "strings"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-// TestPartitionPathHandling tests that partition paths are handled correctly
-// whether discoverTopicPartitions returns relative or absolute paths
-func TestPartitionPathHandling(t *testing.T) {
- engine := NewMockSQLEngine()
-
- t.Run("Mock discoverTopicPartitions returns correct paths", func(t *testing.T) {
- // Test that our mock engine handles absolute paths correctly
- engine.mockPartitions["test.user_events"] = []string{
- "/topics/test/user_events/v2025-09-03-15-36-29/0000-2520",
- "/topics/test/user_events/v2025-09-03-15-36-29/2521-5040",
- }
-
- partitions, err := engine.discoverTopicPartitions("test", "user_events")
- assert.NoError(t, err, "Should discover partitions without error")
- assert.Equal(t, 2, len(partitions), "Should return 2 partitions")
- assert.Contains(t, partitions[0], "/topics/test/user_events/", "Should contain absolute path")
- })
-
- t.Run("Mock discoverTopicPartitions handles relative paths", func(t *testing.T) {
- // Test relative paths scenario
- engine.mockPartitions["test.user_events"] = []string{
- "v2025-09-03-15-36-29/0000-2520",
- "v2025-09-03-15-36-29/2521-5040",
- }
-
- partitions, err := engine.discoverTopicPartitions("test", "user_events")
- assert.NoError(t, err, "Should discover partitions without error")
- assert.Equal(t, 2, len(partitions), "Should return 2 partitions")
- assert.True(t, !strings.HasPrefix(partitions[0], "/topics/"), "Should be relative path")
- })
-
- t.Run("Partition path building logic works correctly", func(t *testing.T) {
- topicBasePath := "/topics/test/user_events"
-
- testCases := []struct {
- name string
- relativePartition string
- expectedPath string
- }{
- {
- name: "Absolute path - use as-is",
- relativePartition: "/topics/test/user_events/v2025-09-03-15-36-29/0000-2520",
- expectedPath: "/topics/test/user_events/v2025-09-03-15-36-29/0000-2520",
- },
- {
- name: "Relative path - build full path",
- relativePartition: "v2025-09-03-15-36-29/0000-2520",
- expectedPath: "/topics/test/user_events/v2025-09-03-15-36-29/0000-2520",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- var partitionPath string
-
- // This is the same logic from our fixed code
- if strings.HasPrefix(tc.relativePartition, "/topics/") {
- // Already a full path - use as-is
- partitionPath = tc.relativePartition
- } else {
- // Relative path - build full path
- partitionPath = topicBasePath + "/" + tc.relativePartition
- }
-
- assert.Equal(t, tc.expectedPath, partitionPath,
- "Partition path should be built correctly")
-
- // Ensure no double slashes
- assert.NotContains(t, partitionPath, "//",
- "Partition path should not contain double slashes")
- })
- }
- })
-}
-
-// TestPartitionPathLogic tests the core logic for handling partition paths
-func TestPartitionPathLogic(t *testing.T) {
- t.Run("Building partition paths from discovered partitions", func(t *testing.T) {
- // Test the specific partition path building that was causing issues
-
- topicBasePath := "/topics/ecommerce/user_events"
-
- // This simulates the discoverTopicPartitions returning absolute paths (realistic scenario)
- relativePartitions := []string{
- "/topics/ecommerce/user_events/v2025-09-03-15-36-29/0000-2520",
- }
-
- // This is the code from our fix - test it directly
- partitions := make([]string, len(relativePartitions))
- for i, relPartition := range relativePartitions {
- // Handle both relative and absolute partition paths from discoverTopicPartitions
- if strings.HasPrefix(relPartition, "/topics/") {
- // Already a full path - use as-is
- partitions[i] = relPartition
- } else {
- // Relative path - build full path
- partitions[i] = topicBasePath + "/" + relPartition
- }
- }
-
- // Verify the path was handled correctly
- expectedPath := "/topics/ecommerce/user_events/v2025-09-03-15-36-29/0000-2520"
- assert.Equal(t, expectedPath, partitions[0], "Absolute path should be used as-is")
-
- // Ensure no double slashes (this was the original bug)
- assert.NotContains(t, partitions[0], "//", "Path should not contain double slashes")
- })
-}
diff --git a/weed/query/sqltypes/type.go b/weed/query/sqltypes/type.go
index f4f3dd471..2a0f40386 100644
--- a/weed/query/sqltypes/type.go
+++ b/weed/query/sqltypes/type.go
@@ -56,11 +56,6 @@ func IsBinary(t Type) bool {
return int(t)&flagIsBinary == flagIsBinary
}
-// isNumber returns true if the type is any type of number.
-func isNumber(t Type) bool {
- return IsIntegral(t) || IsFloat(t) || t == Decimal
-}
-
// IsTemporal returns true if Value is time type.
func IsTemporal(t Type) bool {
switch t {
diff --git a/weed/query/sqltypes/value.go b/weed/query/sqltypes/value.go
index 012de2b45..7c2599652 100644
--- a/weed/query/sqltypes/value.go
+++ b/weed/query/sqltypes/value.go
@@ -1,9 +1,7 @@
package sqltypes
import (
- "fmt"
"strconv"
- "time"
)
var (
@@ -19,32 +17,6 @@ type Value struct {
val []byte
}
-// NewValue builds a Value using typ and val. If the value and typ
-// don't match, it returns an error.
-func NewValue(typ Type, val []byte) (v Value, err error) {
- switch {
- case IsSigned(typ):
- if _, err := strconv.ParseInt(string(val), 0, 64); err != nil {
- return NULL, err
- }
- return MakeTrusted(typ, val), nil
- case IsUnsigned(typ):
- if _, err := strconv.ParseUint(string(val), 0, 64); err != nil {
- return NULL, err
- }
- return MakeTrusted(typ, val), nil
- case IsFloat(typ) || typ == Decimal:
- if _, err := strconv.ParseFloat(string(val), 64); err != nil {
- return NULL, err
- }
- return MakeTrusted(typ, val), nil
- case IsQuoted(typ) || typ == Bit || typ == Null:
- return MakeTrusted(typ, val), nil
- }
- // All other types are unsafe or invalid.
- return NULL, fmt.Errorf("invalid type specified for MakeValue: %v", typ)
-}
-
// MakeTrusted makes a new Value based on the type.
// This function should only be used if you know the value
// and type conform to the rules. Every place this function is
@@ -71,11 +43,6 @@ func NewInt32(v int32) Value {
return MakeTrusted(Int32, strconv.AppendInt(nil, int64(v), 10))
}
-// NewUint64 builds an Uint64 Value.
-func NewUint64(v uint64) Value {
- return MakeTrusted(Uint64, strconv.AppendUint(nil, v, 10))
-}
-
// NewFloat32 builds an Float64 Value.
func NewFloat32(v float32) Value {
return MakeTrusted(Float32, strconv.AppendFloat(nil, float64(v), 'f', -1, 64))
@@ -97,136 +64,11 @@ func NewVarBinary(v string) Value {
return MakeTrusted(VarBinary, []byte(v))
}
-// NewIntegral builds an integral type from a string representation.
-// The type will be Int64 or Uint64. Int64 will be preferred where possible.
-func NewIntegral(val string) (n Value, err error) {
- signed, err := strconv.ParseInt(val, 0, 64)
- if err == nil {
- return MakeTrusted(Int64, strconv.AppendInt(nil, signed, 10)), nil
- }
- unsigned, err := strconv.ParseUint(val, 0, 64)
- if err != nil {
- return Value{}, err
- }
- return MakeTrusted(Uint64, strconv.AppendUint(nil, unsigned, 10)), nil
-}
-
// MakeString makes a VarBinary Value.
func MakeString(val []byte) Value {
return MakeTrusted(VarBinary, val)
}
-// BuildValue builds a value from any go type. sqltype.Value is
-// also allowed.
-func BuildValue(goval interface{}) (v Value, err error) {
- // Look for the most common types first.
- switch goval := goval.(type) {
- case nil:
- // no op
- case []byte:
- v = MakeTrusted(VarBinary, goval)
- case int64:
- v = MakeTrusted(Int64, strconv.AppendInt(nil, int64(goval), 10))
- case uint64:
- v = MakeTrusted(Uint64, strconv.AppendUint(nil, uint64(goval), 10))
- case float64:
- v = MakeTrusted(Float64, strconv.AppendFloat(nil, goval, 'f', -1, 64))
- case int:
- v = MakeTrusted(Int64, strconv.AppendInt(nil, int64(goval), 10))
- case int8:
- v = MakeTrusted(Int8, strconv.AppendInt(nil, int64(goval), 10))
- case int16:
- v = MakeTrusted(Int16, strconv.AppendInt(nil, int64(goval), 10))
- case int32:
- v = MakeTrusted(Int32, strconv.AppendInt(nil, int64(goval), 10))
- case uint:
- v = MakeTrusted(Uint64, strconv.AppendUint(nil, uint64(goval), 10))
- case uint8:
- v = MakeTrusted(Uint8, strconv.AppendUint(nil, uint64(goval), 10))
- case uint16:
- v = MakeTrusted(Uint16, strconv.AppendUint(nil, uint64(goval), 10))
- case uint32:
- v = MakeTrusted(Uint32, strconv.AppendUint(nil, uint64(goval), 10))
- case float32:
- v = MakeTrusted(Float32, strconv.AppendFloat(nil, float64(goval), 'f', -1, 64))
- case string:
- v = MakeTrusted(VarBinary, []byte(goval))
- case time.Time:
- v = MakeTrusted(Datetime, []byte(goval.Format("2006-01-02 15:04:05")))
- case Value:
- v = goval
- case *BindVariable:
- return ValueFromBytes(goval.Type, goval.Value)
- default:
- return v, fmt.Errorf("unexpected type %T: %v", goval, goval)
- }
- return v, nil
-}
-
-// BuildConverted is like BuildValue except that it tries to
-// convert a string or []byte to an integral if the target type
-// is an integral. We don't perform other implicit conversions
-// because they're unsafe.
-func BuildConverted(typ Type, goval interface{}) (v Value, err error) {
- if IsIntegral(typ) {
- switch goval := goval.(type) {
- case []byte:
- return ValueFromBytes(typ, goval)
- case string:
- return ValueFromBytes(typ, []byte(goval))
- case Value:
- if goval.IsQuoted() {
- return ValueFromBytes(typ, goval.Raw())
- }
- }
- }
- return BuildValue(goval)
-}
-
-// ValueFromBytes builds a Value using typ and val. It ensures that val
-// matches the requested type. If type is an integral it's converted to
-// a canonical form. Otherwise, the original representation is preserved.
-func ValueFromBytes(typ Type, val []byte) (v Value, err error) {
- switch {
- case IsSigned(typ):
- signed, err := strconv.ParseInt(string(val), 0, 64)
- if err != nil {
- return NULL, err
- }
- v = MakeTrusted(typ, strconv.AppendInt(nil, signed, 10))
- case IsUnsigned(typ):
- unsigned, err := strconv.ParseUint(string(val), 0, 64)
- if err != nil {
- return NULL, err
- }
- v = MakeTrusted(typ, strconv.AppendUint(nil, unsigned, 10))
- case IsFloat(typ) || typ == Decimal:
- _, err := strconv.ParseFloat(string(val), 64)
- if err != nil {
- return NULL, err
- }
- // After verification, we preserve the original representation.
- fallthrough
- default:
- v = MakeTrusted(typ, val)
- }
- return v, nil
-}
-
-// BuildIntegral builds an integral type from a string representation.
-// The type will be Int64 or Uint64. Int64 will be preferred where possible.
-func BuildIntegral(val string) (n Value, err error) {
- signed, err := strconv.ParseInt(val, 0, 64)
- if err == nil {
- return MakeTrusted(Int64, strconv.AppendInt(nil, signed, 10)), nil
- }
- unsigned, err := strconv.ParseUint(val, 0, 64)
- if err != nil {
- return Value{}, err
- }
- return MakeTrusted(Uint64, strconv.AppendUint(nil, unsigned, 10)), nil
-}
-
// Type returns the type of Value.
func (v Value) Type() Type {
return v.typ
@@ -247,15 +89,6 @@ func (v Value) Len() int {
// Values represents the array of Value.
type Values []Value
-// Len implements the interface.
-func (vs Values) Len() int {
- len := 0
- for _, v := range vs {
- len += v.Len()
- }
- return len
-}
-
// String returns the raw value as a string.
func (v Value) String() string {
return BytesToString(v.val)
diff --git a/weed/remote_storage/remote_storage.go b/weed/remote_storage/remote_storage.go
index e23fd81df..0a6a63e1d 100644
--- a/weed/remote_storage/remote_storage.go
+++ b/weed/remote_storage/remote_storage.go
@@ -120,17 +120,6 @@ func GetAllRemoteStorageNames() string {
return strings.Join(storageNames, "|")
}
-func GetRemoteStorageNamesHasBucket() string {
- var storageNames []string
- for k, m := range RemoteStorageClientMakers {
- if m.HasBucket() {
- storageNames = append(storageNames, k)
- }
- }
- sort.Strings(storageNames)
- return strings.Join(storageNames, "|")
-}
-
func ParseRemoteLocation(remoteConfType string, remote string) (remoteStorageLocation *remote_pb.RemoteStorageLocation, err error) {
maker, found := RemoteStorageClientMakers[remoteConfType]
if !found {
diff --git a/weed/s3api/auth_credentials.go b/weed/s3api/auth_credentials.go
index a09f6c6d4..ec950cbab 100644
--- a/weed/s3api/auth_credentials.go
+++ b/weed/s3api/auth_credentials.go
@@ -144,6 +144,10 @@ func (c *Credential) isCredentialExpired() bool {
}
// NewIdentityAccessManagement creates a new IAM manager
+func NewIdentityAccessManagement(option *S3ApiServerOption, filerClient *wdclient.FilerClient) *IdentityAccessManagement {
+ return NewIdentityAccessManagementWithStore(option, filerClient, "")
+}
+
// SetFilerClient updates the filer client and its associated credential store
func (iam *IdentityAccessManagement) SetFilerClient(filerClient *wdclient.FilerClient) {
iam.m.Lock()
@@ -196,10 +200,6 @@ func parseExternalUrlToHost(externalUrl string) (string, error) {
return net.JoinHostPort(host, port), nil
}
-func NewIdentityAccessManagement(option *S3ApiServerOption, filerClient *wdclient.FilerClient) *IdentityAccessManagement {
- return NewIdentityAccessManagementWithStore(option, filerClient, "")
-}
-
func NewIdentityAccessManagementWithStore(option *S3ApiServerOption, filerClient *wdclient.FilerClient, explicitStore string) *IdentityAccessManagement {
var externalHost string
if option.ExternalUrl != "" {
diff --git a/weed/s3api/auth_credentials_test.go b/weed/s3api/auth_credentials_test.go
deleted file mode 100644
index 1e84b93db..000000000
--- a/weed/s3api/auth_credentials_test.go
+++ /dev/null
@@ -1,1393 +0,0 @@
-package s3api
-
-import (
- "context"
- "crypto/tls"
- "fmt"
- "net/http"
- "os"
- "reflect"
- "sync"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/credential"
- "github.com/seaweedfs/seaweedfs/weed/credential/memory"
- "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/policy_engine"
- . "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
- "github.com/seaweedfs/seaweedfs/weed/util/wildcard"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-
- jsonpb "google.golang.org/protobuf/encoding/protojson"
-
- _ "github.com/seaweedfs/seaweedfs/weed/credential/filer_etc"
-)
-
-type loadConfigurationDropsPoliciesStore struct {
- *memory.MemoryStore
- loadManagedPoliciesCalled bool
-}
-
-func (store *loadConfigurationDropsPoliciesStore) LoadConfiguration(ctx context.Context) (*iam_pb.S3ApiConfiguration, error) {
- config, err := store.MemoryStore.LoadConfiguration(ctx)
- if err != nil {
- return nil, err
- }
- stripped := *config
- stripped.Policies = nil
- return &stripped, nil
-}
-
-func (store *loadConfigurationDropsPoliciesStore) LoadManagedPolicies(ctx context.Context) ([]*iam_pb.Policy, error) {
- store.loadManagedPoliciesCalled = true
-
- config, err := store.MemoryStore.LoadConfiguration(ctx)
- if err != nil {
- return nil, err
- }
-
- policies := make([]*iam_pb.Policy, 0, len(config.Policies))
- for _, policy := range config.Policies {
- policies = append(policies, &iam_pb.Policy{
- Name: policy.Name,
- Content: policy.Content,
- })
- }
-
- return policies, nil
-}
-
-type inlinePolicyRuntimeStore struct {
- *memory.MemoryStore
- inlinePolicies map[string]map[string]policy_engine.PolicyDocument
-}
-
-func (store *inlinePolicyRuntimeStore) LoadInlinePolicies(ctx context.Context) (map[string]map[string]policy_engine.PolicyDocument, error) {
- _ = ctx
- return store.inlinePolicies, nil
-}
-
-func newPolicyAuthRequest(t *testing.T, method string) *http.Request {
- t.Helper()
- req, err := http.NewRequest(method, "http://s3.amazonaws.com/test-bucket/test-object", nil)
- require.NoError(t, err)
- return req
-}
-
-func TestIdentityListFileFormat(t *testing.T) {
-
- s3ApiConfiguration := &iam_pb.S3ApiConfiguration{}
-
- identity1 := &iam_pb.Identity{
- Name: "some_name",
- Credentials: []*iam_pb.Credential{
- {
- AccessKey: "some_access_key1",
- SecretKey: "some_secret_key2",
- },
- },
- Actions: []string{
- ACTION_ADMIN,
- ACTION_READ,
- ACTION_WRITE,
- },
- }
- identity2 := &iam_pb.Identity{
- Name: "some_read_only_user",
- Credentials: []*iam_pb.Credential{
- {
- AccessKey: "some_access_key1",
- SecretKey: "some_secret_key1",
- },
- },
- Actions: []string{
- ACTION_READ,
- },
- }
- identity3 := &iam_pb.Identity{
- Name: "some_normal_user",
- Credentials: []*iam_pb.Credential{
- {
- AccessKey: "some_access_key2",
- SecretKey: "some_secret_key2",
- },
- },
- Actions: []string{
- ACTION_READ,
- ACTION_WRITE,
- },
- }
-
- s3ApiConfiguration.Identities = append(s3ApiConfiguration.Identities, identity1)
- s3ApiConfiguration.Identities = append(s3ApiConfiguration.Identities, identity2)
- s3ApiConfiguration.Identities = append(s3ApiConfiguration.Identities, identity3)
-
- m := jsonpb.MarshalOptions{
- EmitUnpopulated: true,
- Indent: " ",
- }
-
- text, _ := m.Marshal(s3ApiConfiguration)
-
- println(string(text))
-
-}
-
-func TestCanDo(t *testing.T) {
- ident1 := &Identity{
- Name: "anything",
- Actions: []Action{
- "Write:bucket1/a/b/c/*",
- "Write:bucket1/a/b/other",
- },
- }
- // object specific
- assert.Equal(t, true, ident1.CanDo(ACTION_WRITE, "bucket1", "/a/b/c/d.txt"))
- assert.Equal(t, true, ident1.CanDo(ACTION_WRITE, "bucket1", "/a/b/c/d/e.txt"))
- assert.Equal(t, false, ident1.CanDo(ACTION_DELETE_BUCKET, "bucket1", ""))
- assert.Equal(t, false, ident1.CanDo(ACTION_WRITE, "bucket1", "/a/b/other/some"), "action without *")
- assert.Equal(t, false, ident1.CanDo(ACTION_WRITE, "bucket1", "/a/b/*"), "action on parent directory")
-
- // bucket specific
- ident2 := &Identity{
- Name: "anything",
- Actions: []Action{
- "Read:bucket1",
- "Write:bucket1/*",
- "WriteAcp:bucket1",
- },
- }
- assert.Equal(t, true, ident2.CanDo(ACTION_READ, "bucket1", "/a/b/c/d.txt"))
- assert.Equal(t, true, ident2.CanDo(ACTION_WRITE, "bucket1", "/a/b/c/d.txt"))
- assert.Equal(t, true, ident2.CanDo(ACTION_WRITE_ACP, "bucket1", ""))
- assert.Equal(t, false, ident2.CanDo(ACTION_READ_ACP, "bucket1", ""))
- assert.Equal(t, false, ident2.CanDo(ACTION_LIST, "bucket1", "/a/b/c/d.txt"))
-
- // across buckets
- ident3 := &Identity{
- Name: "anything",
- Actions: []Action{
- "Read",
- "Write",
- },
- }
- assert.Equal(t, true, ident3.CanDo(ACTION_READ, "bucket1", "/a/b/c/d.txt"))
- assert.Equal(t, true, ident3.CanDo(ACTION_WRITE, "bucket1", "/a/b/c/d.txt"))
- assert.Equal(t, false, ident3.CanDo(ACTION_LIST, "bucket1", "/a/b/other/some"))
- assert.Equal(t, false, ident3.CanDo(ACTION_WRITE_ACP, "bucket1", ""))
-
- // partial buckets
- ident4 := &Identity{
- Name: "anything",
- Actions: []Action{
- "Read:special_*",
- "ReadAcp:special_*",
- },
- }
- assert.Equal(t, true, ident4.CanDo(ACTION_READ, "special_bucket", "/a/b/c/d.txt"))
- assert.Equal(t, true, ident4.CanDo(ACTION_READ_ACP, "special_bucket", ""))
- assert.Equal(t, false, ident4.CanDo(ACTION_READ, "bucket1", "/a/b/c/d.txt"))
-
- // admin buckets
- ident5 := &Identity{
- Name: "anything",
- Actions: []Action{
- "Admin:special_*",
- },
- }
- assert.Equal(t, true, ident5.CanDo(ACTION_READ, "special_bucket", "/a/b/c/d.txt"))
- assert.Equal(t, true, ident5.CanDo(ACTION_READ_ACP, "special_bucket", ""))
- assert.Equal(t, true, ident5.CanDo(ACTION_WRITE, "special_bucket", "/a/b/c/d.txt"))
- assert.Equal(t, true, ident5.CanDo(ACTION_WRITE_ACP, "special_bucket", ""))
-
- // anonymous buckets
- ident6 := &Identity{
- Name: "anonymous",
- Actions: []Action{
- "Read",
- },
- }
- assert.Equal(t, true, ident6.CanDo(ACTION_READ, "anything_bucket", "/a/b/c/d.txt"))
-
- //test deleteBucket operation
- ident7 := &Identity{
- Name: "anything",
- Actions: []Action{
- "DeleteBucket:bucket1",
- },
- }
- assert.Equal(t, true, ident7.CanDo(ACTION_DELETE_BUCKET, "bucket1", ""))
-}
-
-func TestMatchWildcardPattern(t *testing.T) {
- tests := []struct {
- pattern string
- target string
- match bool
- }{
- // Basic * wildcard tests
- {"Bucket/*", "Bucket/a/b", true},
- {"Bucket/*", "x/Bucket/a", false},
- {"Bucket/*/admin", "Bucket/x/admin", true},
- {"Bucket/*/admin", "Bucket/x/y/admin", true},
- {"Bucket/*/admin", "Bucket////x////uwu////y////admin", true},
- {"abc*def", "abcXYZdef", true},
- {"abc*def", "abcXYZdefZZ", false},
- {"syr/*", "syr/a/b", true},
-
- // ? wildcard tests (matches exactly one character)
- {"ab?d", "abcd", true},
- {"ab?d", "abXd", true},
- {"ab?d", "abd", false}, // ? must match exactly one character
- {"ab?d", "abcXd", false}, // ? matches only one character
- {"a?c", "abc", true},
- {"a?c", "aXc", true},
- {"a?c", "ac", false},
- {"???", "abc", true},
- {"???", "ab", false},
- {"???", "abcd", false},
-
- // Combined * and ? wildcards
- {"a*?", "ab", true}, // * matches empty, ? matches 'b'
- {"a*?", "abc", true}, // * matches 'b', ? matches 'c'
- {"a*?", "a", false}, // ? must match something
- {"a?*", "ab", true}, // ? matches 'b', * matches empty
- {"a?*", "abc", true}, // ? matches 'b', * matches 'c'
- {"a?*b", "aXb", true}, // ? matches 'X', * matches empty
- {"a?*b", "aXYZb", true},
- {"*?*", "a", true},
- {"*?*", "", false}, // ? requires at least one character
-
- // Edge cases: * matches empty string
- {"a*b", "ab", true}, // * matches empty string
- {"a**b", "ab", true}, // multiple stars match empty
- {"a**b", "axb", true}, // multiple stars match 'x'
- {"a**b", "axyb", true},
- {"*", "", true},
- {"*", "anything", true},
- {"**", "", true},
- {"**", "anything", true},
-
- // Edge cases: empty strings
- {"", "", true},
- {"a", "", false},
- {"", "a", false},
-
- // Trailing * matches empty
- {"a*", "a", true},
- {"a*", "abc", true},
- {"abc*", "abc", true},
- {"abc*", "abcdef", true},
-
- // Leading * matches empty
- {"*a", "a", true},
- {"*a", "XXXa", true},
- {"*abc", "abc", true},
- {"*abc", "XXXabc", true},
-
- // Multiple wildcards
- {"*a*", "a", true},
- {"*a*", "Xa", true},
- {"*a*", "aX", true},
- {"*a*", "XaX", true},
- {"*a*b*", "ab", true},
- {"*a*b*", "XaYbZ", true},
-
- // Exact match (no wildcards)
- {"exact", "exact", true},
- {"exact", "notexact", false},
- {"exact", "exactnot", false},
-
- // S3-style action patterns
- {"Read:bucket*", "Read:bucket-test", true},
- {"Read:bucket*", "Read:bucket", true},
- {"Write:bucket/path/*", "Write:bucket/path/file.txt", true},
- {"Admin:*", "Admin:anything", true},
- }
-
- for _, tt := range tests {
- t.Run(tt.pattern+"_"+tt.target, func(t *testing.T) {
- result := wildcard.MatchesWildcard(tt.pattern, tt.target)
- if result != tt.match {
- t.Errorf("wildcard.MatchesWildcard(%q, %q) = %v, want %v", tt.pattern, tt.target, result, tt.match)
- }
- })
- }
-}
-
-func TestVerifyActionPermissionPolicyFallback(t *testing.T) {
- buildRequest := func(t *testing.T, method string) *http.Request {
- t.Helper()
- req, err := http.NewRequest(method, "http://s3.amazonaws.com/test-bucket/test-object", nil)
- assert.NoError(t, err)
- return req
- }
-
- t.Run("policy allow grants access", func(t *testing.T) {
- iam := &IdentityAccessManagement{}
- err := iam.PutPolicy("allowGet", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`)
- assert.NoError(t, err)
-
- identity := &Identity{
- Name: "policy-user",
- Account: &AccountAdmin,
- PolicyNames: []string{"allowGet"},
- }
-
- errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrNone, errCode)
- })
-
- t.Run("explicit deny overrides allow", func(t *testing.T) {
- iam := &IdentityAccessManagement{}
- err := iam.PutPolicy("allowAllGet", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`)
- assert.NoError(t, err)
- err = iam.PutPolicy("denySecret", `{"Version":"2012-10-17","Statement":[{"Effect":"Deny","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/secret.txt"}]}`)
- assert.NoError(t, err)
-
- identity := &Identity{
- Name: "policy-user",
- Account: &AccountAdmin,
- PolicyNames: []string{"allowAllGet", "denySecret"},
- }
-
- errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "secret.txt")
- assert.Equal(t, s3err.ErrAccessDenied, errCode)
- })
-
- t.Run("implicit deny when no statement matches", func(t *testing.T) {
- iam := &IdentityAccessManagement{}
- err := iam.PutPolicy("allowOtherBucket", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::other-bucket/*"}]}`)
- assert.NoError(t, err)
-
- identity := &Identity{
- Name: "policy-user",
- Account: &AccountAdmin,
- PolicyNames: []string{"allowOtherBucket"},
- }
-
- errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrAccessDenied, errCode)
- })
-
- t.Run("invalid policy document does not allow", func(t *testing.T) {
- iam := &IdentityAccessManagement{}
- err := iam.PutPolicy("invalidPolicy", "{not-json")
- assert.NoError(t, err)
-
- identity := &Identity{
- Name: "policy-user",
- Account: &AccountAdmin,
- PolicyNames: []string{"invalidPolicy"},
- }
-
- errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrAccessDenied, errCode)
- })
-
- t.Run("notresource excludes denied object", func(t *testing.T) {
- iam := &IdentityAccessManagement{}
- err := iam.PutPolicy("denyNotResource", `{"Version":"2012-10-17","Statement":[{"Effect":"Deny","Action":"s3:GetObject","NotResource":"arn:aws:s3:::test-bucket/public/*"}]}`)
- assert.NoError(t, err)
- err = iam.PutPolicy("allowAllGet", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`)
- assert.NoError(t, err)
-
- identity := &Identity{
- Name: "policy-user",
- Account: &AccountAdmin,
- PolicyNames: []string{"allowAllGet", "denyNotResource"},
- }
-
- errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "private/secret.txt")
- assert.Equal(t, s3err.ErrAccessDenied, errCode)
-
- errCode = iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "public/readme.txt")
- assert.Equal(t, s3err.ErrNone, errCode)
- })
-
- t.Run("condition securetransport enforced", func(t *testing.T) {
- iam := &IdentityAccessManagement{}
- err := iam.PutPolicy("allowTLSOnly", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*","Condition":{"Bool":{"aws:SecureTransport":"true"}}}]}`)
- assert.NoError(t, err)
-
- identity := &Identity{
- Name: "policy-user",
- Account: &AccountAdmin,
- PolicyNames: []string{"allowTLSOnly"},
- }
-
- httpReq := buildRequest(t, http.MethodGet)
- errCode := iam.VerifyActionPermission(httpReq, identity, Action(ACTION_READ), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrAccessDenied, errCode)
-
- httpsReq := buildRequest(t, http.MethodGet)
- httpsReq.TLS = &tls.ConnectionState{}
- errCode = iam.VerifyActionPermission(httpsReq, identity, Action(ACTION_READ), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrNone, errCode)
- })
-
- t.Run("attached policies override coarse legacy actions", func(t *testing.T) {
- iam := &IdentityAccessManagement{}
- err := iam.PutPolicy("putOnly", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:PutObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`)
- assert.NoError(t, err)
-
- identity := &Identity{
- Name: "policy-user",
- Account: &AccountAdmin,
- Actions: []Action{"Write:test-bucket"},
- PolicyNames: []string{"putOnly"},
- }
-
- putErrCode := iam.VerifyActionPermission(buildRequest(t, http.MethodPut), identity, Action(ACTION_WRITE), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrNone, putErrCode)
-
- deleteErrCode := iam.VerifyActionPermission(buildRequest(t, http.MethodDelete), identity, Action(ACTION_WRITE), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrAccessDenied, deleteErrCode)
- })
-
- t.Run("valid policy updated to invalid denies access", func(t *testing.T) {
- iam := &IdentityAccessManagement{}
- err := iam.PutPolicy("myPolicy", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`)
- assert.NoError(t, err)
-
- identity := &Identity{
- Name: "policy-user",
- Account: &AccountAdmin,
- PolicyNames: []string{"myPolicy"},
- }
-
- errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrNone, errCode)
-
- // Update to invalid JSON — should revoke access.
- err = iam.PutPolicy("myPolicy", "{broken")
- assert.NoError(t, err)
-
- errCode = iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrAccessDenied, errCode)
- })
-
- t.Run("actions based path still works", func(t *testing.T) {
- iam := &IdentityAccessManagement{}
- identity := &Identity{
- Name: "legacy-user",
- Account: &AccountAdmin,
- Actions: []Action{"Read:test-bucket"},
- }
-
- errCode := iam.VerifyActionPermission(buildRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "any-object")
- assert.Equal(t, s3err.ErrNone, errCode)
- })
-}
-
-func TestLoadS3ApiConfigurationFromCredentialManagerHydratesManagedPolicies(t *testing.T) {
- baseStore := &memory.MemoryStore{}
- assert.NoError(t, baseStore.Initialize(nil, ""))
-
- store := &loadConfigurationDropsPoliciesStore{MemoryStore: baseStore}
- cm := &credential.CredentialManager{Store: store}
-
- config := &iam_pb.S3ApiConfiguration{
- Identities: []*iam_pb.Identity{
- {
- Name: "managed-user",
- PolicyNames: []string{"managedGet"},
- Credentials: []*iam_pb.Credential{
- {AccessKey: "AKIAMANAGED000001", SecretKey: "managed-secret"},
- },
- },
- },
- Policies: []*iam_pb.Policy{
- {
- Name: "managedGet",
- Content: `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`,
- },
- },
- }
- assert.NoError(t, cm.SaveConfiguration(context.Background(), config))
-
- iam := &IdentityAccessManagement{credentialManager: cm}
- assert.NoError(t, iam.LoadS3ApiConfigurationFromCredentialManager())
- assert.True(t, store.loadManagedPoliciesCalled)
-
- identity := iam.lookupByIdentityName("managed-user")
- if !assert.NotNil(t, identity) {
- return
- }
-
- errCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrNone, errCode)
-}
-
-func TestLoadS3ApiConfigurationFromCredentialManagerHydratesManagedPoliciesThroughPropagatingStore(t *testing.T) {
- baseStore := &memory.MemoryStore{}
- assert.NoError(t, baseStore.Initialize(nil, ""))
-
- upstream := &loadConfigurationDropsPoliciesStore{MemoryStore: baseStore}
- wrappedStore := credential.NewPropagatingCredentialStore(upstream, nil, nil)
- cm := &credential.CredentialManager{Store: wrappedStore}
-
- config := &iam_pb.S3ApiConfiguration{
- Identities: []*iam_pb.Identity{
- {
- Name: "managed-user",
- PolicyNames: []string{"managedGet"},
- Credentials: []*iam_pb.Credential{
- {AccessKey: "AKIAMANAGED000010", SecretKey: "managed-secret"},
- },
- },
- },
- Policies: []*iam_pb.Policy{
- {
- Name: "managedGet",
- Content: `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::test-bucket/*"}]}`,
- },
- },
- }
- assert.NoError(t, cm.SaveConfiguration(context.Background(), config))
-
- iam := &IdentityAccessManagement{credentialManager: cm}
- assert.NoError(t, iam.LoadS3ApiConfigurationFromCredentialManager())
- assert.True(t, upstream.loadManagedPoliciesCalled)
-
- identity := iam.lookupByIdentityName("managed-user")
- if !assert.NotNil(t, identity) {
- return
- }
-
- errCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodGet), identity, Action(ACTION_READ), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrNone, errCode)
-}
-
-func TestLoadS3ApiConfigurationFromCredentialManagerSyncsPoliciesToIAMManager(t *testing.T) {
- ctx := context.Background()
- baseStore := &memory.MemoryStore{}
- assert.NoError(t, baseStore.Initialize(nil, ""))
-
- cm := &credential.CredentialManager{Store: baseStore}
- config := &iam_pb.S3ApiConfiguration{
- Identities: []*iam_pb.Identity{
- {
- Name: "managed-user",
- PolicyNames: []string{"managedPut"},
- Credentials: []*iam_pb.Credential{
- {AccessKey: "AKIAMANAGED000002", SecretKey: "managed-secret"},
- },
- },
- },
- Policies: []*iam_pb.Policy{
- {
- Name: "managedPut",
- Content: `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:PutObject","s3:ListBucket"],"Resource":["arn:aws:s3:::cli-allowed-bucket","arn:aws:s3:::cli-allowed-bucket/*"]}]}`,
- },
- },
- }
- assert.NoError(t, cm.SaveConfiguration(ctx, config))
-
- iamManager, err := loadIAMManagerFromConfig("", func() string { return "localhost:8888" }, func() string {
- return "fallback-key-for-zero-config"
- })
- assert.NoError(t, err)
- iamManager.SetUserStore(cm)
-
- iam := &IdentityAccessManagement{credentialManager: cm}
- iam.SetIAMIntegration(NewS3IAMIntegration(iamManager, ""))
-
- assert.NoError(t, iam.LoadS3ApiConfigurationFromCredentialManager())
-
- identity := iam.lookupByIdentityName("managed-user")
- if !assert.NotNil(t, identity) {
- return
- }
-
- allowedErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodPut), identity, Action(ACTION_WRITE), "cli-allowed-bucket", "test-object")
- assert.Equal(t, s3err.ErrNone, allowedErrCode)
-
- forbiddenErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodPut), identity, Action(ACTION_WRITE), "cli-forbidden-bucket", "test-object")
- assert.Equal(t, s3err.ErrAccessDenied, forbiddenErrCode)
-}
-
-func TestLoadS3ApiConfigurationFromCredentialManagerHydratesInlinePolicies(t *testing.T) {
- baseStore := &memory.MemoryStore{}
- assert.NoError(t, baseStore.Initialize(nil, ""))
-
- inlinePolicy := policy_engine.PolicyDocument{
- Version: policy_engine.PolicyVersion2012_10_17,
- Statement: []policy_engine.PolicyStatement{
- {
- Effect: policy_engine.PolicyEffectAllow,
- Action: policy_engine.NewStringOrStringSlice("s3:PutObject"),
- Resource: policy_engine.NewStringOrStringSlicePtr("arn:aws:s3:::test-bucket/*"),
- },
- },
- }
-
- store := &inlinePolicyRuntimeStore{
- MemoryStore: baseStore,
- inlinePolicies: map[string]map[string]policy_engine.PolicyDocument{
- "inline-user": {
- "PutOnly": inlinePolicy,
- },
- },
- }
- cm := &credential.CredentialManager{Store: store}
-
- config := &iam_pb.S3ApiConfiguration{
- Identities: []*iam_pb.Identity{
- {
- Name: "inline-user",
- Actions: []string{"Write:test-bucket"},
- Credentials: []*iam_pb.Credential{
- {AccessKey: "AKIAINLINE0000001", SecretKey: "inline-secret"},
- },
- },
- },
- }
- assert.NoError(t, cm.SaveConfiguration(context.Background(), config))
-
- iam := &IdentityAccessManagement{credentialManager: cm}
- assert.NoError(t, iam.LoadS3ApiConfigurationFromCredentialManager())
-
- identity := iam.lookupByIdentityName("inline-user")
- if !assert.NotNil(t, identity) {
- return
- }
- assert.Contains(t, identity.PolicyNames, inlinePolicyRuntimeName("inline-user", "PutOnly"))
-
- putErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodPut), identity, Action(ACTION_WRITE), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrNone, putErrCode)
-
- deleteErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodDelete), identity, Action(ACTION_WRITE), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrAccessDenied, deleteErrCode)
-}
-
-func TestLoadS3ApiConfigurationFromCredentialManagerHydratesInlinePoliciesThroughPropagatingStore(t *testing.T) {
- baseStore := &memory.MemoryStore{}
- assert.NoError(t, baseStore.Initialize(nil, ""))
-
- inlinePolicy := policy_engine.PolicyDocument{
- Version: policy_engine.PolicyVersion2012_10_17,
- Statement: []policy_engine.PolicyStatement{
- {
- Effect: policy_engine.PolicyEffectAllow,
- Action: policy_engine.NewStringOrStringSlice("s3:PutObject"),
- Resource: policy_engine.NewStringOrStringSlicePtr("arn:aws:s3:::test-bucket/*"),
- },
- },
- }
-
- upstream := &inlinePolicyRuntimeStore{
- MemoryStore: baseStore,
- inlinePolicies: map[string]map[string]policy_engine.PolicyDocument{
- "inline-user": {
- "PutOnly": inlinePolicy,
- },
- },
- }
- wrappedStore := credential.NewPropagatingCredentialStore(upstream, nil, nil)
- cm := &credential.CredentialManager{Store: wrappedStore}
-
- config := &iam_pb.S3ApiConfiguration{
- Identities: []*iam_pb.Identity{
- {
- Name: "inline-user",
- Actions: []string{"Write:test-bucket"},
- Credentials: []*iam_pb.Credential{
- {AccessKey: "AKIAINLINE0000010", SecretKey: "inline-secret"},
- },
- },
- },
- }
- assert.NoError(t, cm.SaveConfiguration(context.Background(), config))
-
- iam := &IdentityAccessManagement{credentialManager: cm}
- assert.NoError(t, iam.LoadS3ApiConfigurationFromCredentialManager())
-
- identity := iam.lookupByIdentityName("inline-user")
- if !assert.NotNil(t, identity) {
- return
- }
- assert.Contains(t, identity.PolicyNames, inlinePolicyRuntimeName("inline-user", "PutOnly"))
-
- putErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodPut), identity, Action(ACTION_WRITE), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrNone, putErrCode)
-
- deleteErrCode := iam.VerifyActionPermission(newPolicyAuthRequest(t, http.MethodDelete), identity, Action(ACTION_WRITE), "test-bucket", "test-object")
- assert.Equal(t, s3err.ErrAccessDenied, deleteErrCode)
-}
-
-func TestLoadConfigurationDropsPoliciesStoreDoesNotMutateSourceConfig(t *testing.T) {
- baseStore := &memory.MemoryStore{}
- require.NoError(t, baseStore.Initialize(nil, ""))
-
- config := &iam_pb.S3ApiConfiguration{
- Policies: []*iam_pb.Policy{
- {Name: "managedGet", Content: `{"Version":"2012-10-17","Statement":[]}`},
- },
- }
- require.NoError(t, baseStore.SaveConfiguration(context.Background(), config))
-
- store := &loadConfigurationDropsPoliciesStore{MemoryStore: baseStore}
-
- stripped, err := store.LoadConfiguration(context.Background())
- require.NoError(t, err)
- assert.Nil(t, stripped.Policies)
-
- source, err := baseStore.LoadConfiguration(context.Background())
- require.NoError(t, err)
- require.Len(t, source.Policies, 1)
- assert.Equal(t, "managedGet", source.Policies[0].Name)
-}
-
-func TestMergePoliciesIntoConfigurationSkipsNilPolicies(t *testing.T) {
- config := &iam_pb.S3ApiConfiguration{
- Policies: []*iam_pb.Policy{
- nil,
- {Name: "existing", Content: "old"},
- },
- }
-
- mergePoliciesIntoConfiguration(config, []*iam_pb.Policy{
- nil,
- {Name: "", Content: "ignored"},
- {Name: "existing", Content: "updated"},
- {Name: "new", Content: "created"},
- })
-
- require.Len(t, config.Policies, 3)
- assert.Nil(t, config.Policies[0])
- assert.Equal(t, "existing", config.Policies[1].Name)
- assert.Equal(t, "updated", config.Policies[1].Content)
- assert.Equal(t, "new", config.Policies[2].Name)
- assert.Equal(t, "created", config.Policies[2].Content)
-}
-
-type LoadS3ApiConfigurationTestCase struct {
- pbAccount *iam_pb.Account
- pbIdent *iam_pb.Identity
- expectIdent *Identity
-}
-
-func TestLoadS3ApiConfiguration(t *testing.T) {
- specifiedAccount := Account{
- Id: "specifiedAccountID",
- DisplayName: "specifiedAccountName",
- EmailAddress: "specifiedAccounEmail@example.com",
- }
- pbSpecifiedAccount := iam_pb.Account{
- Id: "specifiedAccountID",
- DisplayName: "specifiedAccountName",
- EmailAddress: "specifiedAccounEmail@example.com",
- }
- testCases := map[string]*LoadS3ApiConfigurationTestCase{
- "notSpecifyAccountId": {
- pbIdent: &iam_pb.Identity{
- Name: "notSpecifyAccountId",
- Actions: []string{
- "Read",
- "Write",
- },
- Credentials: []*iam_pb.Credential{
- {
- AccessKey: "some_access_key1",
- SecretKey: "some_secret_key2",
- },
- },
- },
- expectIdent: &Identity{
- Name: "notSpecifyAccountId",
- Account: &AccountAdmin,
- PrincipalArn: fmt.Sprintf("arn:aws:iam::%s:user/notSpecifyAccountId", defaultAccountID),
- Actions: []Action{
- "Read",
- "Write",
- },
- Credentials: []*Credential{
- {
- AccessKey: "some_access_key1",
- SecretKey: "some_secret_key2",
- },
- },
- },
- },
- "specifiedAccountID": {
- pbAccount: &pbSpecifiedAccount,
- pbIdent: &iam_pb.Identity{
- Name: "specifiedAccountID",
- Account: &pbSpecifiedAccount,
- Actions: []string{
- "Read",
- "Write",
- },
- },
- expectIdent: &Identity{
- Name: "specifiedAccountID",
- Account: &specifiedAccount,
- PrincipalArn: fmt.Sprintf("arn:aws:iam::%s:user/specifiedAccountID", defaultAccountID),
- Actions: []Action{
- "Read",
- "Write",
- },
- },
- },
- "anonymous": {
- pbIdent: &iam_pb.Identity{
- Name: "anonymous",
- Actions: []string{
- "Read",
- "Write",
- },
- },
- expectIdent: &Identity{
- Name: "anonymous",
- Account: &AccountAnonymous,
- PrincipalArn: "*",
- Actions: []Action{
- "Read",
- "Write",
- },
- },
- },
- }
-
- config := &iam_pb.S3ApiConfiguration{
- Identities: make([]*iam_pb.Identity, 0),
- }
- for _, v := range testCases {
- config.Identities = append(config.Identities, v.pbIdent)
- if v.pbAccount != nil {
- config.Accounts = append(config.Accounts, v.pbAccount)
- }
- }
-
- iam := IdentityAccessManagement{}
- err := iam.loadS3ApiConfiguration(config)
- if err != nil {
- return
- }
-
- for _, ident := range iam.identities {
- tc := testCases[ident.Name]
- if !reflect.DeepEqual(ident, tc.expectIdent) {
- t.Errorf("not expect for ident name %s", ident.Name)
- }
- }
-}
-
-func TestNewIdentityAccessManagementWithStoreEnvVars(t *testing.T) {
- // Save original environment
- originalAccessKeyId := os.Getenv("AWS_ACCESS_KEY_ID")
- originalSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY")
-
- // Clean up after test
- defer func() {
- if originalAccessKeyId != "" {
- os.Setenv("AWS_ACCESS_KEY_ID", originalAccessKeyId)
- } else {
- os.Unsetenv("AWS_ACCESS_KEY_ID")
- }
- if originalSecretAccessKey != "" {
- os.Setenv("AWS_SECRET_ACCESS_KEY", originalSecretAccessKey)
- } else {
- os.Unsetenv("AWS_SECRET_ACCESS_KEY")
- }
- }()
-
- tests := []struct {
- name string
- accessKeyId string
- secretAccessKey string
- expectEnvIdentity bool
- expectedName string
- description string
- }{
- {
- name: "Environment variables used as fallback",
- accessKeyId: "AKIA1234567890ABCDEF",
- secretAccessKey: "secret123456789012345678901234567890abcdef12",
- expectEnvIdentity: true,
- expectedName: "admin-AKIA1234",
- description: "When no config file and no filer config, environment variables should be used",
- },
- {
- name: "Short access key fallback",
- accessKeyId: "SHORT",
- secretAccessKey: "secret123456789012345678901234567890abcdef12",
- expectEnvIdentity: true,
- expectedName: "admin-SHORT",
- description: "Short access keys should work correctly as fallback",
- },
- {
- name: "No env vars means no identities",
- accessKeyId: "",
- secretAccessKey: "",
- expectEnvIdentity: false,
- expectedName: "",
- description: "When no env vars and no config, should have no identities",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Reset the memory store to avoid test pollution
- if store := credential.Stores[0]; store.GetName() == credential.StoreTypeMemory {
- if memStore, ok := store.(interface{ Reset() }); ok {
- memStore.Reset()
- }
- }
-
- // Set up environment variables
- if tt.accessKeyId != "" {
- os.Setenv("AWS_ACCESS_KEY_ID", tt.accessKeyId)
- } else {
- os.Unsetenv("AWS_ACCESS_KEY_ID")
- }
- if tt.secretAccessKey != "" {
- os.Setenv("AWS_SECRET_ACCESS_KEY", tt.secretAccessKey)
- } else {
- os.Unsetenv("AWS_SECRET_ACCESS_KEY")
- }
-
- // Create IAM instance with memory store for testing (no config file)
- option := &S3ApiServerOption{
- Config: "", // No config file - this should trigger environment variable fallback
- }
- iam := NewIdentityAccessManagementWithStore(option, nil, string(credential.StoreTypeMemory))
-
- if tt.expectEnvIdentity {
- // Should have exactly one identity from environment variables
- assert.Len(t, iam.identities, 1, "Should have exactly one identity from environment variables")
-
- identity := iam.identities[0]
- assert.Equal(t, tt.expectedName, identity.Name, "Identity name should match expected")
- assert.Len(t, identity.Credentials, 1, "Should have one credential")
- assert.Equal(t, tt.accessKeyId, identity.Credentials[0].AccessKey, "Access key should match environment variable")
- assert.Equal(t, tt.secretAccessKey, identity.Credentials[0].SecretKey, "Secret key should match environment variable")
- assert.Contains(t, identity.Actions, Action(ACTION_ADMIN), "Should have admin action")
- } else {
- // When no env vars, should have no identities (since no config file)
- assert.Len(t, iam.identities, 0, "Should have no identities when no env vars and no config file")
- }
- })
- }
-}
-
-// TestConfigFileWithNoIdentitiesAllowsEnvVars tests that when a config file exists
-// but contains no identities (e.g., only KMS settings), environment variables should still work.
-// This test validates the fix for issue #7311.
-func TestConfigFileWithNoIdentitiesAllowsEnvVars(t *testing.T) {
- // Reset the memory store to avoid test pollution
- if store := credential.Stores[0]; store.GetName() == credential.StoreTypeMemory {
- if memStore, ok := store.(interface{ Reset() }); ok {
- memStore.Reset()
- }
- }
-
- // Set environment variables
- testAccessKey := "AKIATEST1234567890AB"
- testSecretKey := "testSecret1234567890123456789012345678901234"
- t.Setenv("AWS_ACCESS_KEY_ID", testAccessKey)
- t.Setenv("AWS_SECRET_ACCESS_KEY", testSecretKey)
-
- // Create a temporary config file with only KMS settings (no identities)
- configContent := `{
- "kms": {
- "default": {
- "provider": "local",
- "config": {
- "keyPath": "/tmp/test-key"
- }
- }
- }
-}`
- tmpFile, err := os.CreateTemp("", "s3-config-*.json")
- assert.NoError(t, err, "Should create temp config file")
- defer os.Remove(tmpFile.Name())
-
- _, err = tmpFile.Write([]byte(configContent))
- assert.NoError(t, err, "Should write config content")
- tmpFile.Close()
-
- // Create IAM instance with config file that has no identities
- option := &S3ApiServerOption{
- Config: tmpFile.Name(),
- }
- iam := NewIdentityAccessManagementWithStore(option, nil, string(credential.StoreTypeMemory))
-
- // Should have exactly one identity from environment variables
- assert.Len(t, iam.identities, 1, "Should have exactly one identity from environment variables even when config file exists with no identities")
-
- identity := iam.identities[0]
- assert.Equal(t, "admin-AKIATEST", identity.Name, "Identity name should be based on access key")
- assert.Len(t, identity.Credentials, 1, "Should have one credential")
- assert.Equal(t, testAccessKey, identity.Credentials[0].AccessKey, "Access key should match environment variable")
- assert.Equal(t, testSecretKey, identity.Credentials[0].SecretKey, "Secret key should match environment variable")
- assert.Contains(t, identity.Actions, Action(ACTION_ADMIN), "Should have admin action")
-}
-
-// TestBucketLevelListPermissions tests that bucket-level List permissions work correctly
-// This test validates the fix for issue #7066
-func TestBucketLevelListPermissions(t *testing.T) {
- // Test the functionality that was broken in issue #7066
-
- t.Run("Bucket Wildcard Permissions", func(t *testing.T) {
- // Create identity with bucket-level List permission using wildcards
- identity := &Identity{
- Name: "bucket-user",
- Actions: []Action{
- "List:mybucket*",
- "Read:mybucket*",
- "ReadAcp:mybucket*",
- "Write:mybucket*",
- "WriteAcp:mybucket*",
- "Tagging:mybucket*",
- },
- }
-
- // Test cases for bucket-level wildcard permissions
- testCases := []struct {
- name string
- action Action
- bucket string
- object string
- shouldAllow bool
- description string
- }{
- {
- name: "exact bucket match",
- action: "List",
- bucket: "mybucket",
- object: "",
- shouldAllow: true,
- description: "Should allow access to exact bucket name",
- },
- {
- name: "bucket with suffix",
- action: "List",
- bucket: "mybucket-prod",
- object: "",
- shouldAllow: true,
- description: "Should allow access to bucket with matching prefix",
- },
- {
- name: "bucket with numbers",
- action: "List",
- bucket: "mybucket123",
- object: "",
- shouldAllow: true,
- description: "Should allow access to bucket with numbers",
- },
- {
- name: "different bucket",
- action: "List",
- bucket: "otherbucket",
- object: "",
- shouldAllow: false,
- description: "Should deny access to bucket with different prefix",
- },
- {
- name: "partial match",
- action: "List",
- bucket: "notmybucket",
- object: "",
- shouldAllow: false,
- description: "Should deny access to bucket that contains but doesn't start with the prefix",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- result := identity.CanDo(tc.action, tc.bucket, tc.object)
- assert.Equal(t, tc.shouldAllow, result, tc.description)
- })
- }
- })
-
- t.Run("Global List Permission", func(t *testing.T) {
- // Create identity with global List permission
- identity := &Identity{
- Name: "global-user",
- Actions: []Action{
- "List",
- },
- }
-
- // Should allow access to any bucket
- testCases := []string{"anybucket", "mybucket", "test-bucket", "prod-data"}
-
- for _, bucket := range testCases {
- result := identity.CanDo("List", bucket, "")
- assert.True(t, result, "Global List permission should allow access to bucket %s", bucket)
- }
- })
-
- t.Run("No Wildcard Exact Match", func(t *testing.T) {
- // Create identity with exact bucket permission (no wildcard)
- identity := &Identity{
- Name: "exact-user",
- Actions: []Action{
- "List:specificbucket",
- },
- }
-
- // Should only allow access to the exact bucket
- assert.True(t, identity.CanDo("List", "specificbucket", ""), "Should allow access to exact bucket")
- assert.False(t, identity.CanDo("List", "specificbucket-test", ""), "Should deny access to bucket with suffix")
- assert.False(t, identity.CanDo("List", "otherbucket", ""), "Should deny access to different bucket")
- })
-
- t.Log("This test validates the fix for issue #7066")
- t.Log("Bucket-level List permissions like 'List:bucket*' work correctly")
- t.Log("ListBucketsHandler now uses consistent authentication flow")
-}
-
-// TestListBucketsAuthRequest tests that authRequest works correctly for ListBuckets operations
-// This test validates that the fix for the regression identified in PR #7067 works correctly
-func TestListBucketsAuthRequest(t *testing.T) {
- t.Run("ListBuckets special case handling", func(t *testing.T) {
- // Create identity with bucket-specific permissions (no global List permission)
- identity := &Identity{
- Name: "bucket-user",
- Account: &AccountAdmin,
- Actions: []Action{
- Action("List:mybucket*"),
- Action("Read:mybucket*"),
- },
- }
-
- // Test 1: ListBuckets operation should succeed (bucket = "")
- // This would have failed before the fix because CanDo("List", "", "") would return false
- // After the fix, it bypasses the CanDo check for ListBuckets operations
-
- // Simulate what happens in authRequest for ListBuckets:
- // action = ACTION_LIST, bucket = "", object = ""
-
- // Before fix: identity.CanDo(ACTION_LIST, "", "") would fail
- // After fix: the CanDo check should be bypassed
-
- // Test the individual CanDo method to show it would fail without the special case
- result := identity.CanDo(Action(ACTION_LIST), "", "")
- assert.False(t, result, "CanDo should return false for empty bucket with bucket-specific permissions")
-
- // Test with a specific bucket that matches the permission
- result2 := identity.CanDo(Action(ACTION_LIST), "mybucket", "")
- assert.True(t, result2, "CanDo should return true for matching bucket")
-
- // Test with a specific bucket that doesn't match
- result3 := identity.CanDo(Action(ACTION_LIST), "otherbucket", "")
- assert.False(t, result3, "CanDo should return false for non-matching bucket")
- })
-
- t.Run("Object listing maintains permission enforcement", func(t *testing.T) {
- // Create identity with bucket-specific permissions
- identity := &Identity{
- Name: "bucket-user",
- Account: &AccountAdmin,
- Actions: []Action{
- Action("List:mybucket*"),
- },
- }
-
- // For object listing operations, the normal permission checks should still apply
- // These operations have a specific bucket in the URL
-
- // Should succeed for allowed bucket
- result1 := identity.CanDo(Action(ACTION_LIST), "mybucket", "prefix/")
- assert.True(t, result1, "Should allow listing objects in permitted bucket")
-
- result2 := identity.CanDo(Action(ACTION_LIST), "mybucket-prod", "")
- assert.True(t, result2, "Should allow listing objects in wildcard-matched bucket")
-
- // Should fail for disallowed bucket
- result3 := identity.CanDo(Action(ACTION_LIST), "otherbucket", "")
- assert.False(t, result3, "Should deny listing objects in non-permitted bucket")
- })
-
- t.Log("This test validates the fix for the regression identified in PR #7067")
- t.Log("ListBuckets operation bypasses global permission check when bucket is empty")
- t.Log("Object listing still properly enforces bucket-level permissions")
-}
-
-// TestSignatureVerificationDoesNotCheckPermissions tests that signature verification
-// only validates the signature and identity, not permissions. Permissions should be
-// checked later in authRequest based on the actual operation.
-// This test validates the fix for issue #7334
-func TestSignatureVerificationDoesNotCheckPermissions(t *testing.T) {
- t.Run("List-only user can authenticate via signature", func(t *testing.T) {
- // Create IAM with a user that only has List permissions on specific buckets
- iam := &IdentityAccessManagement{
- hashes: make(map[string]*sync.Pool),
- hashCounters: make(map[string]*int32),
- }
-
- err := iam.loadS3ApiConfiguration(&iam_pb.S3ApiConfiguration{
- Identities: []*iam_pb.Identity{
- {
- Name: "list-only-user",
- Credentials: []*iam_pb.Credential{
- {
- AccessKey: "list_access_key",
- SecretKey: "list_secret_key",
- },
- },
- Actions: []string{
- "List:bucket-123",
- "Read:bucket-123",
- },
- },
- },
- })
- assert.NoError(t, err)
-
- // Before the fix, signature verification would fail because it checked for Write permission
- // After the fix, signature verification should succeed (only checking signature validity)
- // The actual permission check happens later in authRequest with the correct action
-
- // The user should be able to authenticate (signature verification passes)
- // But authorization for specific actions is checked separately
- identity, cred, found := iam.lookupByAccessKey("list_access_key")
- assert.True(t, found, "Should find the user by access key")
- assert.Equal(t, "list-only-user", identity.Name)
- assert.Equal(t, "list_secret_key", cred.SecretKey)
-
- // User should have the correct permissions
- assert.True(t, identity.CanDo(Action(ACTION_LIST), "bucket-123", ""))
- assert.True(t, identity.CanDo(Action(ACTION_READ), "bucket-123", ""))
-
- // User should NOT have write permissions
- assert.False(t, identity.CanDo(Action(ACTION_WRITE), "bucket-123", ""))
- })
-
- t.Log("This test validates the fix for issue #7334")
- t.Log("Signature verification no longer checks for Write permission")
- t.Log("This allows list-only and read-only users to authenticate via AWS Signature V4")
-}
-
-func TestStaticIdentityProtection(t *testing.T) {
- iam := NewIdentityAccessManagement(&S3ApiServerOption{}, nil)
-
- // Add a static identity
- staticIdent := &Identity{
- Name: "static-user",
- IsStatic: true,
- }
- iam.m.Lock()
- if iam.nameToIdentity == nil {
- iam.nameToIdentity = make(map[string]*Identity)
- }
- iam.identities = append(iam.identities, staticIdent)
- iam.nameToIdentity[staticIdent.Name] = staticIdent
- iam.m.Unlock()
-
- // Add a dynamic identity
- dynamicIdent := &Identity{
- Name: "dynamic-user",
- IsStatic: false,
- }
- iam.m.Lock()
- iam.identities = append(iam.identities, dynamicIdent)
- iam.nameToIdentity[dynamicIdent.Name] = dynamicIdent
- iam.m.Unlock()
-
- // Try to remove static identity
- iam.RemoveIdentity("static-user")
-
- // Verify static identity still exists
- iam.m.RLock()
- _, ok := iam.nameToIdentity["static-user"]
- iam.m.RUnlock()
- assert.True(t, ok, "Static identity should not be removed")
-
- // Try to remove dynamic identity
- iam.RemoveIdentity("dynamic-user")
-
- // Verify dynamic identity is removed
- iam.m.RLock()
- _, ok = iam.nameToIdentity["dynamic-user"]
- iam.m.RUnlock()
- assert.False(t, ok, "Dynamic identity should have been removed")
-}
-
-func TestParseExternalUrlToHost(t *testing.T) {
- tests := []struct {
- name string
- input string
- expected string
- expectErr bool
- }{
- {
- name: "empty string",
- input: "",
- expected: "",
- },
- {
- name: "HTTPS with default port stripped",
- input: "https://api.example.com:443",
- expected: "api.example.com",
- },
- {
- name: "HTTP with default port stripped",
- input: "http://api.example.com:80",
- expected: "api.example.com",
- },
- {
- name: "HTTPS with non-standard port preserved",
- input: "https://api.example.com:9000",
- expected: "api.example.com:9000",
- },
- {
- name: "HTTP with non-standard port preserved",
- input: "http://api.example.com:8080",
- expected: "api.example.com:8080",
- },
- {
- name: "HTTPS without port",
- input: "https://api.example.com",
- expected: "api.example.com",
- },
- {
- name: "HTTP without port",
- input: "http://api.example.com",
- expected: "api.example.com",
- },
- {
- name: "IPv6 with non-standard port",
- input: "https://[::1]:9000",
- expected: "[::1]:9000",
- },
- {
- name: "IPv6 with default HTTPS port stripped",
- input: "https://[::1]:443",
- expected: "::1",
- },
- {
- name: "IPv6 without port",
- input: "https://[::1]",
- expected: "::1",
- },
- {
- name: "invalid URL",
- input: "://not-a-url",
- expectErr: true,
- },
- {
- name: "missing host",
- input: "https://",
- expectErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result, err := parseExternalUrlToHost(tt.input)
- if tt.expectErr {
- assert.Error(t, err)
- return
- }
- assert.NoError(t, err)
- assert.Equal(t, tt.expected, result)
- })
- }
-}
diff --git a/weed/s3api/bucket_metadata.go b/weed/s3api/bucket_metadata.go
index f13fbb949..ce0162be8 100644
--- a/weed/s3api/bucket_metadata.go
+++ b/weed/s3api/bucket_metadata.go
@@ -224,12 +224,6 @@ func (r *BucketRegistry) removeMetadataCache(bucket string) {
delete(r.metadataCache, bucket)
}
-func (r *BucketRegistry) markNotFound(bucket string) {
- r.notFoundLock.Lock()
- defer r.notFoundLock.Unlock()
- r.notFound[bucket] = struct{}{}
-}
-
func (r *BucketRegistry) unMarkNotFound(bucket string) {
r.notFoundLock.Lock()
defer r.notFoundLock.Unlock()
diff --git a/weed/s3api/filer_multipart_test.go b/weed/s3api/filer_multipart_test.go
deleted file mode 100644
index 92ecbeba9..000000000
--- a/weed/s3api/filer_multipart_test.go
+++ /dev/null
@@ -1,267 +0,0 @@
-package s3api
-
-import (
- "encoding/hex"
- "net/http"
- "testing"
- "time"
-
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/s3"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
- "github.com/stretchr/testify/assert"
-)
-
-func TestInitiateMultipartUploadResult(t *testing.T) {
-
- expected := `
-example-bucketexample-objectVXBsb2FkIElEIGZvciA2aWWpbmcncyBteS1tb3ZpZS5tMnRzIHVwbG9hZA`
- response := &InitiateMultipartUploadResult{
- CreateMultipartUploadOutput: s3.CreateMultipartUploadOutput{
- Bucket: aws.String("example-bucket"),
- Key: aws.String("example-object"),
- UploadId: aws.String("VXBsb2FkIElEIGZvciA2aWWpbmcncyBteS1tb3ZpZS5tMnRzIHVwbG9hZA"),
- },
- }
-
- encoded := string(s3err.EncodeXMLResponse(response))
- if encoded != expected {
- t.Errorf("unexpected output: %s\nexpecting:%s", encoded, expected)
- }
-
-}
-
-func TestListPartsResult(t *testing.T) {
-
- expected := `
-"12345678"1970-01-01T00:00:00Z1123`
- response := &ListPartsResult{
- Part: []*s3.Part{
- {
- PartNumber: aws.Int64(int64(1)),
- LastModified: aws.Time(time.Unix(0, 0).UTC()),
- Size: aws.Int64(int64(123)),
- ETag: aws.String("\"12345678\""),
- },
- },
- }
-
- encoded := string(s3err.EncodeXMLResponse(response))
- if encoded != expected {
- t.Errorf("unexpected output: %s\nexpecting:%s", encoded, expected)
- }
-
-}
-
-func TestCompleteMultipartResultIncludesVersionId(t *testing.T) {
- r := &http.Request{Host: "localhost", Header: make(http.Header)}
- input := &s3.CompleteMultipartUploadInput{
- Bucket: aws.String("example-bucket"),
- Key: aws.String("example-object"),
- }
-
- entry := &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte("version-123"),
- },
- }
-
- result := completeMultipartResult(r, input, "\"etag-value\"", entry)
- if assert.NotNil(t, result.VersionId) {
- assert.Equal(t, "version-123", *result.VersionId)
- }
-}
-
-func TestCompleteMultipartResultOmitsNullVersionId(t *testing.T) {
- r := &http.Request{Host: "localhost", Header: make(http.Header)}
- input := &s3.CompleteMultipartUploadInput{
- Bucket: aws.String("example-bucket"),
- Key: aws.String("example-object"),
- }
-
- entry := &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte("null"),
- },
- }
-
- result := completeMultipartResult(r, input, "\"etag-value\"", entry)
- assert.Nil(t, result.VersionId)
-}
-
-func Test_parsePartNumber(t *testing.T) {
- tests := []struct {
- name string
- fileName string
- partNum int
- }{
- {
- "first",
- "0001_uuid.part",
- 1,
- },
- {
- "second",
- "0002.part",
- 2,
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- partNumber, _ := parsePartNumber(tt.fileName)
- assert.Equalf(t, tt.partNum, partNumber, "parsePartNumber(%v)", tt.fileName)
- })
- }
-}
-
-func TestGetEntryNameAndDir(t *testing.T) {
- s3a := &S3ApiServer{
- option: &S3ApiServerOption{
- BucketsPath: "/buckets",
- },
- }
-
- tests := []struct {
- name string
- bucket string
- key string
- expectedName string
- expectedDirEnd string // We check the suffix since dir includes BucketsPath
- }{
- {
- name: "simple file at root",
- bucket: "test-bucket",
- key: "/file.txt",
- expectedName: "file.txt",
- expectedDirEnd: "/buckets/test-bucket",
- },
- {
- name: "file in subdirectory",
- bucket: "test-bucket",
- key: "/folder/file.txt",
- expectedName: "file.txt",
- expectedDirEnd: "/buckets/test-bucket/folder",
- },
- {
- name: "file in nested subdirectory",
- bucket: "test-bucket",
- key: "/folder/subfolder/file.txt",
- expectedName: "file.txt",
- expectedDirEnd: "/buckets/test-bucket/folder/subfolder",
- },
- {
- name: "key without leading slash",
- bucket: "test-bucket",
- key: "folder/file.txt",
- expectedName: "file.txt",
- expectedDirEnd: "/buckets/test-bucket/folder",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- input := &s3.CompleteMultipartUploadInput{
- Bucket: aws.String(tt.bucket),
- Key: aws.String(tt.key),
- }
- entryName, dirName := s3a.getEntryNameAndDir(input)
- assert.Equal(t, tt.expectedName, entryName, "entry name mismatch")
- assert.Equal(t, tt.expectedDirEnd, dirName, "directory mismatch")
- })
- }
-}
-
-func TestValidateCompletePartETag(t *testing.T) {
- t.Run("matches_composite_etag_from_extended", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte("ea58527f14c6ae0dd53089966e44941b-2"),
- },
- Attributes: &filer_pb.FuseAttributes{},
- }
- match, invalid, part, stored := validateCompletePartETag(`"ea58527f14c6ae0dd53089966e44941b-2"`, entry)
- assert.True(t, match)
- assert.False(t, invalid)
- assert.Equal(t, "ea58527f14c6ae0dd53089966e44941b-2", part)
- assert.Equal(t, "ea58527f14c6ae0dd53089966e44941b-2", stored)
- })
-
- t.Run("matches_md5_from_attributes", func(t *testing.T) {
- md5Bytes, err := hex.DecodeString("324b2665939fde5b8678d3a8b5c46970")
- assert.NoError(t, err)
- entry := &filer_pb.Entry{
- Attributes: &filer_pb.FuseAttributes{
- Md5: md5Bytes,
- },
- }
- match, invalid, part, stored := validateCompletePartETag("324b2665939fde5b8678d3a8b5c46970", entry)
- assert.True(t, match)
- assert.False(t, invalid)
- assert.Equal(t, "324b2665939fde5b8678d3a8b5c46970", part)
- assert.Equal(t, "324b2665939fde5b8678d3a8b5c46970", stored)
- })
-
- t.Run("detects_mismatch", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte("67fdd2e302502ff9f9b606bc036e6892-2"),
- },
- Attributes: &filer_pb.FuseAttributes{},
- }
- match, invalid, _, _ := validateCompletePartETag("686f7d71bacdcd539dd4e17a0d7f1e5f-2", entry)
- assert.False(t, match)
- assert.False(t, invalid)
- })
-
- t.Run("flags_empty_client_etag_as_invalid", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte("67fdd2e302502ff9f9b606bc036e6892-2"),
- },
- Attributes: &filer_pb.FuseAttributes{},
- }
- match, invalid, _, _ := validateCompletePartETag(`""`, entry)
- assert.False(t, match)
- assert.True(t, invalid)
- })
-}
-
-func TestCompleteMultipartUploadRejectsOutOfOrderParts(t *testing.T) {
- s3a := NewS3ApiServerForTest()
- input := &s3.CompleteMultipartUploadInput{
- Bucket: aws.String("bucket"),
- Key: aws.String("object"),
- UploadId: aws.String("upload"),
- }
- parts := &CompleteMultipartUpload{
- Parts: []CompletedPart{
- {PartNumber: 2, ETag: "\"etag-2\""},
- {PartNumber: 1, ETag: "\"etag-1\""},
- },
- }
-
- result, errCode := s3a.completeMultipartUpload(&http.Request{Header: make(http.Header)}, input, parts)
- assert.Nil(t, result)
- assert.Equal(t, s3err.ErrInvalidPartOrder, errCode)
-}
-
-func TestCompleteMultipartUploadAllowsDuplicatePartNumbers(t *testing.T) {
- s3a := NewS3ApiServerForTest()
- input := &s3.CompleteMultipartUploadInput{
- Bucket: aws.String("bucket"),
- Key: aws.String("object"),
- UploadId: aws.String("upload"),
- }
- parts := &CompleteMultipartUpload{
- Parts: []CompletedPart{
- {PartNumber: 1, ETag: "\"etag-older\""},
- {PartNumber: 1, ETag: "\"etag-newer\""},
- },
- }
-
- result, errCode := s3a.completeMultipartUpload(&http.Request{Header: make(http.Header)}, input, parts)
- assert.Nil(t, result)
- assert.Equal(t, s3err.ErrNoSuchUpload, errCode)
-}
diff --git a/weed/s3api/iam_optional_test.go b/weed/s3api/iam_optional_test.go
index 583b14791..4d35e4df9 100644
--- a/weed/s3api/iam_optional_test.go
+++ b/weed/s3api/iam_optional_test.go
@@ -3,9 +3,22 @@ package s3api
import (
"testing"
+ "github.com/seaweedfs/seaweedfs/weed/credential"
"github.com/stretchr/testify/assert"
)
+// resetMemoryStore resets the shared in-memory credential store so that tests
+// that rely on an empty store are not polluted by earlier tests.
+func resetMemoryStore() {
+ for _, store := range credential.Stores {
+ if store.GetName() == credential.StoreTypeMemory {
+ if resettable, ok := store.(interface{ Reset() }); ok {
+ resettable.Reset()
+ }
+ }
+ }
+}
+
func TestLoadIAMManagerWithNoConfig(t *testing.T) {
// Verify that IAM can be initialized without any config
option := &S3ApiServerOption{
@@ -17,6 +30,9 @@ func TestLoadIAMManagerWithNoConfig(t *testing.T) {
}
func TestLoadIAMManagerFromConfig_EmptyConfigWithFallbackKey(t *testing.T) {
+ // Reset the shared memory store to avoid state leaking from other tests.
+ resetMemoryStore()
+
// Initialize IAM with empty config — no anonymous identity is configured,
// so LookupAnonymous should return not-found.
option := &S3ApiServerOption{
diff --git a/weed/s3api/iceberg/commit_helpers.go b/weed/s3api/iceberg/commit_helpers.go
index 958cc753f..742f24cae 100644
--- a/weed/s3api/iceberg/commit_helpers.go
+++ b/weed/s3api/iceberg/commit_helpers.go
@@ -6,8 +6,6 @@ import (
"errors"
"fmt"
"net/http"
- "path"
- "strconv"
"strings"
"github.com/apache/iceberg-go/table"
@@ -25,10 +23,6 @@ type icebergRequestError struct {
message string
}
-func (e *icebergRequestError) Error() string {
- return e.message
-}
-
type createOnCommitInput struct {
bucketARN string
markerBucket string
@@ -88,19 +82,6 @@ func isS3TablesAlreadyExists(err error) bool {
(tableErr.Type == s3tables.ErrCodeTableAlreadyExists || tableErr.Type == s3tables.ErrCodeNamespaceAlreadyExists || strings.Contains(strings.ToLower(tableErr.Message), "already exists"))
}
-func parseMetadataVersionFromLocation(metadataLocation string) int {
- base := path.Base(metadataLocation)
- if !strings.HasPrefix(base, "v") || !strings.HasSuffix(base, ".metadata.json") {
- return 0
- }
- rawVersion := strings.TrimPrefix(strings.TrimSuffix(base, ".metadata.json"), "v")
- version, err := strconv.Atoi(rawVersion)
- if err != nil || version <= 0 {
- return 0
- }
- return version
-}
-
func (s *Server) finalizeCreateOnCommit(ctx context.Context, input createOnCommitInput) (*CommitTableResponse, *icebergRequestError) {
builder, err := table.MetadataBuilderFromBase(input.baseMetadata, input.baseMetadataLoc)
if err != nil {
diff --git a/weed/s3api/iceberg/iceberg_stage_create_helpers_test.go b/weed/s3api/iceberg/iceberg_stage_create_helpers_test.go
deleted file mode 100644
index 81fd7aba9..000000000
--- a/weed/s3api/iceberg/iceberg_stage_create_helpers_test.go
+++ /dev/null
@@ -1,76 +0,0 @@
-package iceberg
-
-import (
- "strings"
- "testing"
-
- "github.com/apache/iceberg-go/table"
- "github.com/google/uuid"
-)
-
-func TestHasAssertCreateRequirement(t *testing.T) {
- requirements := table.Requirements{table.AssertCreate()}
- if !hasAssertCreateRequirement(requirements) {
- t.Fatalf("hasAssertCreateRequirement() = false, want true")
- }
-
- requirements = table.Requirements{table.AssertDefaultSortOrderID(0)}
- if hasAssertCreateRequirement(requirements) {
- t.Fatalf("hasAssertCreateRequirement() = true, want false")
- }
-}
-
-func TestParseMetadataVersionFromLocation(t *testing.T) {
- testCases := []struct {
- location string
- version int
- }{
- {location: "s3://b/ns/t/metadata/v1.metadata.json", version: 1},
- {location: "s3://b/ns/t/metadata/v25.metadata.json", version: 25},
- {location: "v1.metadata.json", version: 1},
- {location: "s3://b/ns/t/metadata/v0.metadata.json", version: 0},
- {location: "s3://b/ns/t/metadata/v-1.metadata.json", version: 0},
- {location: "s3://b/ns/t/metadata/vABC.metadata.json", version: 0},
- {location: "s3://b/ns/t/metadata/current.json", version: 0},
- {location: "", version: 0},
- }
-
- for _, tc := range testCases {
- t.Run(tc.location, func(t *testing.T) {
- if got := parseMetadataVersionFromLocation(tc.location); got != tc.version {
- t.Errorf("parseMetadataVersionFromLocation(%q) = %d, want %d", tc.location, got, tc.version)
- }
- })
- }
-}
-
-func TestStageCreateMarkerNamespaceKey(t *testing.T) {
- key := stageCreateMarkerNamespaceKey([]string{"a", "b"})
- if key == "a\x1fb" {
- t.Fatalf("stageCreateMarkerNamespaceKey() returned unescaped namespace key %q", key)
- }
- if !strings.Contains(key, "%1F") {
- t.Fatalf("stageCreateMarkerNamespaceKey() = %q, want escaped unit separator", key)
- }
-}
-
-func TestStageCreateMarkerDir(t *testing.T) {
- dir := stageCreateMarkerDir("warehouse", []string{"ns"}, "orders")
- if !strings.Contains(dir, stageCreateMarkerDirName) {
- t.Fatalf("stageCreateMarkerDir() = %q, want marker dir segment %q", dir, stageCreateMarkerDirName)
- }
- if !strings.HasSuffix(dir, "/orders") {
- t.Fatalf("stageCreateMarkerDir() = %q, want suffix /orders", dir)
- }
-}
-
-func TestStageCreateStagedTablePath(t *testing.T) {
- tableUUID := uuid.MustParse("11111111-2222-3333-4444-555555555555")
- stagedPath := stageCreateStagedTablePath([]string{"ns"}, "orders", tableUUID)
- if !strings.Contains(stagedPath, stageCreateMarkerDirName) {
- t.Fatalf("stageCreateStagedTablePath() = %q, want marker dir segment %q", stagedPath, stageCreateMarkerDirName)
- }
- if !strings.HasSuffix(stagedPath, "/"+tableUUID.String()) {
- t.Fatalf("stageCreateStagedTablePath() = %q, want UUID suffix %q", stagedPath, tableUUID.String())
- }
-}
diff --git a/weed/s3api/object_lock_utils.go b/weed/s3api/object_lock_utils.go
index 9455cb12c..d58bc7b8e 100644
--- a/weed/s3api/object_lock_utils.go
+++ b/weed/s3api/object_lock_utils.go
@@ -2,8 +2,6 @@ package s3api
import (
"context"
- "encoding/xml"
- "fmt"
"strconv"
"time"
@@ -35,21 +33,6 @@ func StoreVersioningInExtended(entry *filer_pb.Entry, enabled bool) error {
return nil
}
-// LoadVersioningFromExtended loads versioning configuration from entry extended attributes
-func LoadVersioningFromExtended(entry *filer_pb.Entry) (bool, bool) {
- if entry == nil || entry.Extended == nil {
- return false, false // not found, default to suspended
- }
-
- // Check for S3 API compatible key
- if versioningBytes, exists := entry.Extended[s3_constants.ExtVersioningKey]; exists {
- enabled := string(versioningBytes) == s3_constants.VersioningEnabled
- return enabled, true
- }
-
- return false, false // not found
-}
-
// GetVersioningStatus returns the versioning status as a string: "", "Enabled", or "Suspended"
// Empty string means versioning was never enabled
func GetVersioningStatus(entry *filer_pb.Entry) string {
@@ -90,15 +73,6 @@ func CreateObjectLockConfiguration(enabled bool, mode string, days int, years in
return config
}
-// ObjectLockConfigurationToXML converts ObjectLockConfiguration to XML bytes
-func ObjectLockConfigurationToXML(config *ObjectLockConfiguration) ([]byte, error) {
- if config == nil {
- return nil, fmt.Errorf("object lock configuration is nil")
- }
-
- return xml.Marshal(config)
-}
-
// StoreObjectLockConfigurationInExtended stores Object Lock configuration in entry extended attributes
func StoreObjectLockConfigurationInExtended(entry *filer_pb.Entry, config *ObjectLockConfiguration) error {
if entry.Extended == nil {
@@ -379,18 +353,6 @@ func validateDefaultRetention(retention *DefaultRetention) error {
return nil
}
-// ====================================================================
-// SHARED OBJECT LOCK CHECKING FUNCTIONS
-// ====================================================================
-// These functions delegate to s3_objectlock package to avoid code duplication.
-// They are kept here for backward compatibility with existing callers.
-
-// EntryHasActiveLock checks if an entry has an active retention or legal hold
-// Delegates to s3_objectlock.EntryHasActiveLock
-func EntryHasActiveLock(entry *filer_pb.Entry, currentTime time.Time) bool {
- return s3_objectlock.EntryHasActiveLock(entry, currentTime)
-}
-
// HasObjectsWithActiveLocks checks if any objects in the bucket have active retention or legal hold
// Delegates to s3_objectlock.HasObjectsWithActiveLocks
func HasObjectsWithActiveLocks(ctx context.Context, client filer_pb.SeaweedFilerClient, bucketPath string) (bool, error) {
diff --git a/weed/s3api/policy/post-policy.go b/weed/s3api/policy/post-policy.go
deleted file mode 100644
index 3250cdf49..000000000
--- a/weed/s3api/policy/post-policy.go
+++ /dev/null
@@ -1,321 +0,0 @@
-package policy
-
-/*
- * MinIO Go Library for Amazon S3 Compatible Cloud Storage
- * Copyright 2015-2017 MinIO, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-import (
- "encoding/base64"
- "fmt"
- "net/http"
- "strings"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
-)
-
-// expirationDateFormat date format for expiration key in json policy.
-const expirationDateFormat = "2006-01-02T15:04:05.999Z"
-
-// policyCondition explanation:
-// http://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-HTTPPOSTConstructPolicy.html
-//
-// Example:
-//
-// policyCondition {
-// matchType: "$eq",
-// key: "$Content-Type",
-// value: "image/png",
-// }
-type policyCondition struct {
- matchType string
- condition string
- value string
-}
-
-// PostPolicy - Provides strict static type conversion and validation
-// for Amazon S3's POST policy JSON string.
-type PostPolicy struct {
- // Expiration date and time of the POST policy.
- expiration time.Time
- // Collection of different policy conditions.
- conditions []policyCondition
- // ContentLengthRange minimum and maximum allowable size for the
- // uploaded content.
- contentLengthRange struct {
- min int64
- max int64
- }
-
- // Post form data.
- formData map[string]string
-}
-
-// NewPostPolicy - Instantiate new post policy.
-func NewPostPolicy() *PostPolicy {
- p := &PostPolicy{}
- p.conditions = make([]policyCondition, 0)
- p.formData = make(map[string]string)
- return p
-}
-
-// SetExpires - Sets expiration time for the new policy.
-func (p *PostPolicy) SetExpires(t time.Time) error {
- if t.IsZero() {
- return errInvalidArgument("No expiry time set.")
- }
- p.expiration = t
- return nil
-}
-
-// SetKey - Sets an object name for the policy based upload.
-func (p *PostPolicy) SetKey(key string) error {
- if strings.TrimSpace(key) == "" || key == "" {
- return errInvalidArgument("Object name is empty.")
- }
- policyCond := policyCondition{
- matchType: "eq",
- condition: "$key",
- value: key,
- }
- if err := p.addNewPolicy(policyCond); err != nil {
- return err
- }
- p.formData["key"] = key
- return nil
-}
-
-// SetKeyStartsWith - Sets an object name that an policy based upload
-// can start with.
-func (p *PostPolicy) SetKeyStartsWith(keyStartsWith string) error {
- if strings.TrimSpace(keyStartsWith) == "" || keyStartsWith == "" {
- return errInvalidArgument("Object prefix is empty.")
- }
- policyCond := policyCondition{
- matchType: "starts-with",
- condition: "$key",
- value: keyStartsWith,
- }
- if err := p.addNewPolicy(policyCond); err != nil {
- return err
- }
- p.formData["key"] = keyStartsWith
- return nil
-}
-
-// SetBucket - Sets bucket at which objects will be uploaded to.
-func (p *PostPolicy) SetBucket(bucketName string) error {
- if strings.TrimSpace(bucketName) == "" || bucketName == "" {
- return errInvalidArgument("Bucket name is empty.")
- }
- policyCond := policyCondition{
- matchType: "eq",
- condition: "$bucket",
- value: bucketName,
- }
- if err := p.addNewPolicy(policyCond); err != nil {
- return err
- }
- p.formData["bucket"] = bucketName
- return nil
-}
-
-// SetCondition - Sets condition for credentials, date and algorithm
-func (p *PostPolicy) SetCondition(matchType, condition, value string) error {
- if strings.TrimSpace(value) == "" || value == "" {
- return errInvalidArgument("No value specified for condition")
- }
-
- policyCond := policyCondition{
- matchType: matchType,
- condition: "$" + condition,
- value: value,
- }
- if condition == "X-Amz-Credential" || condition == "X-Amz-Date" || condition == "X-Amz-Algorithm" {
- if err := p.addNewPolicy(policyCond); err != nil {
- return err
- }
- p.formData[condition] = value
- return nil
- }
- return errInvalidArgument("Invalid condition in policy")
-}
-
-// SetContentType - Sets content-type of the object for this policy
-// based upload.
-func (p *PostPolicy) SetContentType(contentType string) error {
- if strings.TrimSpace(contentType) == "" || contentType == "" {
- return errInvalidArgument("No content type specified.")
- }
- policyCond := policyCondition{
- matchType: "eq",
- condition: "$Content-Type",
- value: contentType,
- }
- if err := p.addNewPolicy(policyCond); err != nil {
- return err
- }
- p.formData["Content-Type"] = contentType
- return nil
-}
-
-// SetContentLengthRange - Set new min and max content length
-// condition for all incoming uploads.
-func (p *PostPolicy) SetContentLengthRange(min, max int64) error {
- if min > max {
- return errInvalidArgument("Minimum limit is larger than maximum limit.")
- }
- if min < 0 {
- return errInvalidArgument("Minimum limit cannot be negative.")
- }
- if max < 0 {
- return errInvalidArgument("Maximum limit cannot be negative.")
- }
- p.contentLengthRange.min = min
- p.contentLengthRange.max = max
- return nil
-}
-
-// SetSuccessActionRedirect - Sets the redirect success url of the object for this policy
-// based upload.
-func (p *PostPolicy) SetSuccessActionRedirect(redirect string) error {
- if strings.TrimSpace(redirect) == "" || redirect == "" {
- return errInvalidArgument("Redirect is empty")
- }
- policyCond := policyCondition{
- matchType: "eq",
- condition: "$success_action_redirect",
- value: redirect,
- }
- if err := p.addNewPolicy(policyCond); err != nil {
- return err
- }
- p.formData["success_action_redirect"] = redirect
- return nil
-}
-
-// SetSuccessStatusAction - Sets the status success code of the object for this policy
-// based upload.
-func (p *PostPolicy) SetSuccessStatusAction(status string) error {
- if strings.TrimSpace(status) == "" || status == "" {
- return errInvalidArgument("Status is empty")
- }
- policyCond := policyCondition{
- matchType: "eq",
- condition: "$success_action_status",
- value: status,
- }
- if err := p.addNewPolicy(policyCond); err != nil {
- return err
- }
- p.formData["success_action_status"] = status
- return nil
-}
-
-// SetUserMetadata - Set user metadata as a key/value couple.
-// Can be retrieved through a HEAD request or an event.
-func (p *PostPolicy) SetUserMetadata(key string, value string) error {
- if strings.TrimSpace(key) == "" || key == "" {
- return errInvalidArgument("Key is empty")
- }
- if strings.TrimSpace(value) == "" || value == "" {
- return errInvalidArgument("Value is empty")
- }
- headerName := fmt.Sprintf("x-amz-meta-%s", key)
- policyCond := policyCondition{
- matchType: "eq",
- condition: fmt.Sprintf("$%s", headerName),
- value: value,
- }
- if err := p.addNewPolicy(policyCond); err != nil {
- return err
- }
- p.formData[headerName] = value
- return nil
-}
-
-// SetUserData - Set user data as a key/value couple.
-// Can be retrieved through a HEAD request or an event.
-func (p *PostPolicy) SetUserData(key string, value string) error {
- if key == "" {
- return errInvalidArgument("Key is empty")
- }
- if value == "" {
- return errInvalidArgument("Value is empty")
- }
- headerName := fmt.Sprintf("x-amz-%s", key)
- policyCond := policyCondition{
- matchType: "eq",
- condition: fmt.Sprintf("$%s", headerName),
- value: value,
- }
- if err := p.addNewPolicy(policyCond); err != nil {
- return err
- }
- p.formData[headerName] = value
- return nil
-}
-
-// addNewPolicy - internal helper to validate adding new policies.
-func (p *PostPolicy) addNewPolicy(policyCond policyCondition) error {
- if policyCond.matchType == "" || policyCond.condition == "" || policyCond.value == "" {
- return errInvalidArgument("Policy fields are empty.")
- }
- p.conditions = append(p.conditions, policyCond)
- return nil
-}
-
-// String function for printing policy in json formatted string.
-func (p PostPolicy) String() string {
- return string(p.marshalJSON())
-}
-
-// marshalJSON - Provides Marshaled JSON in bytes.
-func (p PostPolicy) marshalJSON() []byte {
- expirationStr := `"expiration":"` + p.expiration.Format(expirationDateFormat) + `"`
- var conditionsStr string
- conditions := []string{}
- for _, po := range p.conditions {
- conditions = append(conditions, fmt.Sprintf("[\"%s\",\"%s\",\"%s\"]", po.matchType, po.condition, po.value))
- }
- if p.contentLengthRange.min != 0 || p.contentLengthRange.max != 0 {
- conditions = append(conditions, fmt.Sprintf("[\"content-length-range\", %d, %d]",
- p.contentLengthRange.min, p.contentLengthRange.max))
- }
- if len(conditions) > 0 {
- conditionsStr = `"conditions":[` + strings.Join(conditions, ",") + "]"
- }
- retStr := "{"
- retStr = retStr + expirationStr + ","
- retStr = retStr + conditionsStr
- retStr = retStr + "}"
- return []byte(retStr)
-}
-
-// base64 - Produces base64 of PostPolicy's Marshaled json.
-func (p PostPolicy) base64() string {
- return base64.StdEncoding.EncodeToString(p.marshalJSON())
-}
-
-// errInvalidArgument - Invalid argument response.
-func errInvalidArgument(message string) error {
- return s3err.RESTErrorResponse{
- StatusCode: http.StatusBadRequest,
- Code: "InvalidArgument",
- Message: message,
- RequestID: "client",
- }
-}
diff --git a/weed/s3api/policy/postpolicyform_test.go b/weed/s3api/policy/postpolicyform_test.go
deleted file mode 100644
index 1a9d78b0e..000000000
--- a/weed/s3api/policy/postpolicyform_test.go
+++ /dev/null
@@ -1,106 +0,0 @@
-package policy
-
-/*
- * MinIO Cloud Storage, (C) 2016 MinIO, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-import (
- "encoding/base64"
- "fmt"
- "net/http"
- "testing"
- "time"
-)
-
-// Test Post Policy parsing and checking conditions
-func TestPostPolicyForm(t *testing.T) {
- pp := NewPostPolicy()
- pp.SetBucket("testbucket")
- pp.SetContentType("image/jpeg")
- pp.SetUserMetadata("uuid", "14365123651274")
- pp.SetKeyStartsWith("user/user1/filename")
- pp.SetContentLengthRange(1048579, 10485760)
- pp.SetSuccessStatusAction("201")
-
- type testCase struct {
- Bucket string
- Key string
- XAmzDate string
- XAmzAlgorithm string
- XAmzCredential string
- XAmzMetaUUID string
- ContentType string
- SuccessActionStatus string
- Policy string
- Expired bool
- expectedErr error
- }
-
- testCases := []testCase{
- // Everything is fine with this test
- {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "14365123651274", SuccessActionStatus: "201", XAmzCredential: "KVGKMDUQ23TCZXTLTHLP/20160727/us-east-1/s3/aws4_request", XAmzDate: "20160727T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: nil},
- // Expired policy document
- {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "14365123651274", SuccessActionStatus: "201", XAmzCredential: "KVGKMDUQ23TCZXTLTHLP/20160727/us-east-1/s3/aws4_request", XAmzDate: "20160727T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", Expired: true, expectedErr: fmt.Errorf("Invalid according to Policy: Policy expired")},
- // Different AMZ date
- {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "14365123651274", XAmzDate: "2017T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")},
- // Key which doesn't start with user/user1/filename
- {Bucket: "testbucket", Key: "myfile.txt", XAmzDate: "20160727T000000Z", XAmzMetaUUID: "14365123651274", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")},
- // Incorrect bucket name.
- {Bucket: "incorrect", Key: "user/user1/filename/myfile.txt", XAmzMetaUUID: "14365123651274", XAmzDate: "20160727T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")},
- // Incorrect key name
- {Bucket: "testbucket", Key: "incorrect", XAmzDate: "20160727T000000Z", XAmzMetaUUID: "14365123651274", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")},
- // Incorrect date
- {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "14365123651274", XAmzDate: "incorrect", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")},
- // Incorrect ContentType
- {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "14365123651274", XAmzDate: "20160727T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "incorrect", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed")},
- // Incorrect Metadata
- {Bucket: "testbucket", Key: "user/user1/filename/${filename}/myfile.txt", XAmzMetaUUID: "151274", SuccessActionStatus: "201", XAmzCredential: "KVGKMDUQ23TCZXTLTHLP/20160727/us-east-1/s3/aws4_request", XAmzDate: "20160727T000000Z", XAmzAlgorithm: "AWS4-HMAC-SHA256", ContentType: "image/jpeg", expectedErr: fmt.Errorf("Invalid according to Policy: Policy Condition failed: [eq, $x-amz-meta-uuid, 14365123651274]")},
- }
- // Validate all the test cases.
- for i, tt := range testCases {
- formValues := make(http.Header)
- formValues.Set("Bucket", tt.Bucket)
- formValues.Set("Key", tt.Key)
- formValues.Set("Content-Type", tt.ContentType)
- formValues.Set("X-Amz-Date", tt.XAmzDate)
- formValues.Set("X-Amz-Meta-Uuid", tt.XAmzMetaUUID)
- formValues.Set("X-Amz-Algorithm", tt.XAmzAlgorithm)
- formValues.Set("X-Amz-Credential", tt.XAmzCredential)
- if tt.Expired {
- // Expired already.
- pp.SetExpires(time.Now().UTC().AddDate(0, 0, -10))
- } else {
- // Expires in 10 days.
- pp.SetExpires(time.Now().UTC().AddDate(0, 0, 10))
- }
-
- formValues.Set("Policy", base64.StdEncoding.EncodeToString([]byte(pp.String())))
- formValues.Set("Success_action_status", tt.SuccessActionStatus)
- policyBytes, err := base64.StdEncoding.DecodeString(base64.StdEncoding.EncodeToString([]byte(pp.String())))
- if err != nil {
- t.Fatal(err)
- }
-
- postPolicyForm, err := ParsePostPolicyForm(string(policyBytes))
- if err != nil {
- t.Fatal(err)
- }
-
- err = CheckPostPolicy(formValues, postPolicyForm)
- if err != nil && tt.expectedErr != nil && err.Error() != tt.expectedErr.Error() {
- t.Fatalf("Test %d:, Expected %s, got %s", i+1, tt.expectedErr.Error(), err.Error())
- }
- }
-}
diff --git a/weed/s3api/policy_engine/conditions.go b/weed/s3api/policy_engine/conditions.go
index b32f11594..af55b06c2 100644
--- a/weed/s3api/policy_engine/conditions.go
+++ b/weed/s3api/policy_engine/conditions.go
@@ -125,22 +125,6 @@ func (c *NormalizedValueCache) evictLeastRecentlyUsed() {
delete(c.cache, tail.key)
}
-// Clear clears all cached values
-func (c *NormalizedValueCache) Clear() {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.cache = make(map[string]*LRUNode)
- c.head.next = c.tail
- c.tail.prev = c.head
-}
-
-// GetStats returns cache statistics
-func (c *NormalizedValueCache) GetStats() (size int, maxSize int) {
- c.mu.RLock()
- defer c.mu.RUnlock()
- return len(c.cache), c.maxSize
-}
-
// Global cache instance with size limit
var normalizedValueCache = NewNormalizedValueCache(1000)
@@ -769,34 +753,3 @@ func EvaluateConditions(conditions PolicyConditions, contextValues map[string][]
return true
}
-
-// EvaluateConditionsLegacy evaluates conditions using the old interface{} format for backward compatibility
-// objectEntry is the object's metadata from entry.Extended (can be nil)
-func EvaluateConditionsLegacy(conditions map[string]interface{}, contextValues map[string][]string, objectEntry map[string][]byte) bool {
- if len(conditions) == 0 {
- return true // No conditions means always true
- }
-
- for operator, conditionMap := range conditions {
- conditionEvaluator, err := GetConditionEvaluator(operator)
- if err != nil {
- glog.Warningf("Unsupported condition operator: %s", operator)
- continue
- }
-
- conditionMapTyped, ok := conditionMap.(map[string]interface{})
- if !ok {
- glog.Warningf("Invalid condition format for operator: %s", operator)
- continue
- }
-
- for key, value := range conditionMapTyped {
- contextVals := getConditionContextValue(key, contextValues, objectEntry)
- if !conditionEvaluator.Evaluate(value, contextVals) {
- return false // If any condition fails, the whole condition block fails
- }
- }
- }
-
- return true
-}
diff --git a/weed/s3api/policy_engine/engine.go b/weed/s3api/policy_engine/engine.go
index bf66ebfd2..d39b4b2ce 100644
--- a/weed/s3api/policy_engine/engine.go
+++ b/weed/s3api/policy_engine/engine.go
@@ -610,92 +610,6 @@ func BuildActionName(action string) string {
return fmt.Sprintf("s3:%s", action)
}
-// IsReadAction checks if an action is a read action
-func IsReadAction(action string) bool {
- readActions := []string{
- "s3:GetObject",
- "s3:GetObjectVersion",
- "s3:GetObjectAcl",
- "s3:GetObjectVersionAcl",
- "s3:GetObjectTagging",
- "s3:GetObjectVersionTagging",
- "s3:ListBucket",
- "s3:ListBucketVersions",
- "s3:GetBucketLocation",
- "s3:GetBucketVersioning",
- "s3:GetBucketAcl",
- "s3:GetBucketCors",
- "s3:GetBucketPolicy",
- "s3:GetBucketTagging",
- "s3:GetBucketNotification",
- "s3:GetBucketObjectLockConfiguration",
- "s3:GetObjectRetention",
- "s3:GetObjectLegalHold",
- }
-
- for _, readAction := range readActions {
- if action == readAction {
- return true
- }
- }
- return false
-}
-
-// IsWriteAction checks if an action is a write action
-func IsWriteAction(action string) bool {
- writeActions := []string{
- "s3:PutObject",
- "s3:PutObjectAcl",
- "s3:PutObjectTagging",
- "s3:DeleteObject",
- "s3:DeleteObjectVersion",
- "s3:DeleteObjectTagging",
- "s3:AbortMultipartUpload",
- "s3:ListMultipartUploads",
- "s3:ListParts",
- "s3:PutBucketAcl",
- "s3:PutBucketCors",
- "s3:PutBucketPolicy",
- "s3:PutBucketTagging",
- "s3:PutBucketNotification",
- "s3:PutBucketVersioning",
- "s3:DeleteBucketPolicy",
- "s3:DeleteBucketTagging",
- "s3:DeleteBucketCors",
- "s3:PutBucketObjectLockConfiguration",
- "s3:PutObjectRetention",
- "s3:PutObjectLegalHold",
- "s3:BypassGovernanceRetention",
- }
-
- for _, writeAction := range writeActions {
- if action == writeAction {
- return true
- }
- }
- return false
-}
-
-// GetBucketNameFromArn extracts bucket name from ARN
-func GetBucketNameFromArn(arn string) string {
- if strings.HasPrefix(arn, "arn:aws:s3:::") {
- parts := strings.SplitN(arn[13:], "/", 2)
- return parts[0]
- }
- return ""
-}
-
-// GetObjectNameFromArn extracts object name from ARN
-func GetObjectNameFromArn(arn string) string {
- if strings.HasPrefix(arn, "arn:aws:s3:::") {
- parts := strings.SplitN(arn[13:], "/", 2)
- if len(parts) > 1 {
- return parts[1]
- }
- }
- return ""
-}
-
// GetPolicyStatements returns all policy statements for a bucket
func (engine *PolicyEngine) GetPolicyStatements(bucketName string) []PolicyStatement {
engine.mutex.RLock()
diff --git a/weed/s3api/policy_engine/engine_test.go b/weed/s3api/policy_engine/engine_test.go
index 1ad8c434a..452c01775 100644
--- a/weed/s3api/policy_engine/engine_test.go
+++ b/weed/s3api/policy_engine/engine_test.go
@@ -6,7 +6,6 @@ import (
"testing"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
"github.com/seaweedfs/seaweedfs/weed/util/wildcard"
)
@@ -226,47 +225,6 @@ func TestConditionEvaluators(t *testing.T) {
}
}
-func TestConvertIdentityToPolicy(t *testing.T) {
- identityActions := []string{
- "Read:bucket1/*",
- "Write:bucket1/*",
- "Admin:bucket2",
- }
-
- policy, err := ConvertIdentityToPolicy(identityActions)
- if err != nil {
- t.Fatalf("Failed to convert identity to policy: %v", err)
- }
-
- if policy.Version != "2012-10-17" {
- t.Errorf("Expected version 2012-10-17, got %s", policy.Version)
- }
-
- if len(policy.Statement) != 3 {
- t.Errorf("Expected 3 statements, got %d", len(policy.Statement))
- }
-
- // Check first statement (Read)
- stmt := policy.Statement[0]
- if stmt.Effect != PolicyEffectAllow {
- t.Errorf("Expected Allow effect, got %s", stmt.Effect)
- }
-
- actions := normalizeToStringSlice(stmt.Action)
- // Read action now includes: GetObject, GetObjectVersion, ListBucket, ListBucketVersions,
- // GetObjectAcl, GetObjectVersionAcl, GetObjectTagging, GetObjectVersionTagging,
- // GetBucketLocation, GetBucketVersioning, GetBucketAcl, GetBucketCors, GetBucketTagging, GetBucketNotification
- if len(actions) != 14 {
- t.Errorf("Expected 14 read actions, got %d: %v", len(actions), actions)
- }
-
- resources := normalizeToStringSlice(stmt.Resource)
- // Read action now includes both bucket ARN (for ListBucket*) and object ARN (for GetObject*)
- if len(resources) != 2 {
- t.Errorf("Expected 2 resources (bucket and bucket/*), got %d: %v", len(resources), resources)
- }
-}
-
func TestPolicyValidation(t *testing.T) {
tests := []struct {
name string
@@ -794,41 +752,6 @@ func TestCompilePolicy(t *testing.T) {
}
}
-// TestNewPolicyBackedIAMWithLegacy tests the constructor overload
-func TestNewPolicyBackedIAMWithLegacy(t *testing.T) {
- // Mock legacy IAM
- mockLegacyIAM := &MockLegacyIAM{}
-
- // Test the new constructor
- policyBackedIAM := NewPolicyBackedIAMWithLegacy(mockLegacyIAM)
-
- // Verify that the legacy IAM is set
- if policyBackedIAM.legacyIAM != mockLegacyIAM {
- t.Errorf("Expected legacy IAM to be set, but it wasn't")
- }
-
- // Verify that the policy engine is initialized
- if policyBackedIAM.policyEngine == nil {
- t.Errorf("Expected policy engine to be initialized, but it wasn't")
- }
-
- // Compare with the traditional approach
- traditionalIAM := NewPolicyBackedIAM()
- traditionalIAM.SetLegacyIAM(mockLegacyIAM)
-
- // Both should behave the same
- if policyBackedIAM.legacyIAM != traditionalIAM.legacyIAM {
- t.Errorf("Expected both approaches to result in the same legacy IAM")
- }
-}
-
-// MockLegacyIAM implements the LegacyIAM interface for testing
-type MockLegacyIAM struct{}
-
-func (m *MockLegacyIAM) authRequest(r *http.Request, action Action) (Identity, s3err.ErrorCode) {
- return nil, s3err.ErrNone
-}
-
// TestExistingObjectTagCondition tests s3:ExistingObjectTag/ condition support
func TestExistingObjectTagCondition(t *testing.T) {
engine := NewPolicyEngine()
diff --git a/weed/s3api/policy_engine/integration.go b/weed/s3api/policy_engine/integration.go
deleted file mode 100644
index d1d36d02a..000000000
--- a/weed/s3api/policy_engine/integration.go
+++ /dev/null
@@ -1,642 +0,0 @@
-package policy_engine
-
-import (
- "fmt"
- "net/http"
- "strings"
-
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
-)
-
-// Action represents an S3 action - this should match the type in auth_credentials.go
-type Action string
-
-// Identity represents a user identity - this should match the type in auth_credentials.go
-type Identity interface {
- CanDo(action Action, bucket string, objectKey string) bool
-}
-
-// PolicyBackedIAM provides policy-based access control with fallback to legacy IAM
-type PolicyBackedIAM struct {
- policyEngine *PolicyEngine
- legacyIAM LegacyIAM // Interface to delegate to existing IAM system
-}
-
-// LegacyIAM interface for delegating to existing IAM implementation
-type LegacyIAM interface {
- authRequest(r *http.Request, action Action) (Identity, s3err.ErrorCode)
-}
-
-// NewPolicyBackedIAM creates a new policy-backed IAM system
-func NewPolicyBackedIAM() *PolicyBackedIAM {
- return &PolicyBackedIAM{
- policyEngine: NewPolicyEngine(),
- legacyIAM: nil, // Will be set when integrated with existing IAM
- }
-}
-
-// NewPolicyBackedIAMWithLegacy creates a new policy-backed IAM system with legacy IAM set
-func NewPolicyBackedIAMWithLegacy(legacyIAM LegacyIAM) *PolicyBackedIAM {
- return &PolicyBackedIAM{
- policyEngine: NewPolicyEngine(),
- legacyIAM: legacyIAM,
- }
-}
-
-// SetLegacyIAM sets the legacy IAM system for fallback
-func (p *PolicyBackedIAM) SetLegacyIAM(legacyIAM LegacyIAM) {
- p.legacyIAM = legacyIAM
-}
-
-// SetBucketPolicy sets the policy for a bucket
-func (p *PolicyBackedIAM) SetBucketPolicy(bucketName string, policyJSON string) error {
- return p.policyEngine.SetBucketPolicy(bucketName, policyJSON)
-}
-
-// GetBucketPolicy gets the policy for a bucket
-func (p *PolicyBackedIAM) GetBucketPolicy(bucketName string) (*PolicyDocument, error) {
- return p.policyEngine.GetBucketPolicy(bucketName)
-}
-
-// DeleteBucketPolicy deletes the policy for a bucket
-func (p *PolicyBackedIAM) DeleteBucketPolicy(bucketName string) error {
- return p.policyEngine.DeleteBucketPolicy(bucketName)
-}
-
-// CanDo checks if a principal can perform an action on a resource
-func (p *PolicyBackedIAM) CanDo(action, bucketName, objectName, principal string, r *http.Request) bool {
- // If there's a bucket policy, evaluate it
- if p.policyEngine.HasPolicyForBucket(bucketName) {
- result := p.policyEngine.EvaluatePolicyForRequest(bucketName, objectName, action, principal, r)
- switch result {
- case PolicyResultAllow:
- return true
- case PolicyResultDeny:
- return false
- case PolicyResultIndeterminate:
- // Fall through to legacy system
- }
- }
-
- // No bucket policy or indeterminate result, use legacy conversion
- return p.evaluateLegacyAction(action, bucketName, objectName, principal)
-}
-
-// evaluateLegacyAction evaluates actions using legacy identity-based rules
-func (p *PolicyBackedIAM) evaluateLegacyAction(action, bucketName, objectName, principal string) bool {
- // If we have a legacy IAM system to delegate to, use it
- if p.legacyIAM != nil {
- // Create a dummy request for legacy evaluation
- // In real implementation, this would use the actual request
- r := &http.Request{
- Header: make(http.Header),
- }
-
- // Convert the action string to Action type
- legacyAction := Action(action)
-
- // Use legacy IAM to check permission
- identity, errCode := p.legacyIAM.authRequest(r, legacyAction)
- if errCode != s3err.ErrNone {
- return false
- }
-
- // If we have an identity, check if it can perform the action
- if identity != nil {
- return identity.CanDo(legacyAction, bucketName, objectName)
- }
- }
-
- // No legacy IAM available, convert to policy and evaluate
- return p.evaluateUsingPolicyConversion(action, bucketName, objectName, principal)
-}
-
-// evaluateUsingPolicyConversion converts legacy action to policy and evaluates
-func (p *PolicyBackedIAM) evaluateUsingPolicyConversion(action, bucketName, objectName, principal string) bool {
- // For now, use a conservative approach for legacy actions
- // In a real implementation, this would integrate with the existing identity system
- glog.V(2).Infof("Legacy action evaluation for %s on %s/%s by %s", action, bucketName, objectName, principal)
-
- // Return false to maintain security until proper legacy integration is implemented
- // This ensures no unintended access is granted
- return false
-}
-
-// extractBucketAndPrefix extracts bucket name and prefix from a resource pattern.
-// Examples:
-//
-// "bucket" -> bucket="bucket", prefix=""
-// "bucket/*" -> bucket="bucket", prefix=""
-// "bucket/prefix/*" -> bucket="bucket", prefix="prefix"
-// "bucket/a/b/c/*" -> bucket="bucket", prefix="a/b/c"
-func extractBucketAndPrefix(pattern string) (string, string) {
- // Validate input
- pattern = strings.TrimSpace(pattern)
- if pattern == "" || pattern == "/" {
- return "", ""
- }
-
- // Remove trailing /* if present
- pattern = strings.TrimSuffix(pattern, "/*")
-
- // Remove a single trailing slash to avoid empty path segments
- if strings.HasSuffix(pattern, "/") {
- pattern = pattern[:len(pattern)-1]
- }
- if pattern == "" {
- return "", ""
- }
-
- // Split on the first /
- parts := strings.SplitN(pattern, "/", 2)
- bucket := strings.TrimSpace(parts[0])
- if bucket == "" {
- return "", ""
- }
-
- if len(parts) == 1 {
- // No slash, entire pattern is bucket
- return bucket, ""
- }
- // Has slash, first part is bucket, rest is prefix
- prefix := strings.Trim(parts[1], "/")
- return bucket, prefix
-}
-
-// buildObjectResourceArn generates ARNs for object-level access.
-// It properly handles both bucket-level (all objects) and prefix-level access.
-// Returns empty slice if bucket is invalid to prevent generating malformed ARNs.
-func buildObjectResourceArn(resourcePattern string) []string {
- bucket, prefix := extractBucketAndPrefix(resourcePattern)
- // If bucket is empty, the pattern is invalid; avoid generating malformed ARNs
- if bucket == "" {
- return []string{}
- }
- if prefix != "" {
- // Prefix-based access: restrict to objects under this prefix
- return []string{fmt.Sprintf("arn:aws:s3:::%s/%s/*", bucket, prefix)}
- }
- // Bucket-level access: all objects in bucket
- return []string{fmt.Sprintf("arn:aws:s3:::%s/*", bucket)}
-}
-
-// ConvertIdentityToPolicy converts a legacy identity action to an AWS policy
-func ConvertIdentityToPolicy(identityActions []string) (*PolicyDocument, error) {
- statements := make([]PolicyStatement, 0)
-
- for _, action := range identityActions {
- stmt, err := convertSingleAction(action)
- if err != nil {
- glog.Warningf("Failed to convert action %s: %v", action, err)
- continue
- }
- if stmt != nil {
- statements = append(statements, *stmt)
- }
- }
-
- if len(statements) == 0 {
- return nil, fmt.Errorf("no valid statements generated")
- }
-
- return &PolicyDocument{
- Version: PolicyVersion2012_10_17,
- Statement: statements,
- }, nil
-}
-
-// convertSingleAction converts a single legacy action to a policy statement.
-// action format: "ActionType:ResourcePattern" (e.g., "Write:bucket/prefix/*")
-func convertSingleAction(action string) (*PolicyStatement, error) {
- parts := strings.Split(action, ":")
- if len(parts) != 2 {
- return nil, fmt.Errorf("invalid action format: %s", action)
- }
-
- actionType := parts[0]
- resourcePattern := parts[1]
-
- var s3Actions []string
- var resources []string
-
- switch actionType {
- case "Read":
- // Read includes both object-level (GetObject, GetObjectAcl, GetObjectTagging, GetObjectVersions)
- // and bucket-level operations (ListBucket, GetBucketLocation, GetBucketVersioning, GetBucketCors, etc.)
- s3Actions = []string{
- "s3:GetObject",
- "s3:GetObjectVersion",
- "s3:GetObjectAcl",
- "s3:GetObjectVersionAcl",
- "s3:GetObjectTagging",
- "s3:GetObjectVersionTagging",
- "s3:ListBucket",
- "s3:ListBucketVersions",
- "s3:GetBucketLocation",
- "s3:GetBucketVersioning",
- "s3:GetBucketAcl",
- "s3:GetBucketCors",
- "s3:GetBucketTagging",
- "s3:GetBucketNotification",
- }
- bucket, _ := extractBucketAndPrefix(resourcePattern)
- objectResources := buildObjectResourceArn(resourcePattern)
- // Include both bucket ARN (for ListBucket* and Get*Bucket operations) and object ARNs (for GetObject* operations)
- if bucket != "" {
- resources = append([]string{fmt.Sprintf("arn:aws:s3:::%s", bucket)}, objectResources...)
- } else {
- resources = objectResources
- }
-
- case "Write":
- // Write includes object-level writes (PutObject, DeleteObject, PutObjectAcl, DeleteObjectVersion, DeleteObjectTagging, PutObjectTagging)
- // and bucket-level writes (PutBucketVersioning, PutBucketCors, DeleteBucketCors, PutBucketAcl, PutBucketTagging, DeleteBucketTagging, PutBucketNotification)
- // and multipart upload operations (AbortMultipartUpload, ListMultipartUploads, ListParts).
- // ListMultipartUploads and ListParts are included because they are part of the multipart upload workflow
- // and require Write permissions to be meaningful (no point listing uploads if you can't abort/complete them).
- s3Actions = []string{
- "s3:PutObject",
- "s3:PutObjectAcl",
- "s3:PutObjectTagging",
- "s3:DeleteObject",
- "s3:DeleteObjectVersion",
- "s3:DeleteObjectTagging",
- "s3:AbortMultipartUpload",
- "s3:ListMultipartUploads",
- "s3:ListParts",
- "s3:PutBucketAcl",
- "s3:PutBucketCors",
- "s3:PutBucketTagging",
- "s3:PutBucketNotification",
- "s3:PutBucketVersioning",
- "s3:DeleteBucketTagging",
- "s3:DeleteBucketCors",
- }
- bucket, _ := extractBucketAndPrefix(resourcePattern)
- objectResources := buildObjectResourceArn(resourcePattern)
- // Include bucket ARN so bucket-level write operations (e.g., PutBucketVersioning, PutBucketCors)
- // have the correct resource, while still allowing object-level writes.
- if bucket != "" {
- resources = append([]string{fmt.Sprintf("arn:aws:s3:::%s", bucket)}, objectResources...)
- } else {
- resources = objectResources
- }
-
- case "Admin":
- s3Actions = []string{"s3:*"}
- bucket, prefix := extractBucketAndPrefix(resourcePattern)
- if bucket == "" {
- // Invalid pattern, return error
- return nil, fmt.Errorf("Admin action requires a valid bucket name")
- }
- if prefix != "" {
- // Subpath admin access: restrict to objects under this prefix
- resources = []string{
- fmt.Sprintf("arn:aws:s3:::%s", bucket),
- fmt.Sprintf("arn:aws:s3:::%s/%s/*", bucket, prefix),
- }
- } else {
- // Bucket-level admin access: full bucket permissions
- resources = []string{
- fmt.Sprintf("arn:aws:s3:::%s", bucket),
- fmt.Sprintf("arn:aws:s3:::%s/*", bucket),
- }
- }
-
- case "List":
- // List includes bucket listing operations and also ListAllMyBuckets
- s3Actions = []string{"s3:ListBucket", "s3:ListBucketVersions", "s3:ListAllMyBuckets"}
- // ListBucket actions only require bucket ARN, not object-level ARNs
- bucket, _ := extractBucketAndPrefix(resourcePattern)
- if bucket != "" {
- resources = []string{fmt.Sprintf("arn:aws:s3:::%s", bucket)}
- } else {
- // Invalid pattern, return empty resources to fail validation
- resources = []string{}
- }
-
- case "Tagging":
- // Tagging includes both object-level and bucket-level tagging operations
- s3Actions = []string{
- "s3:GetObjectTagging",
- "s3:PutObjectTagging",
- "s3:DeleteObjectTagging",
- "s3:GetBucketTagging",
- "s3:PutBucketTagging",
- "s3:DeleteBucketTagging",
- }
- bucket, _ := extractBucketAndPrefix(resourcePattern)
- objectResources := buildObjectResourceArn(resourcePattern)
- // Include bucket ARN so bucket-level tagging operations have the correct resource
- if bucket != "" {
- resources = append([]string{fmt.Sprintf("arn:aws:s3:::%s", bucket)}, objectResources...)
- } else {
- resources = objectResources
- }
-
- case "BypassGovernanceRetention":
- s3Actions = []string{"s3:BypassGovernanceRetention"}
- resources = buildObjectResourceArn(resourcePattern)
-
- case "GetObjectRetention":
- s3Actions = []string{"s3:GetObjectRetention"}
- resources = buildObjectResourceArn(resourcePattern)
-
- case "PutObjectRetention":
- s3Actions = []string{"s3:PutObjectRetention"}
- resources = buildObjectResourceArn(resourcePattern)
-
- case "GetObjectLegalHold":
- s3Actions = []string{"s3:GetObjectLegalHold"}
- resources = buildObjectResourceArn(resourcePattern)
-
- case "PutObjectLegalHold":
- s3Actions = []string{"s3:PutObjectLegalHold"}
- resources = buildObjectResourceArn(resourcePattern)
-
- case "GetBucketObjectLockConfiguration":
- s3Actions = []string{"s3:GetBucketObjectLockConfiguration"}
- bucket, _ := extractBucketAndPrefix(resourcePattern)
- if bucket != "" {
- resources = []string{fmt.Sprintf("arn:aws:s3:::%s", bucket)}
- } else {
- // Invalid pattern, return empty resources to fail validation
- resources = []string{}
- }
-
- case "PutBucketObjectLockConfiguration":
- s3Actions = []string{"s3:PutBucketObjectLockConfiguration"}
- bucket, _ := extractBucketAndPrefix(resourcePattern)
- if bucket != "" {
- resources = []string{fmt.Sprintf("arn:aws:s3:::%s", bucket)}
- } else {
- // Invalid pattern, return empty resources to fail validation
- resources = []string{}
- }
-
- default:
- return nil, fmt.Errorf("unknown action type: %s", actionType)
- }
-
- return &PolicyStatement{
- Effect: PolicyEffectAllow,
- Action: NewStringOrStringSlice(s3Actions...),
- Resource: NewStringOrStringSlicePtr(resources...),
- }, nil
-}
-
-// GetActionMappings returns the mapping of legacy actions to S3 actions
-func GetActionMappings() map[string][]string {
- return map[string][]string{
- "Read": {
- "s3:GetObject",
- "s3:GetObjectVersion",
- "s3:GetObjectAcl",
- "s3:GetObjectVersionAcl",
- "s3:GetObjectTagging",
- "s3:GetObjectVersionTagging",
- "s3:ListBucket",
- "s3:ListBucketVersions",
- "s3:GetBucketLocation",
- "s3:GetBucketVersioning",
- "s3:GetBucketAcl",
- "s3:GetBucketCors",
- "s3:GetBucketTagging",
- "s3:GetBucketNotification",
- },
- "Write": {
- "s3:PutObject",
- "s3:PutObjectAcl",
- "s3:PutObjectTagging",
- "s3:DeleteObject",
- "s3:DeleteObjectVersion",
- "s3:DeleteObjectTagging",
- "s3:AbortMultipartUpload",
- "s3:ListMultipartUploads",
- "s3:ListParts",
- "s3:PutBucketAcl",
- "s3:PutBucketCors",
- "s3:PutBucketTagging",
- "s3:PutBucketNotification",
- "s3:PutBucketVersioning",
- "s3:DeleteBucketTagging",
- "s3:DeleteBucketCors",
- },
- "Admin": {
- "s3:*",
- },
- "List": {
- "s3:ListBucket",
- "s3:ListBucketVersions",
- "s3:ListAllMyBuckets",
- },
- "Tagging": {
- "s3:GetObjectTagging",
- "s3:PutObjectTagging",
- "s3:DeleteObjectTagging",
- "s3:GetBucketTagging",
- "s3:PutBucketTagging",
- "s3:DeleteBucketTagging",
- },
- "BypassGovernanceRetention": {
- "s3:BypassGovernanceRetention",
- },
- "GetObjectRetention": {
- "s3:GetObjectRetention",
- },
- "PutObjectRetention": {
- "s3:PutObjectRetention",
- },
- "GetObjectLegalHold": {
- "s3:GetObjectLegalHold",
- },
- "PutObjectLegalHold": {
- "s3:PutObjectLegalHold",
- },
- "GetBucketObjectLockConfiguration": {
- "s3:GetBucketObjectLockConfiguration",
- },
- "PutBucketObjectLockConfiguration": {
- "s3:PutBucketObjectLockConfiguration",
- },
- }
-}
-
-// ValidateActionMapping validates that a legacy action can be mapped to S3 actions
-func ValidateActionMapping(action string) error {
- mappings := GetActionMappings()
-
- parts := strings.Split(action, ":")
- if len(parts) != 2 {
- return fmt.Errorf("invalid action format: %s, expected format: 'ActionType:Resource'", action)
- }
-
- actionType := parts[0]
- resource := parts[1]
-
- if _, exists := mappings[actionType]; !exists {
- return fmt.Errorf("unknown action type: %s", actionType)
- }
-
- if resource == "" {
- return fmt.Errorf("resource cannot be empty")
- }
-
- return nil
-}
-
-// ConvertLegacyActions converts an array of legacy actions to S3 actions
-func ConvertLegacyActions(legacyActions []string) ([]string, error) {
- mappings := GetActionMappings()
- s3Actions := make([]string, 0)
-
- for _, legacyAction := range legacyActions {
- if err := ValidateActionMapping(legacyAction); err != nil {
- return nil, err
- }
-
- parts := strings.Split(legacyAction, ":")
- actionType := parts[0]
-
- if actionType == "Admin" {
- // Admin gives all permissions, so we can just return s3:*
- return []string{"s3:*"}, nil
- }
-
- if mapped, exists := mappings[actionType]; exists {
- s3Actions = append(s3Actions, mapped...)
- }
- }
-
- // Remove duplicates
- uniqueActions := make([]string, 0)
- seen := make(map[string]bool)
- for _, action := range s3Actions {
- if !seen[action] {
- uniqueActions = append(uniqueActions, action)
- seen[action] = true
- }
- }
-
- return uniqueActions, nil
-}
-
-// GetResourcesFromLegacyAction extracts resources from a legacy action.
-// It delegates to convertSingleAction to ensure consistent resource ARN generation
-// across the codebase and avoid duplicating action-type-specific logic.
-func GetResourcesFromLegacyAction(legacyAction string) ([]string, error) {
- stmt, err := convertSingleAction(legacyAction)
- if err != nil {
- return nil, err
- }
- return stmt.Resource.Strings(), nil
-}
-
-// CreatePolicyFromLegacyIdentity creates a policy document from legacy identity actions
-func CreatePolicyFromLegacyIdentity(identityName string, actions []string) (*PolicyDocument, error) {
- statements := make([]PolicyStatement, 0)
-
- // Group actions by resource pattern
- resourceActions := make(map[string][]string)
-
- for _, action := range actions {
- // Validate action format before processing
- if err := ValidateActionMapping(action); err != nil {
- glog.Warningf("Skipping invalid action %q for identity %q: %v", action, identityName, err)
- continue
- }
-
- parts := strings.Split(action, ":")
- if len(parts) != 2 {
- continue
- }
-
- resourcePattern := parts[1]
- actionType := parts[0]
-
- if _, exists := resourceActions[resourcePattern]; !exists {
- resourceActions[resourcePattern] = make([]string, 0)
- }
- resourceActions[resourcePattern] = append(resourceActions[resourcePattern], actionType)
- }
-
- // Create statements for each resource pattern
- for resourcePattern, actionTypes := range resourceActions {
- s3Actions := make([]string, 0)
- resourceSet := make(map[string]struct{})
-
- // Collect S3 actions and aggregate resource ARNs from all action types.
- // Different action types have different resource ARN requirements:
- // - List: bucket-level ARNs only
- // - Read/Write/Tagging: object-level ARNs
- // - Admin: full bucket access
- // We must merge all required ARNs for the combined policy statement.
- for _, actionType := range actionTypes {
- if actionType == "Admin" {
- s3Actions = []string{"s3:*"}
- // Admin action determines the resources, so we can break after processing it.
- res, err := GetResourcesFromLegacyAction(fmt.Sprintf("Admin:%s", resourcePattern))
- if err != nil {
- glog.Warningf("Failed to get resources for Admin action on %s: %v", resourcePattern, err)
- resourceSet = nil // Invalidate to skip this statement
- break
- }
- for _, r := range res {
- resourceSet[r] = struct{}{}
- }
- break
- }
-
- if mapped, exists := GetActionMappings()[actionType]; exists {
- s3Actions = append(s3Actions, mapped...)
- res, err := GetResourcesFromLegacyAction(fmt.Sprintf("%s:%s", actionType, resourcePattern))
- if err != nil {
- glog.Warningf("Failed to get resources for %s action on %s: %v", actionType, resourcePattern, err)
- resourceSet = nil // Invalidate to skip this statement
- break
- }
- for _, r := range res {
- resourceSet[r] = struct{}{}
- }
- }
- }
-
- if resourceSet == nil || len(s3Actions) == 0 {
- continue
- }
-
- resources := make([]string, 0, len(resourceSet))
- for r := range resourceSet {
- resources = append(resources, r)
- }
-
- statement := PolicyStatement{
- Sid: fmt.Sprintf("%s-%s", identityName, strings.ReplaceAll(resourcePattern, "/", "-")),
- Effect: PolicyEffectAllow,
- Action: NewStringOrStringSlice(s3Actions...),
- Resource: NewStringOrStringSlicePtr(resources...),
- }
-
- statements = append(statements, statement)
- }
-
- if len(statements) == 0 {
- return nil, fmt.Errorf("no valid statements generated for identity %s", identityName)
- }
-
- return &PolicyDocument{
- Version: PolicyVersion2012_10_17,
- Statement: statements,
- }, nil
-}
-
-// HasPolicyForBucket checks if a bucket has a policy
-func (p *PolicyBackedIAM) HasPolicyForBucket(bucketName string) bool {
- return p.policyEngine.HasPolicyForBucket(bucketName)
-}
-
-// GetPolicyEngine returns the underlying policy engine
-func (p *PolicyBackedIAM) GetPolicyEngine() *PolicyEngine {
- return p.policyEngine
-}
diff --git a/weed/s3api/policy_engine/integration_test.go b/weed/s3api/policy_engine/integration_test.go
deleted file mode 100644
index 6e74e51cb..000000000
--- a/weed/s3api/policy_engine/integration_test.go
+++ /dev/null
@@ -1,373 +0,0 @@
-package policy_engine
-
-import (
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-// TestConvertSingleActionDeleteObject tests support for s3:DeleteObject action (Issue #7864)
-func TestConvertSingleActionDeleteObject(t *testing.T) {
- // Test that Write action includes DeleteObject S3 action
- stmt, err := convertSingleAction("Write:bucket")
- assert.NoError(t, err)
- assert.NotNil(t, stmt)
-
- // Check that s3:DeleteObject is included in the actions
- actions := stmt.Action.Strings()
- assert.Contains(t, actions, "s3:DeleteObject", "Write action should include s3:DeleteObject")
- assert.Contains(t, actions, "s3:PutObject", "Write action should include s3:PutObject")
-}
-
-// TestConvertSingleActionSubpath tests subpath handling for legacy actions (Issue #7864)
-func TestConvertSingleActionSubpath(t *testing.T) {
- testCases := []struct {
- name string
- action string
- expectedActions []string
- expectedResources []string
- description string
- }{
- {
- name: "Write_on_bucket",
- action: "Write:mybucket",
- expectedActions: []string{"s3:PutObject", "s3:DeleteObject", "s3:PutObjectAcl", "s3:DeleteObjectVersion", "s3:PutObjectTagging", "s3:DeleteObjectTagging", "s3:AbortMultipartUpload", "s3:ListMultipartUploads", "s3:ListParts", "s3:PutBucketAcl", "s3:PutBucketCors", "s3:PutBucketTagging", "s3:PutBucketNotification", "s3:PutBucketVersioning", "s3:DeleteBucketTagging", "s3:DeleteBucketCors"},
- expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/*"},
- description: "Write permission on bucket should include bucket and object ARNs",
- },
- {
- name: "Write_on_bucket_with_wildcard",
- action: "Write:mybucket/*",
- expectedActions: []string{"s3:PutObject", "s3:DeleteObject", "s3:PutObjectAcl", "s3:DeleteObjectVersion", "s3:PutObjectTagging", "s3:DeleteObjectTagging", "s3:AbortMultipartUpload", "s3:ListMultipartUploads", "s3:ListParts", "s3:PutBucketAcl", "s3:PutBucketCors", "s3:PutBucketTagging", "s3:PutBucketNotification", "s3:PutBucketVersioning", "s3:DeleteBucketTagging", "s3:DeleteBucketCors"},
- expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/*"},
- description: "Write permission with /* should include bucket and object ARNs",
- },
- {
- name: "Write_on_subpath",
- action: "Write:mybucket/sub_path/*",
- expectedActions: []string{"s3:PutObject", "s3:DeleteObject", "s3:PutObjectAcl", "s3:DeleteObjectVersion", "s3:PutObjectTagging", "s3:DeleteObjectTagging", "s3:AbortMultipartUpload", "s3:ListMultipartUploads", "s3:ListParts", "s3:PutBucketAcl", "s3:PutBucketCors", "s3:PutBucketTagging", "s3:PutBucketNotification", "s3:PutBucketVersioning", "s3:DeleteBucketTagging", "s3:DeleteBucketCors"},
- expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/sub_path/*"},
- description: "Write permission on subpath should include bucket and subpath objects ARNs",
- },
- {
- name: "Read_on_subpath",
- action: "Read:mybucket/documents/*",
- expectedActions: []string{"s3:GetObject", "s3:GetObjectVersion", "s3:ListBucket", "s3:ListBucketVersions", "s3:GetObjectAcl", "s3:GetObjectVersionAcl", "s3:GetObjectTagging", "s3:GetObjectVersionTagging", "s3:GetBucketLocation", "s3:GetBucketVersioning", "s3:GetBucketAcl", "s3:GetBucketCors", "s3:GetBucketTagging", "s3:GetBucketNotification"},
- expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/documents/*"},
- description: "Read permission on subpath should include bucket ARN and subpath objects",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- stmt, err := convertSingleAction(tc.action)
- assert.NoError(t, err, tc.description)
- assert.NotNil(t, stmt)
-
- // Check actions
- actions := stmt.Action.Strings()
- for _, expectedAction := range tc.expectedActions {
- assert.Contains(t, actions, expectedAction,
- "Action %s should be included for %s", expectedAction, tc.action)
- }
-
- // Check resources - verify all expected resources are present
- resources := stmt.Resource.Strings()
- assert.ElementsMatch(t, resources, tc.expectedResources,
- "Resources should match exactly for %s. Got %v, expected %v", tc.action, resources, tc.expectedResources)
- })
- }
-}
-
-// TestConvertSingleActionSubpathDeleteAllowed tests that DeleteObject works on subpaths
-func TestConvertSingleActionSubpathDeleteAllowed(t *testing.T) {
- // This test specifically addresses Issue #7864 part 1:
- // "when a user is granted permission to a subpath, eg s3.configure -user someuser
- // -actions Write -buckets some_bucket/sub_path/* -apply
- // the user will only be able to put, but not delete object under somebucket/sub_path"
-
- stmt, err := convertSingleAction("Write:some_bucket/sub_path/*")
- assert.NoError(t, err)
-
- // The fix: s3:DeleteObject should be in the allowed actions
- actions := stmt.Action.Strings()
- assert.Contains(t, actions, "s3:DeleteObject",
- "Write permission on subpath should allow deletion of objects in that path")
-
- // The resource should be restricted to the subpath
- resources := stmt.Resource.Strings()
- assert.Contains(t, resources, "arn:aws:s3:::some_bucket/sub_path/*",
- "Delete permission should apply to objects under the subpath")
-}
-
-// TestConvertSingleActionNestedPaths tests deeply nested paths
-func TestConvertSingleActionNestedPaths(t *testing.T) {
- testCases := []struct {
- action string
- expectedResources []string
- }{
- {
- action: "Write:bucket/a/b/c/*",
- expectedResources: []string{"arn:aws:s3:::bucket", "arn:aws:s3:::bucket/a/b/c/*"},
- },
- {
- action: "Read:bucket/data/documents/2024/*",
- expectedResources: []string{"arn:aws:s3:::bucket", "arn:aws:s3:::bucket/data/documents/2024/*"},
- },
- }
-
- for _, tc := range testCases {
- stmt, err := convertSingleAction(tc.action)
- assert.NoError(t, err)
-
- resources := stmt.Resource.Strings()
- assert.ElementsMatch(t, resources, tc.expectedResources)
- }
-}
-
-// TestGetResourcesFromLegacyAction tests that GetResourcesFromLegacyAction generates
-// action-appropriate resources consistent with convertSingleAction
-func TestGetResourcesFromLegacyAction(t *testing.T) {
- testCases := []struct {
- name string
- action string
- expectedResources []string
- description string
- }{
- // List actions - bucket-only (no object ARNs)
- {
- name: "List_on_bucket",
- action: "List:mybucket",
- expectedResources: []string{"arn:aws:s3:::mybucket"},
- description: "List action should only have bucket ARN",
- },
- {
- name: "List_on_bucket_with_wildcard",
- action: "List:mybucket/*",
- expectedResources: []string{"arn:aws:s3:::mybucket"},
- description: "List action should only have bucket ARN regardless of wildcard",
- },
- // Read actions - bucket and object-level ARNs (includes List* and Get* operations)
- {
- name: "Read_on_bucket",
- action: "Read:mybucket",
- expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/*"},
- description: "Read action should have both bucket and object ARNs",
- },
- {
- name: "Read_on_subpath",
- action: "Read:mybucket/documents/*",
- expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/documents/*"},
- description: "Read action on subpath should have bucket ARN and object ARN for subpath",
- },
- // Write actions - bucket and object ARNs (includes bucket-level operations)
- {
- name: "Write_on_subpath",
- action: "Write:mybucket/sub_path/*",
- expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/sub_path/*"},
- description: "Write action should have bucket and object ARNs",
- },
- // Admin actions - both bucket and object ARNs
- {
- name: "Admin_on_bucket",
- action: "Admin:mybucket",
- expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/*"},
- description: "Admin action should have both bucket and object ARNs",
- },
- {
- name: "Admin_on_subpath",
- action: "Admin:mybucket/admin/section/*",
- expectedResources: []string{"arn:aws:s3:::mybucket", "arn:aws:s3:::mybucket/admin/section/*"},
- description: "Admin action on subpath should restrict to subpath, preventing privilege escalation",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- resources, err := GetResourcesFromLegacyAction(tc.action)
- assert.NoError(t, err, tc.description)
- assert.ElementsMatch(t, resources, tc.expectedResources,
- "Resources should match expected. Got %v, expected %v", resources, tc.expectedResources)
-
- // Also verify consistency with convertSingleAction where applicable
- stmt, err := convertSingleAction(tc.action)
- assert.NoError(t, err)
-
- stmtResources := stmt.Resource.Strings()
- assert.ElementsMatch(t, resources, stmtResources,
- "GetResourcesFromLegacyAction should match convertSingleAction resources for %s", tc.action)
- })
- }
-}
-
-// TestExtractBucketAndPrefixEdgeCases validates edge case handling in extractBucketAndPrefix
-func TestExtractBucketAndPrefixEdgeCases(t *testing.T) {
- testCases := []struct {
- name string
- pattern string
- expectedBucket string
- expectedPrefix string
- description string
- }{
- {
- name: "Empty string",
- pattern: "",
- expectedBucket: "",
- expectedPrefix: "",
- description: "Empty pattern should return empty strings",
- },
- {
- name: "Whitespace only",
- pattern: " ",
- expectedBucket: "",
- expectedPrefix: "",
- description: "Whitespace-only pattern should return empty strings",
- },
- {
- name: "Slash only",
- pattern: "/",
- expectedBucket: "",
- expectedPrefix: "",
- description: "Slash-only pattern should return empty strings",
- },
- {
- name: "Double slash prefix",
- pattern: "bucket//prefix/*",
- expectedBucket: "bucket",
- expectedPrefix: "prefix",
- description: "Double slash should be normalized (trailing slashes removed)",
- },
- {
- name: "Normal bucket",
- pattern: "mybucket",
- expectedBucket: "mybucket",
- expectedPrefix: "",
- description: "Bucket-only pattern should work correctly",
- },
- {
- name: "Bucket with prefix",
- pattern: "mybucket/myprefix/*",
- expectedBucket: "mybucket",
- expectedPrefix: "myprefix",
- description: "Bucket with prefix should be parsed correctly",
- },
- {
- name: "Nested prefix",
- pattern: "mybucket/a/b/c/*",
- expectedBucket: "mybucket",
- expectedPrefix: "a/b/c",
- description: "Nested prefix should be preserved",
- },
- {
- name: "Bucket with trailing slash",
- pattern: "mybucket/",
- expectedBucket: "mybucket",
- expectedPrefix: "",
- description: "Trailing slash on bucket should be normalized",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- bucket, prefix := extractBucketAndPrefix(tc.pattern)
- assert.Equal(t, tc.expectedBucket, bucket, tc.description)
- assert.Equal(t, tc.expectedPrefix, prefix, tc.description)
- })
- }
-}
-
-// TestCreatePolicyFromLegacyIdentityMultipleActions validates correct resource ARN aggregation
-// when multiple action types target the same resource pattern
-func TestCreatePolicyFromLegacyIdentityMultipleActions(t *testing.T) {
- testCases := []struct {
- name string
- identityName string
- actions []string
- expectedStatements int
- expectedActionsInStmt1 []string
- expectedResourcesInStmt1 []string
- description string
- }{
- {
- name: "List_and_Write_on_subpath",
- identityName: "data-manager",
- actions: []string{"List:mybucket/data/*", "Write:mybucket/data/*"},
- expectedStatements: 1,
- expectedActionsInStmt1: []string{
- "s3:ListBucket", "s3:ListBucketVersions", "s3:ListAllMyBuckets",
- "s3:PutObject", "s3:DeleteObject", "s3:PutObjectAcl", "s3:DeleteObjectVersion",
- "s3:PutObjectTagging", "s3:DeleteObjectTagging", "s3:AbortMultipartUpload",
- "s3:ListMultipartUploads", "s3:ListParts", "s3:PutBucketAcl", "s3:PutBucketCors",
- "s3:PutBucketTagging", "s3:PutBucketNotification", "s3:PutBucketVersioning",
- "s3:DeleteBucketTagging", "s3:DeleteBucketCors",
- },
- expectedResourcesInStmt1: []string{
- "arn:aws:s3:::mybucket", // From List and Write actions
- "arn:aws:s3:::mybucket/data/*", // From Write action
- },
- description: "List + Write on same subpath should aggregate all actions and both bucket and object ARNs",
- },
- {
- name: "Read_and_Tagging_on_bucket",
- identityName: "tag-reader",
- actions: []string{"Read:mybucket", "Tagging:mybucket"},
- expectedStatements: 1,
- expectedActionsInStmt1: []string{
- "s3:GetObject", "s3:GetObjectVersion",
- "s3:ListBucket", "s3:ListBucketVersions",
- "s3:GetObjectAcl", "s3:GetObjectVersionAcl",
- "s3:GetObjectTagging", "s3:GetObjectVersionTagging",
- "s3:PutObjectTagging", "s3:DeleteObjectTagging",
- "s3:GetBucketLocation", "s3:GetBucketVersioning",
- "s3:GetBucketAcl", "s3:GetBucketCors", "s3:GetBucketTagging",
- "s3:GetBucketNotification", "s3:PutBucketTagging", "s3:DeleteBucketTagging",
- },
- expectedResourcesInStmt1: []string{
- "arn:aws:s3:::mybucket",
- "arn:aws:s3:::mybucket/*",
- },
- description: "Read + Tagging on same bucket should aggregate all bucket and object-level actions and ARNs",
- },
- {
- name: "Admin_with_other_actions",
- identityName: "admin-user",
- actions: []string{"Admin:mybucket/admin/*", "Write:mybucket/admin/*"},
- expectedStatements: 1,
- expectedActionsInStmt1: []string{"s3:*"},
- expectedResourcesInStmt1: []string{
- "arn:aws:s3:::mybucket",
- "arn:aws:s3:::mybucket/admin/*",
- },
- description: "Admin action should dominate and set s3:*, other actions still processed for resources",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- policy, err := CreatePolicyFromLegacyIdentity(tc.identityName, tc.actions)
- assert.NoError(t, err, tc.description)
- assert.NotNil(t, policy)
-
- // Check statement count
- assert.Equal(t, tc.expectedStatements, len(policy.Statement),
- "Expected %d statement(s), got %d", tc.expectedStatements, len(policy.Statement))
-
- if tc.expectedStatements > 0 {
- stmt := policy.Statement[0]
-
- // Check actions
- actualActions := stmt.Action.Strings()
- for _, expectedAction := range tc.expectedActionsInStmt1 {
- assert.Contains(t, actualActions, expectedAction,
- "Action %s should be included in statement", expectedAction)
- }
-
- // Check resources - all expected resources should be present
- actualResources := stmt.Resource.Strings()
- assert.ElementsMatch(t, tc.expectedResourcesInStmt1, actualResources,
- "Statement should aggregate all required resource ARNs. Got %v, expected %v",
- actualResources, tc.expectedResourcesInStmt1)
- }
- })
- }
-}
diff --git a/weed/s3api/policy_engine/types.go b/weed/s3api/policy_engine/types.go
index 862023b34..f1623ff15 100644
--- a/weed/s3api/policy_engine/types.go
+++ b/weed/s3api/policy_engine/types.go
@@ -490,11 +490,6 @@ func GetBucketFromResource(resource string) string {
return ""
}
-// IsObjectResource checks if resource refers to objects
-func IsObjectResource(resource string) bool {
- return strings.Contains(resource, "/")
-}
-
// MatchesAction checks if an action matches any of the compiled action matchers.
// It also implicitly grants multipart upload actions if s3:PutObject is allowed,
// since multipart upload is an implementation detail of putting objects.
diff --git a/weed/s3api/s3_bucket_encryption.go b/weed/s3api/s3_bucket_encryption.go
index 5a9fb7499..10901f8ac 100644
--- a/weed/s3api/s3_bucket_encryption.go
+++ b/weed/s3api/s3_bucket_encryption.go
@@ -288,70 +288,3 @@ func (s3a *S3ApiServer) GetDefaultEncryptionHeaders(bucket string) map[string]st
return headers
}
-
-// IsDefaultEncryptionEnabled checks if default encryption is enabled for a configuration
-func IsDefaultEncryptionEnabled(config *s3_pb.EncryptionConfiguration) bool {
- return config != nil && config.SseAlgorithm != ""
-}
-
-// GetDefaultEncryptionHeaders generates default encryption headers from configuration
-func GetDefaultEncryptionHeaders(config *s3_pb.EncryptionConfiguration) map[string]string {
- if config == nil || config.SseAlgorithm == "" {
- return nil
- }
-
- headers := make(map[string]string)
- headers[s3_constants.AmzServerSideEncryption] = config.SseAlgorithm
-
- if config.SseAlgorithm == "aws:kms" && config.KmsKeyId != "" {
- headers[s3_constants.AmzServerSideEncryptionAwsKmsKeyId] = config.KmsKeyId
- }
-
- return headers
-}
-
-// encryptionConfigFromXMLBytes parses XML bytes to encryption configuration
-func encryptionConfigFromXMLBytes(xmlBytes []byte) (*s3_pb.EncryptionConfiguration, error) {
- var xmlConfig ServerSideEncryptionConfiguration
- if err := xml.Unmarshal(xmlBytes, &xmlConfig); err != nil {
- return nil, err
- }
-
- // Validate namespace - should be empty or the standard AWS namespace
- if xmlConfig.XMLName.Space != "" && xmlConfig.XMLName.Space != "http://s3.amazonaws.com/doc/2006-03-01/" {
- return nil, fmt.Errorf("invalid XML namespace: %s", xmlConfig.XMLName.Space)
- }
-
- // Validate the configuration
- if len(xmlConfig.Rules) == 0 {
- return nil, fmt.Errorf("encryption configuration must have at least one rule")
- }
-
- rule := xmlConfig.Rules[0]
- if rule.ApplyServerSideEncryptionByDefault.SSEAlgorithm == "" {
- return nil, fmt.Errorf("encryption algorithm is required")
- }
-
- // Validate algorithm
- validAlgorithms := map[string]bool{
- "AES256": true,
- "aws:kms": true,
- }
-
- if !validAlgorithms[rule.ApplyServerSideEncryptionByDefault.SSEAlgorithm] {
- return nil, fmt.Errorf("unsupported encryption algorithm: %s", rule.ApplyServerSideEncryptionByDefault.SSEAlgorithm)
- }
-
- config := encryptionConfigFromXML(&xmlConfig)
- return config, nil
-}
-
-// encryptionConfigToXMLBytes converts encryption configuration to XML bytes
-func encryptionConfigToXMLBytes(config *s3_pb.EncryptionConfiguration) ([]byte, error) {
- if config == nil {
- return nil, fmt.Errorf("encryption configuration is nil")
- }
-
- xmlConfig := encryptionConfigToXML(config)
- return xml.Marshal(xmlConfig)
-}
diff --git a/weed/s3api/s3_iam_middleware.go b/weed/s3api/s3_iam_middleware.go
index 7820f3803..af454dee8 100644
--- a/weed/s3api/s3_iam_middleware.go
+++ b/weed/s3api/s3_iam_middleware.go
@@ -13,7 +13,6 @@ import (
"github.com/seaweedfs/seaweedfs/weed/iam/integration"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/seaweedfs/seaweedfs/weed/iam/sts"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
)
@@ -381,52 +380,6 @@ func buildS3ResourceArn(bucket string, objectKey string) string {
return "arn:aws:s3:::" + bucket + "/" + objectKey
}
-// mapLegacyActionToIAM provides fallback mapping for legacy actions
-// This ensures backward compatibility while the system transitions to granular actions
-func mapLegacyActionToIAM(legacyAction Action) string {
- switch legacyAction {
- case s3_constants.ACTION_READ:
- return "s3:GetObject" // Fallback for unmapped read operations
- case s3_constants.ACTION_WRITE:
- return "s3:PutObject" // Fallback for unmapped write operations
- case s3_constants.ACTION_LIST:
- return "s3:ListBucket" // Fallback for unmapped list operations
- case s3_constants.ACTION_TAGGING:
- return "s3:GetObjectTagging" // Fallback for unmapped tagging operations
- case s3_constants.ACTION_READ_ACP:
- return "s3:GetObjectAcl" // Fallback for unmapped ACL read operations
- case s3_constants.ACTION_WRITE_ACP:
- return "s3:PutObjectAcl" // Fallback for unmapped ACL write operations
- case s3_constants.ACTION_DELETE_BUCKET:
- return "s3:DeleteBucket" // Fallback for unmapped bucket delete operations
- case s3_constants.ACTION_ADMIN:
- return "s3:*" // Fallback for unmapped admin operations
-
- // Handle granular multipart actions (already correctly mapped)
- case s3_constants.S3_ACTION_CREATE_MULTIPART:
- return s3_constants.S3_ACTION_CREATE_MULTIPART
- case s3_constants.S3_ACTION_UPLOAD_PART:
- return s3_constants.S3_ACTION_UPLOAD_PART
- case s3_constants.S3_ACTION_COMPLETE_MULTIPART:
- return s3_constants.S3_ACTION_COMPLETE_MULTIPART
- case s3_constants.S3_ACTION_ABORT_MULTIPART:
- return s3_constants.S3_ACTION_ABORT_MULTIPART
- case s3_constants.S3_ACTION_LIST_MULTIPART_UPLOADS:
- return s3_constants.S3_ACTION_LIST_MULTIPART_UPLOADS
- case s3_constants.S3_ACTION_LIST_PARTS:
- return s3_constants.S3_ACTION_LIST_PARTS
-
- default:
- // If it's already a properly formatted S3 action, return as-is
- actionStr := string(legacyAction)
- if strings.HasPrefix(actionStr, "s3:") {
- return actionStr
- }
- // Fallback: convert to S3 action format
- return "s3:" + actionStr
- }
-}
-
// extractRequestContext extracts request context for policy conditions
func extractRequestContext(r *http.Request) map[string]interface{} {
context := make(map[string]interface{})
@@ -553,79 +506,6 @@ type EnhancedS3ApiServer struct {
iamIntegration IAMIntegration
}
-// NewEnhancedS3ApiServer creates an S3 API server with IAM integration
-func NewEnhancedS3ApiServer(baseServer *S3ApiServer, iamManager *integration.IAMManager) *EnhancedS3ApiServer {
- // Set the IAM integration on the base server
- baseServer.SetIAMIntegration(iamManager)
-
- return &EnhancedS3ApiServer{
- S3ApiServer: baseServer,
- iamIntegration: NewS3IAMIntegration(iamManager, "localhost:8888"),
- }
-}
-
-// AuthenticateJWTRequest handles JWT authentication for S3 requests
-func (enhanced *EnhancedS3ApiServer) AuthenticateJWTRequest(r *http.Request) (*Identity, s3err.ErrorCode) {
- ctx := r.Context()
-
- // Use our IAM integration for JWT authentication
- iamIdentity, errCode := enhanced.iamIntegration.AuthenticateJWT(ctx, r)
- if errCode != s3err.ErrNone {
- return nil, errCode
- }
-
- // Convert IAMIdentity to the existing Identity structure
- identity := &Identity{
- Name: iamIdentity.Name,
- Account: iamIdentity.Account,
- // Note: Actions will be determined by policy evaluation
- Actions: []Action{}, // Empty - authorization handled by policy engine
- PolicyNames: iamIdentity.PolicyNames,
- }
-
- // Store session token for later authorization
- r.Header.Set("X-SeaweedFS-Session-Token", iamIdentity.SessionToken)
- r.Header.Set("X-SeaweedFS-Principal", iamIdentity.Principal)
-
- return identity, s3err.ErrNone
-}
-
-// AuthorizeRequest handles authorization for S3 requests using policy engine
-func (enhanced *EnhancedS3ApiServer) AuthorizeRequest(r *http.Request, identity *Identity, action Action) s3err.ErrorCode {
- ctx := r.Context()
-
- // Get session info from request headers (set during authentication)
- sessionToken := r.Header.Get("X-SeaweedFS-Session-Token")
- principal := r.Header.Get("X-SeaweedFS-Principal")
-
- if sessionToken == "" || principal == "" {
- glog.V(3).Info("No session information available for authorization")
- return s3err.ErrAccessDenied
- }
-
- // Extract bucket and object from request
- bucket, object := s3_constants.GetBucketAndObject(r)
- prefix := s3_constants.GetPrefix(r)
-
- // For List operations, use prefix for permission checking if available
- if action == s3_constants.ACTION_LIST && object == "" && prefix != "" {
- object = prefix
- } else if (object == "/" || object == "") && prefix != "" {
- object = prefix
- }
-
- // Create IAM identity for authorization
- iamIdentity := &IAMIdentity{
- Name: identity.Name,
- Principal: principal,
- SessionToken: sessionToken,
- Account: identity.Account,
- }
-
- // Use our IAM integration for authorization
- return enhanced.iamIntegration.AuthorizeAction(ctx, iamIdentity, action, bucket, object, r)
-}
-
// OIDCIdentity represents an identity validated through OIDC
type OIDCIdentity struct {
UserID string
diff --git a/weed/s3api/s3_iam_simple_test.go b/weed/s3api/s3_iam_simple_test.go
deleted file mode 100644
index c2c68321f..000000000
--- a/weed/s3api/s3_iam_simple_test.go
+++ /dev/null
@@ -1,584 +0,0 @@
-package s3api
-
-import (
- "context"
- "net/http"
- "net/http/httptest"
- "net/url"
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/iam/integration"
- "github.com/seaweedfs/seaweedfs/weed/iam/policy"
- "github.com/seaweedfs/seaweedfs/weed/iam/sts"
- "github.com/seaweedfs/seaweedfs/weed/iam/utils"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func newTestS3IAMManagerWithDefaultEffect(t *testing.T, defaultEffect string) *integration.IAMManager {
- t.Helper()
-
- iamManager := integration.NewIAMManager()
- config := &integration.IAMConfig{
- STS: &sts.STSConfig{
- TokenDuration: sts.FlexibleDuration{Duration: time.Hour},
- MaxSessionLength: sts.FlexibleDuration{Duration: time.Hour * 12},
- Issuer: "test-sts",
- SigningKey: []byte("test-signing-key-32-characters-long"),
- },
- Policy: &policy.PolicyEngineConfig{
- DefaultEffect: defaultEffect,
- StoreType: "memory",
- },
- Roles: &integration.RoleStoreConfig{
- StoreType: "memory",
- },
- }
-
- err := iamManager.Initialize(config, func() string {
- return "localhost:8888"
- })
- require.NoError(t, err)
-
- return iamManager
-}
-
-func newTestS3IAMManager(t *testing.T) *integration.IAMManager {
- t.Helper()
- return newTestS3IAMManagerWithDefaultEffect(t, "Deny")
-}
-
-// TestS3IAMMiddleware tests the basic S3 IAM middleware functionality
-func TestS3IAMMiddleware(t *testing.T) {
- iamManager := newTestS3IAMManager(t)
-
- // Create S3 IAM integration
- s3IAMIntegration := NewS3IAMIntegration(iamManager, "localhost:8888")
-
- // Test that integration is created successfully
- assert.NotNil(t, s3IAMIntegration)
- assert.True(t, s3IAMIntegration.enabled)
-}
-
-func TestS3IAMMiddlewareStaticV4ManagedPolicies(t *testing.T) {
- ctx := context.Background()
- iamManager := newTestS3IAMManager(t)
-
- allowPolicy := &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Effect: "Allow",
- Action: policy.StringList{"s3:PutObject", "s3:ListBucket"},
- Resource: policy.StringList{"arn:aws:s3:::cli-allowed-bucket", "arn:aws:s3:::cli-allowed-bucket/*"},
- },
- },
- }
- require.NoError(t, iamManager.CreatePolicy(ctx, "localhost:8888", "cli-bucket-access-policy", allowPolicy))
-
- s3IAMIntegration := NewS3IAMIntegration(iamManager, "localhost:8888")
- identity := &IAMIdentity{
- Name: "cli-test-user",
- Principal: "arn:aws:iam::000000000000:user/cli-test-user",
- PolicyNames: []string{"cli-bucket-access-policy"},
- }
-
- putReq := httptest.NewRequest(http.MethodPut, "http://example.com/cli-allowed-bucket/test-file.txt", http.NoBody)
- putErrCode := s3IAMIntegration.AuthorizeAction(ctx, identity, s3_constants.ACTION_WRITE, "cli-allowed-bucket", "test-file.txt", putReq)
- assert.Equal(t, s3err.ErrNone, putErrCode)
-
- listReq := httptest.NewRequest(http.MethodGet, "http://example.com/cli-allowed-bucket/", http.NoBody)
- listErrCode := s3IAMIntegration.AuthorizeAction(ctx, identity, s3_constants.ACTION_LIST, "cli-allowed-bucket", "", listReq)
- assert.Equal(t, s3err.ErrNone, listErrCode)
-}
-
-func TestS3IAMMiddlewareAttachedPoliciesRestrictDefaultAllow(t *testing.T) {
- ctx := context.Background()
- iamManager := newTestS3IAMManagerWithDefaultEffect(t, "Allow")
-
- allowPolicy := &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Effect: "Allow",
- Action: policy.StringList{"s3:PutObject", "s3:ListBucket"},
- Resource: policy.StringList{"arn:aws:s3:::cli-allowed-bucket", "arn:aws:s3:::cli-allowed-bucket/*"},
- },
- },
- }
- require.NoError(t, iamManager.CreatePolicy(ctx, "localhost:8888", "cli-bucket-access-policy", allowPolicy))
-
- s3IAMIntegration := NewS3IAMIntegration(iamManager, "localhost:8888")
- identity := &IAMIdentity{
- Name: "cli-test-user",
- Principal: "arn:aws:iam::000000000000:user/cli-test-user",
- PolicyNames: []string{"cli-bucket-access-policy"},
- }
-
- allowedReq := httptest.NewRequest(http.MethodPut, "http://example.com/cli-allowed-bucket/test-file.txt", http.NoBody)
- allowedErrCode := s3IAMIntegration.AuthorizeAction(ctx, identity, s3_constants.ACTION_WRITE, "cli-allowed-bucket", "test-file.txt", allowedReq)
- assert.Equal(t, s3err.ErrNone, allowedErrCode)
-
- forbiddenReq := httptest.NewRequest(http.MethodPut, "http://example.com/cli-forbidden-bucket/forbidden-file.txt", http.NoBody)
- forbiddenErrCode := s3IAMIntegration.AuthorizeAction(ctx, identity, s3_constants.ACTION_WRITE, "cli-forbidden-bucket", "forbidden-file.txt", forbiddenReq)
- assert.Equal(t, s3err.ErrAccessDenied, forbiddenErrCode)
-
- forbiddenListReq := httptest.NewRequest(http.MethodGet, "http://example.com/cli-forbidden-bucket/", http.NoBody)
- forbiddenListErrCode := s3IAMIntegration.AuthorizeAction(ctx, identity, s3_constants.ACTION_LIST, "cli-forbidden-bucket", "", forbiddenListReq)
- assert.Equal(t, s3err.ErrAccessDenied, forbiddenListErrCode)
-}
-
-// TestS3IAMMiddlewareJWTAuth tests JWT authentication
-func TestS3IAMMiddlewareJWTAuth(t *testing.T) {
- // Skip for now since it requires full setup
- t.Skip("JWT authentication test requires full IAM setup")
-
- // Create IAM integration
- s3iam := NewS3IAMIntegration(nil, "localhost:8888") // Disabled integration
-
- // Create test request with JWT token
- req := httptest.NewRequest("GET", "/test-bucket/test-object", http.NoBody)
- req.Header.Set("Authorization", "Bearer test-token")
-
- // Test authentication (should return not implemented when disabled)
- ctx := context.Background()
- identity, errCode := s3iam.AuthenticateJWT(ctx, req)
-
- assert.Nil(t, identity)
- assert.NotEqual(t, errCode, 0) // Should return an error
-}
-
-// TestBuildS3ResourceArn tests resource ARN building
-func TestBuildS3ResourceArn(t *testing.T) {
- tests := []struct {
- name string
- bucket string
- object string
- expected string
- }{
- {
- name: "empty bucket and object",
- bucket: "",
- object: "",
- expected: "arn:aws:s3:::*",
- },
- {
- name: "bucket only",
- bucket: "test-bucket",
- object: "",
- expected: "arn:aws:s3:::test-bucket",
- },
- {
- name: "bucket and object",
- bucket: "test-bucket",
- object: "test-object.txt",
- expected: "arn:aws:s3:::test-bucket/test-object.txt",
- },
- {
- name: "bucket and object with leading slash",
- bucket: "test-bucket",
- object: "/test-object.txt",
- expected: "arn:aws:s3:::test-bucket/test-object.txt",
- },
- {
- name: "bucket and nested object",
- bucket: "test-bucket",
- object: "folder/subfolder/test-object.txt",
- expected: "arn:aws:s3:::test-bucket/folder/subfolder/test-object.txt",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := buildS3ResourceArn(tt.bucket, tt.object)
- assert.Equal(t, tt.expected, result)
- })
- }
-}
-
-// TestDetermineGranularS3Action tests granular S3 action determination from HTTP requests
-func TestDetermineGranularS3Action(t *testing.T) {
- tests := []struct {
- name string
- method string
- bucket string
- objectKey string
- queryParams map[string]string
- fallbackAction Action
- expected string
- description string
- }{
- // Object-level operations
- {
- name: "get_object",
- method: "GET",
- bucket: "test-bucket",
- objectKey: "test-object.txt",
- queryParams: map[string]string{},
- fallbackAction: s3_constants.ACTION_READ,
- expected: "s3:GetObject",
- description: "Basic object retrieval",
- },
- {
- name: "get_object_acl",
- method: "GET",
- bucket: "test-bucket",
- objectKey: "test-object.txt",
- queryParams: map[string]string{"acl": ""},
- fallbackAction: s3_constants.ACTION_READ_ACP,
- expected: "s3:GetObjectAcl",
- description: "Object ACL retrieval",
- },
- {
- name: "get_object_tagging",
- method: "GET",
- bucket: "test-bucket",
- objectKey: "test-object.txt",
- queryParams: map[string]string{"tagging": ""},
- fallbackAction: s3_constants.ACTION_TAGGING,
- expected: "s3:GetObjectTagging",
- description: "Object tagging retrieval",
- },
- {
- name: "put_object",
- method: "PUT",
- bucket: "test-bucket",
- objectKey: "test-object.txt",
- queryParams: map[string]string{},
- fallbackAction: s3_constants.ACTION_WRITE,
- expected: "s3:PutObject",
- description: "Basic object upload",
- },
- {
- name: "put_object_acl",
- method: "PUT",
- bucket: "test-bucket",
- objectKey: "test-object.txt",
- queryParams: map[string]string{"acl": ""},
- fallbackAction: s3_constants.ACTION_WRITE_ACP,
- expected: "s3:PutObjectAcl",
- description: "Object ACL modification",
- },
- {
- name: "delete_object",
- method: "DELETE",
- bucket: "test-bucket",
- objectKey: "test-object.txt",
- queryParams: map[string]string{},
- fallbackAction: s3_constants.ACTION_WRITE, // DELETE object uses WRITE fallback
- expected: "s3:DeleteObject",
- description: "Object deletion - correctly mapped to DeleteObject (not PutObject)",
- },
- {
- name: "delete_object_tagging",
- method: "DELETE",
- bucket: "test-bucket",
- objectKey: "test-object.txt",
- queryParams: map[string]string{"tagging": ""},
- fallbackAction: s3_constants.ACTION_TAGGING,
- expected: "s3:DeleteObjectTagging",
- description: "Object tag deletion",
- },
-
- // Multipart upload operations
- {
- name: "create_multipart_upload",
- method: "POST",
- bucket: "test-bucket",
- objectKey: "large-file.txt",
- queryParams: map[string]string{"uploads": ""},
- fallbackAction: s3_constants.ACTION_WRITE,
- expected: "s3:CreateMultipartUpload",
- description: "Multipart upload initiation",
- },
- {
- name: "upload_part",
- method: "PUT",
- bucket: "test-bucket",
- objectKey: "large-file.txt",
- queryParams: map[string]string{"uploadId": "12345", "partNumber": "1"},
- fallbackAction: s3_constants.ACTION_WRITE,
- expected: "s3:UploadPart",
- description: "Multipart part upload",
- },
- {
- name: "complete_multipart_upload",
- method: "POST",
- bucket: "test-bucket",
- objectKey: "large-file.txt",
- queryParams: map[string]string{"uploadId": "12345"},
- fallbackAction: s3_constants.ACTION_WRITE,
- expected: "s3:CompleteMultipartUpload",
- description: "Multipart upload completion",
- },
- {
- name: "abort_multipart_upload",
- method: "DELETE",
- bucket: "test-bucket",
- objectKey: "large-file.txt",
- queryParams: map[string]string{"uploadId": "12345"},
- fallbackAction: s3_constants.ACTION_WRITE,
- expected: "s3:AbortMultipartUpload",
- description: "Multipart upload abort",
- },
-
- // Bucket-level operations
- {
- name: "list_bucket",
- method: "GET",
- bucket: "test-bucket",
- objectKey: "",
- queryParams: map[string]string{},
- fallbackAction: s3_constants.ACTION_LIST,
- expected: "s3:ListBucket",
- description: "Bucket listing",
- },
- {
- name: "get_bucket_acl",
- method: "GET",
- bucket: "test-bucket",
- objectKey: "",
- queryParams: map[string]string{"acl": ""},
- fallbackAction: s3_constants.ACTION_READ_ACP,
- expected: "s3:GetBucketAcl",
- description: "Bucket ACL retrieval",
- },
- {
- name: "put_bucket_policy",
- method: "PUT",
- bucket: "test-bucket",
- objectKey: "",
- queryParams: map[string]string{"policy": ""},
- fallbackAction: s3_constants.ACTION_WRITE,
- expected: "s3:PutBucketPolicy",
- description: "Bucket policy modification",
- },
- {
- name: "delete_bucket",
- method: "DELETE",
- bucket: "test-bucket",
- objectKey: "",
- queryParams: map[string]string{},
- fallbackAction: s3_constants.ACTION_DELETE_BUCKET,
- expected: "s3:DeleteBucket",
- description: "Bucket deletion",
- },
- {
- name: "list_multipart_uploads",
- method: "GET",
- bucket: "test-bucket",
- objectKey: "",
- queryParams: map[string]string{"uploads": ""},
- fallbackAction: s3_constants.ACTION_LIST,
- expected: "s3:ListBucketMultipartUploads",
- description: "List multipart uploads in bucket",
- },
-
- // Fallback scenarios
- {
- name: "legacy_read_fallback",
- method: "GET",
- bucket: "",
- objectKey: "",
- queryParams: map[string]string{},
- fallbackAction: s3_constants.ACTION_READ,
- expected: "s3:GetObject",
- description: "Legacy read action fallback",
- },
- {
- name: "already_granular_action",
- method: "GET",
- bucket: "",
- objectKey: "",
- queryParams: map[string]string{},
- fallbackAction: "s3:GetBucketLocation", // Already granular
- expected: "s3:GetBucketLocation",
- description: "Already granular action passed through",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Create HTTP request with query parameters
- req := &http.Request{
- Method: tt.method,
- URL: &url.URL{Path: "/" + tt.bucket + "/" + tt.objectKey},
- }
-
- // Add query parameters
- query := req.URL.Query()
- for key, value := range tt.queryParams {
- query.Set(key, value)
- }
- req.URL.RawQuery = query.Encode()
-
- // Test the action determination
- result := ResolveS3Action(req, string(tt.fallbackAction), tt.bucket, tt.objectKey)
-
- assert.Equal(t, tt.expected, result,
- "Test %s failed: %s. Expected %s but got %s",
- tt.name, tt.description, tt.expected, result)
- })
- }
-}
-
-// TestMapLegacyActionToIAM tests the legacy action fallback mapping
-func TestMapLegacyActionToIAM(t *testing.T) {
- tests := []struct {
- name string
- legacyAction Action
- expected string
- }{
- {
- name: "read_action_fallback",
- legacyAction: s3_constants.ACTION_READ,
- expected: "s3:GetObject",
- },
- {
- name: "write_action_fallback",
- legacyAction: s3_constants.ACTION_WRITE,
- expected: "s3:PutObject",
- },
- {
- name: "admin_action_fallback",
- legacyAction: s3_constants.ACTION_ADMIN,
- expected: "s3:*",
- },
- {
- name: "granular_multipart_action",
- legacyAction: s3_constants.S3_ACTION_CREATE_MULTIPART,
- expected: s3_constants.S3_ACTION_CREATE_MULTIPART,
- },
- {
- name: "unknown_action_with_s3_prefix",
- legacyAction: "s3:CustomAction",
- expected: "s3:CustomAction",
- },
- {
- name: "unknown_action_without_prefix",
- legacyAction: "CustomAction",
- expected: "s3:CustomAction",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := mapLegacyActionToIAM(tt.legacyAction)
- assert.Equal(t, tt.expected, result)
- })
- }
-}
-
-// TestExtractSourceIP tests source IP extraction from requests
-func TestExtractSourceIP(t *testing.T) {
- tests := []struct {
- name string
- setupReq func() *http.Request
- expectedIP string
- }{
- {
- name: "X-Forwarded-For header",
- setupReq: func() *http.Request {
- req := httptest.NewRequest("GET", "/test", http.NoBody)
- req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
- // Set RemoteAddr to private IP to simulate trusted proxy
- req.RemoteAddr = "127.0.0.1:12345"
- return req
- },
- expectedIP: "192.168.1.100",
- },
- {
- name: "X-Real-IP header",
- setupReq: func() *http.Request {
- req := httptest.NewRequest("GET", "/test", http.NoBody)
- req.Header.Set("X-Real-IP", "192.168.1.200")
- // Set RemoteAddr to private IP to simulate trusted proxy
- req.RemoteAddr = "127.0.0.1:12345"
- return req
- },
- expectedIP: "192.168.1.200",
- },
- {
- name: "RemoteAddr fallback",
- setupReq: func() *http.Request {
- req := httptest.NewRequest("GET", "/test", http.NoBody)
- req.RemoteAddr = "192.168.1.300:12345"
- return req
- },
- expectedIP: "192.168.1.300",
- },
- {
- name: "Untrusted proxy - public RemoteAddr ignores X-Forwarded-For",
- setupReq: func() *http.Request {
- req := httptest.NewRequest("GET", "/test", http.NoBody)
- req.Header.Set("X-Forwarded-For", "192.168.1.100")
- // Public IP - headers should NOT be trusted
- req.RemoteAddr = "8.8.8.8:12345"
- return req
- },
- expectedIP: "8.8.8.8", // Should use RemoteAddr, not X-Forwarded-For
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- req := tt.setupReq()
- result := extractSourceIP(req)
- assert.Equal(t, tt.expectedIP, result)
- })
- }
-}
-
-// TestExtractRoleNameFromPrincipal tests role name extraction
-func TestExtractRoleNameFromPrincipal(t *testing.T) {
- tests := []struct {
- name string
- principal string
- expected string
- }{
- {
- name: "valid assumed role ARN",
- principal: "arn:aws:sts::assumed-role/S3ReadOnlyRole/session-123",
- expected: "S3ReadOnlyRole",
- },
- {
- name: "invalid format",
- principal: "invalid-principal",
- expected: "", // Returns empty string to signal invalid format
- },
- {
- name: "missing session name",
- principal: "arn:aws:sts::assumed-role/TestRole",
- expected: "TestRole", // Extracts role name even without session name
- },
- {
- name: "empty principal",
- principal: "",
- expected: "",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := utils.ExtractRoleNameFromPrincipal(tt.principal)
- assert.Equal(t, tt.expected, result)
- })
- }
-}
-
-// TestIAMIdentityIsAdmin tests the IsAdmin method
-func TestIAMIdentityIsAdmin(t *testing.T) {
- identity := &IAMIdentity{
- Name: "test-identity",
- Principal: "arn:aws:sts::assumed-role/TestRole/session",
- SessionToken: "test-token",
- }
-
- // In our implementation, IsAdmin always returns false since admin status
- // is determined by policies, not identity
- result := identity.IsAdmin()
- assert.False(t, result)
-}
diff --git a/weed/s3api/s3_multipart_iam.go b/weed/s3api/s3_multipart_iam.go
deleted file mode 100644
index de3bccae9..000000000
--- a/weed/s3api/s3_multipart_iam.go
+++ /dev/null
@@ -1,420 +0,0 @@
-package s3api
-
-import (
- "fmt"
- "net/http"
- "strconv"
- "strings"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
-)
-
-// S3MultipartIAMManager handles IAM integration for multipart upload operations
-type S3MultipartIAMManager struct {
- s3iam *S3IAMIntegration
-}
-
-// NewS3MultipartIAMManager creates a new multipart IAM manager
-func NewS3MultipartIAMManager(s3iam *S3IAMIntegration) *S3MultipartIAMManager {
- return &S3MultipartIAMManager{
- s3iam: s3iam,
- }
-}
-
-// MultipartUploadRequest represents a multipart upload request
-type MultipartUploadRequest struct {
- Bucket string `json:"bucket"` // S3 bucket name
- ObjectKey string `json:"object_key"` // S3 object key
- UploadID string `json:"upload_id"` // Multipart upload ID
- PartNumber int `json:"part_number"` // Part number for upload part
- Operation string `json:"operation"` // Multipart operation type
- SessionToken string `json:"session_token"` // JWT session token
- Headers map[string]string `json:"headers"` // Request headers
- ContentSize int64 `json:"content_size"` // Content size for validation
-}
-
-// MultipartUploadPolicy represents security policies for multipart uploads
-type MultipartUploadPolicy struct {
- MaxPartSize int64 `json:"max_part_size"` // Maximum part size (5GB AWS limit)
- MinPartSize int64 `json:"min_part_size"` // Minimum part size (5MB AWS limit, except last part)
- MaxParts int `json:"max_parts"` // Maximum number of parts (10,000 AWS limit)
- MaxUploadDuration time.Duration `json:"max_upload_duration"` // Maximum time to complete multipart upload
- AllowedContentTypes []string `json:"allowed_content_types"` // Allowed content types
- RequiredHeaders []string `json:"required_headers"` // Required headers for validation
- IPWhitelist []string `json:"ip_whitelist"` // Allowed IP addresses/ranges
-}
-
-// MultipartOperation represents different multipart upload operations
-type MultipartOperation string
-
-const (
- MultipartOpInitiate MultipartOperation = "initiate"
- MultipartOpUploadPart MultipartOperation = "upload_part"
- MultipartOpComplete MultipartOperation = "complete"
- MultipartOpAbort MultipartOperation = "abort"
- MultipartOpList MultipartOperation = "list"
- MultipartOpListParts MultipartOperation = "list_parts"
-)
-
-// ValidateMultipartOperationWithIAM validates multipart operations using IAM policies
-func (iam *IdentityAccessManagement) ValidateMultipartOperationWithIAM(r *http.Request, identity *Identity, operation MultipartOperation) s3err.ErrorCode {
- if iam.iamIntegration == nil {
- // Fall back to standard validation
- return s3err.ErrNone
- }
-
- // Extract bucket and object from request
- bucket, object := s3_constants.GetBucketAndObject(r)
-
- // Determine the S3 action based on multipart operation
- action := determineMultipartS3Action(operation)
-
- // Extract session token from request
- sessionToken := extractSessionTokenFromRequest(r)
- if sessionToken == "" {
- // No session token - use standard auth
- return s3err.ErrNone
- }
-
- // Retrieve the actual principal ARN from the request header
- // This header is set during initial authentication and contains the correct assumed role ARN
- principalArn := r.Header.Get("X-SeaweedFS-Principal")
- if principalArn == "" {
- glog.V(2).Info("IAM authorization for multipart operation failed: missing principal ARN in request header")
- return s3err.ErrAccessDenied
- }
-
- // Create IAM identity for authorization
- iamIdentity := &IAMIdentity{
- Name: identity.Name,
- Principal: principalArn,
- SessionToken: sessionToken,
- Account: identity.Account,
- }
-
- // Authorize using IAM
- ctx := r.Context()
- errCode := iam.iamIntegration.AuthorizeAction(ctx, iamIdentity, action, bucket, object, r)
- if errCode != s3err.ErrNone {
- glog.V(3).Infof("IAM authorization failed for multipart operation: principal=%s operation=%s action=%s bucket=%s object=%s",
- iamIdentity.Principal, operation, action, bucket, object)
- return errCode
- }
-
- glog.V(3).Infof("IAM authorization succeeded for multipart operation: principal=%s operation=%s action=%s bucket=%s object=%s",
- iamIdentity.Principal, operation, action, bucket, object)
- return s3err.ErrNone
-}
-
-// ValidateMultipartRequestWithPolicy validates multipart request against security policy
-func (policy *MultipartUploadPolicy) ValidateMultipartRequestWithPolicy(req *MultipartUploadRequest) error {
- if req == nil {
- return fmt.Errorf("multipart request cannot be nil")
- }
-
- // Validate part size for upload part operations
- if req.Operation == string(MultipartOpUploadPart) {
- if req.ContentSize > policy.MaxPartSize {
- return fmt.Errorf("part size %d exceeds maximum allowed %d", req.ContentSize, policy.MaxPartSize)
- }
-
- // Minimum part size validation (except for last part)
- // Note: Last part validation would require knowing if this is the final part
- if req.ContentSize < policy.MinPartSize && req.ContentSize > 0 {
- glog.V(2).Infof("Part size %d is below minimum %d - assuming last part", req.ContentSize, policy.MinPartSize)
- }
-
- // Validate part number
- if req.PartNumber < 1 || req.PartNumber > policy.MaxParts {
- return fmt.Errorf("part number %d is invalid (must be 1-%d)", req.PartNumber, policy.MaxParts)
- }
- }
-
- // Validate required headers first
- if req.Headers != nil {
- for _, requiredHeader := range policy.RequiredHeaders {
- if _, exists := req.Headers[requiredHeader]; !exists {
- // Check lowercase version
- if _, exists := req.Headers[strings.ToLower(requiredHeader)]; !exists {
- return fmt.Errorf("required header %s is missing", requiredHeader)
- }
- }
- }
- }
-
- // Validate content type if specified
- if len(policy.AllowedContentTypes) > 0 && req.Headers != nil {
- contentType := req.Headers["Content-Type"]
- if contentType == "" {
- contentType = req.Headers["content-type"]
- }
-
- allowed := false
- for _, allowedType := range policy.AllowedContentTypes {
- if contentType == allowedType {
- allowed = true
- break
- }
- }
-
- if !allowed {
- return fmt.Errorf("content type %s is not allowed", contentType)
- }
- }
-
- return nil
-}
-
-// Enhanced multipart handlers with IAM integration
-
-// NewMultipartUploadWithIAM handles initiate multipart upload with IAM validation
-func (s3a *S3ApiServer) NewMultipartUploadWithIAM(w http.ResponseWriter, r *http.Request) {
- // Validate IAM permissions first
- if s3a.iam.iamIntegration != nil {
- if identity, errCode := s3a.iam.authRequest(r, s3_constants.ACTION_WRITE); errCode != s3err.ErrNone {
- s3err.WriteErrorResponse(w, r, errCode)
- return
- } else {
- // Additional multipart-specific IAM validation
- if errCode := s3a.iam.ValidateMultipartOperationWithIAM(r, identity, MultipartOpInitiate); errCode != s3err.ErrNone {
- s3err.WriteErrorResponse(w, r, errCode)
- return
- }
- }
- }
-
- // Delegate to existing handler
- s3a.NewMultipartUploadHandler(w, r)
-}
-
-// CompleteMultipartUploadWithIAM handles complete multipart upload with IAM validation
-func (s3a *S3ApiServer) CompleteMultipartUploadWithIAM(w http.ResponseWriter, r *http.Request) {
- // Validate IAM permissions first
- if s3a.iam.iamIntegration != nil {
- if identity, errCode := s3a.iam.authRequest(r, s3_constants.ACTION_WRITE); errCode != s3err.ErrNone {
- s3err.WriteErrorResponse(w, r, errCode)
- return
- } else {
- // Additional multipart-specific IAM validation
- if errCode := s3a.iam.ValidateMultipartOperationWithIAM(r, identity, MultipartOpComplete); errCode != s3err.ErrNone {
- s3err.WriteErrorResponse(w, r, errCode)
- return
- }
- }
- }
-
- // Delegate to existing handler
- s3a.CompleteMultipartUploadHandler(w, r)
-}
-
-// AbortMultipartUploadWithIAM handles abort multipart upload with IAM validation
-func (s3a *S3ApiServer) AbortMultipartUploadWithIAM(w http.ResponseWriter, r *http.Request) {
- // Validate IAM permissions first
- if s3a.iam.iamIntegration != nil {
- if identity, errCode := s3a.iam.authRequest(r, s3_constants.ACTION_WRITE); errCode != s3err.ErrNone {
- s3err.WriteErrorResponse(w, r, errCode)
- return
- } else {
- // Additional multipart-specific IAM validation
- if errCode := s3a.iam.ValidateMultipartOperationWithIAM(r, identity, MultipartOpAbort); errCode != s3err.ErrNone {
- s3err.WriteErrorResponse(w, r, errCode)
- return
- }
- }
- }
-
- // Delegate to existing handler
- s3a.AbortMultipartUploadHandler(w, r)
-}
-
-// ListMultipartUploadsWithIAM handles list multipart uploads with IAM validation
-func (s3a *S3ApiServer) ListMultipartUploadsWithIAM(w http.ResponseWriter, r *http.Request) {
- // Validate IAM permissions first
- if s3a.iam.iamIntegration != nil {
- if identity, errCode := s3a.iam.authRequest(r, s3_constants.ACTION_LIST); errCode != s3err.ErrNone {
- s3err.WriteErrorResponse(w, r, errCode)
- return
- } else {
- // Additional multipart-specific IAM validation
- if errCode := s3a.iam.ValidateMultipartOperationWithIAM(r, identity, MultipartOpList); errCode != s3err.ErrNone {
- s3err.WriteErrorResponse(w, r, errCode)
- return
- }
- }
- }
-
- // Delegate to existing handler
- s3a.ListMultipartUploadsHandler(w, r)
-}
-
-// UploadPartWithIAM handles upload part with IAM validation
-func (s3a *S3ApiServer) UploadPartWithIAM(w http.ResponseWriter, r *http.Request) {
- // Validate IAM permissions first
- if s3a.iam.iamIntegration != nil {
- if identity, errCode := s3a.iam.authRequest(r, s3_constants.ACTION_WRITE); errCode != s3err.ErrNone {
- s3err.WriteErrorResponse(w, r, errCode)
- return
- } else {
- // Additional multipart-specific IAM validation
- if errCode := s3a.iam.ValidateMultipartOperationWithIAM(r, identity, MultipartOpUploadPart); errCode != s3err.ErrNone {
- s3err.WriteErrorResponse(w, r, errCode)
- return
- }
-
- // Validate part size and other policies
- if err := s3a.validateUploadPartRequest(r); err != nil {
- glog.Errorf("Upload part validation failed: %v", err)
- s3err.WriteErrorResponse(w, r, s3err.ErrInvalidRequest)
- return
- }
- }
- }
-
- // Delegate to existing object PUT handler (which handles upload part)
- s3a.PutObjectHandler(w, r)
-}
-
-// Helper functions
-
-// determineMultipartS3Action maps multipart operations to granular S3 actions
-// This enables fine-grained IAM policies for multipart upload operations
-func determineMultipartS3Action(operation MultipartOperation) Action {
- switch operation {
- case MultipartOpInitiate:
- return s3_constants.S3_ACTION_CREATE_MULTIPART
- case MultipartOpUploadPart:
- return s3_constants.S3_ACTION_UPLOAD_PART
- case MultipartOpComplete:
- return s3_constants.S3_ACTION_COMPLETE_MULTIPART
- case MultipartOpAbort:
- return s3_constants.S3_ACTION_ABORT_MULTIPART
- case MultipartOpList:
- return s3_constants.S3_ACTION_LIST_MULTIPART_UPLOADS
- case MultipartOpListParts:
- return s3_constants.S3_ACTION_LIST_PARTS
- default:
- // Fail closed for unmapped operations to prevent unintended access
- glog.Errorf("unmapped multipart operation: %s", operation)
- return "s3:InternalErrorUnknownMultipartAction" // Non-existent action ensures denial
- }
-}
-
-// extractSessionTokenFromRequest extracts session token from various request sources
-func extractSessionTokenFromRequest(r *http.Request) string {
- // Check Authorization header for Bearer token
- if authHeader := r.Header.Get("Authorization"); authHeader != "" {
- if strings.HasPrefix(authHeader, "Bearer ") {
- return strings.TrimPrefix(authHeader, "Bearer ")
- }
- }
-
- // Check X-Amz-Security-Token header
- if token := r.Header.Get("X-Amz-Security-Token"); token != "" {
- return token
- }
-
- // Check query parameters for presigned URL tokens
- if token := r.URL.Query().Get("X-Amz-Security-Token"); token != "" {
- return token
- }
-
- return ""
-}
-
-// validateUploadPartRequest validates upload part request against policies
-func (s3a *S3ApiServer) validateUploadPartRequest(r *http.Request) error {
- // Get default multipart policy
- policy := DefaultMultipartUploadPolicy()
-
- // Extract part number from query
- partNumberStr := r.URL.Query().Get("partNumber")
- if partNumberStr == "" {
- return fmt.Errorf("missing partNumber parameter")
- }
-
- partNumber, err := strconv.Atoi(partNumberStr)
- if err != nil {
- return fmt.Errorf("invalid partNumber: %v", err)
- }
-
- // Get content length
- contentLength := r.ContentLength
- if contentLength < 0 {
- contentLength = 0
- }
-
- // Create multipart request for validation
- bucket, object := s3_constants.GetBucketAndObject(r)
- multipartReq := &MultipartUploadRequest{
- Bucket: bucket,
- ObjectKey: object,
- PartNumber: partNumber,
- Operation: string(MultipartOpUploadPart),
- ContentSize: contentLength,
- Headers: make(map[string]string),
- }
-
- // Copy relevant headers
- for key, values := range r.Header {
- if len(values) > 0 {
- multipartReq.Headers[key] = values[0]
- }
- }
-
- // Validate against policy
- return policy.ValidateMultipartRequestWithPolicy(multipartReq)
-}
-
-// DefaultMultipartUploadPolicy returns a default multipart upload security policy
-func DefaultMultipartUploadPolicy() *MultipartUploadPolicy {
- return &MultipartUploadPolicy{
- MaxPartSize: 5 * 1024 * 1024 * 1024, // 5GB AWS limit
- MinPartSize: 5 * 1024 * 1024, // 5MB AWS minimum (except last part)
- MaxParts: 10000, // AWS limit
- MaxUploadDuration: 7 * 24 * time.Hour, // 7 days to complete upload
- AllowedContentTypes: []string{}, // Empty means all types allowed
- RequiredHeaders: []string{}, // No required headers by default
- IPWhitelist: []string{}, // Empty means no IP restrictions
- }
-}
-
-// MultipartUploadSession represents an ongoing multipart upload session
-type MultipartUploadSession struct {
- UploadID string `json:"upload_id"`
- Bucket string `json:"bucket"`
- ObjectKey string `json:"object_key"`
- Initiator string `json:"initiator"` // User who initiated the upload
- Owner string `json:"owner"` // Object owner
- CreatedAt time.Time `json:"created_at"` // When upload was initiated
- Parts []MultipartUploadPart `json:"parts"` // Uploaded parts
- Metadata map[string]string `json:"metadata"` // Object metadata
- Policy *MultipartUploadPolicy `json:"policy"` // Applied security policy
- SessionToken string `json:"session_token"` // IAM session token
-}
-
-// MultipartUploadPart represents an uploaded part
-type MultipartUploadPart struct {
- PartNumber int `json:"part_number"`
- Size int64 `json:"size"`
- ETag string `json:"etag"`
- LastModified time.Time `json:"last_modified"`
- Checksum string `json:"checksum"` // Optional integrity checksum
-}
-
-// GetMultipartUploadSessions retrieves active multipart upload sessions for a bucket
-func (s3a *S3ApiServer) GetMultipartUploadSessions(bucket string) ([]*MultipartUploadSession, error) {
- // This would typically query the filer for active multipart uploads
- // For now, return empty list as this is a placeholder for the full implementation
- return []*MultipartUploadSession{}, nil
-}
-
-// CleanupExpiredMultipartUploads removes expired multipart upload sessions
-func (s3a *S3ApiServer) CleanupExpiredMultipartUploads(maxAge time.Duration) error {
- // This would typically scan for and remove expired multipart uploads
- // Implementation would depend on how multipart sessions are stored in the filer
- glog.V(2).Infof("Cleanup expired multipart uploads older than %v", maxAge)
- return nil
-}
diff --git a/weed/s3api/s3_multipart_iam_test.go b/weed/s3api/s3_multipart_iam_test.go
deleted file mode 100644
index 12546eb7a..000000000
--- a/weed/s3api/s3_multipart_iam_test.go
+++ /dev/null
@@ -1,614 +0,0 @@
-package s3api
-
-import (
- "context"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
-
- "github.com/golang-jwt/jwt/v5"
- "github.com/seaweedfs/seaweedfs/weed/iam/integration"
- "github.com/seaweedfs/seaweedfs/weed/iam/ldap"
- "github.com/seaweedfs/seaweedfs/weed/iam/oidc"
- "github.com/seaweedfs/seaweedfs/weed/iam/policy"
- "github.com/seaweedfs/seaweedfs/weed/iam/sts"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// createTestJWTMultipart creates a test JWT token with the specified issuer, subject and signing key
-func createTestJWTMultipart(t *testing.T, issuer, subject, signingKey string) string {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "iss": issuer,
- "sub": subject,
- "aud": "test-client-id",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- // Add claims that trust policy validation expects
- "idp": "test-oidc", // Identity provider claim for trust policy matching
- })
-
- tokenString, err := token.SignedString([]byte(signingKey))
- require.NoError(t, err)
- return tokenString
-}
-
-// TestMultipartIAMValidation tests IAM validation for multipart operations
-func TestMultipartIAMValidation(t *testing.T) {
- // Set up IAM system
- iamManager := setupTestIAMManagerForMultipart(t)
- s3iam := NewS3IAMIntegration(iamManager, "localhost:8888")
- s3iam.enabled = true
-
- // Create IAM with integration
- iam := &IdentityAccessManagement{
- isAuthEnabled: true,
- }
- iam.SetIAMIntegration(s3iam)
-
- // Set up roles
- ctx := context.Background()
- setupTestRolesForMultipart(ctx, iamManager)
-
- // Create a valid JWT token for testing
- validJWTToken := createTestJWTMultipart(t, "https://test-issuer.com", "test-user-123", "test-signing-key")
-
- // Get session token
- response, err := iamManager.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/S3WriteRole",
- WebIdentityToken: validJWTToken,
- RoleSessionName: "multipart-test-session",
- })
- require.NoError(t, err)
-
- sessionToken := response.Credentials.SessionToken
-
- tests := []struct {
- name string
- operation MultipartOperation
- method string
- path string
- sessionToken string
- expectedResult s3err.ErrorCode
- }{
- {
- name: "Initiate multipart upload",
- operation: MultipartOpInitiate,
- method: "POST",
- path: "/test-bucket/test-file.txt?uploads",
- sessionToken: sessionToken,
- expectedResult: s3err.ErrNone,
- },
- {
- name: "Upload part",
- operation: MultipartOpUploadPart,
- method: "PUT",
- path: "/test-bucket/test-file.txt?partNumber=1&uploadId=test-upload-id",
- sessionToken: sessionToken,
- expectedResult: s3err.ErrNone,
- },
- {
- name: "Complete multipart upload",
- operation: MultipartOpComplete,
- method: "POST",
- path: "/test-bucket/test-file.txt?uploadId=test-upload-id",
- sessionToken: sessionToken,
- expectedResult: s3err.ErrNone,
- },
- {
- name: "Abort multipart upload",
- operation: MultipartOpAbort,
- method: "DELETE",
- path: "/test-bucket/test-file.txt?uploadId=test-upload-id",
- sessionToken: sessionToken,
- expectedResult: s3err.ErrNone,
- },
- {
- name: "List multipart uploads",
- operation: MultipartOpList,
- method: "GET",
- path: "/test-bucket?uploads",
- sessionToken: sessionToken,
- expectedResult: s3err.ErrNone,
- },
- {
- name: "Upload part without session token",
- operation: MultipartOpUploadPart,
- method: "PUT",
- path: "/test-bucket/test-file.txt?partNumber=1&uploadId=test-upload-id",
- sessionToken: "",
- expectedResult: s3err.ErrNone, // Falls back to standard auth
- },
- {
- name: "Upload part with invalid session token",
- operation: MultipartOpUploadPart,
- method: "PUT",
- path: "/test-bucket/test-file.txt?partNumber=1&uploadId=test-upload-id",
- sessionToken: "invalid-token",
- expectedResult: s3err.ErrAccessDenied,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Create request for multipart operation
- req := createMultipartRequest(t, tt.method, tt.path, tt.sessionToken)
-
- // Create identity for testing
- identity := &Identity{
- Name: "test-user",
- Account: &AccountAdmin,
- }
-
- // Test validation
- result := iam.ValidateMultipartOperationWithIAM(req, identity, tt.operation)
- assert.Equal(t, tt.expectedResult, result, "Multipart IAM validation result should match expected")
- })
- }
-}
-
-// TestMultipartUploadPolicy tests multipart upload security policies
-func TestMultipartUploadPolicy(t *testing.T) {
- policy := &MultipartUploadPolicy{
- MaxPartSize: 10 * 1024 * 1024, // 10MB for testing
- MinPartSize: 5 * 1024 * 1024, // 5MB minimum
- MaxParts: 100, // 100 parts max for testing
- AllowedContentTypes: []string{"application/json", "text/plain"},
- RequiredHeaders: []string{"Content-Type"},
- }
-
- tests := []struct {
- name string
- request *MultipartUploadRequest
- expectedError string
- }{
- {
- name: "Valid upload part request",
- request: &MultipartUploadRequest{
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- PartNumber: 1,
- Operation: string(MultipartOpUploadPart),
- ContentSize: 8 * 1024 * 1024, // 8MB
- Headers: map[string]string{
- "Content-Type": "application/json",
- },
- },
- expectedError: "",
- },
- {
- name: "Part size too large",
- request: &MultipartUploadRequest{
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- PartNumber: 1,
- Operation: string(MultipartOpUploadPart),
- ContentSize: 15 * 1024 * 1024, // 15MB exceeds limit
- Headers: map[string]string{
- "Content-Type": "application/json",
- },
- },
- expectedError: "part size",
- },
- {
- name: "Invalid part number (too high)",
- request: &MultipartUploadRequest{
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- PartNumber: 150, // Exceeds max parts
- Operation: string(MultipartOpUploadPart),
- ContentSize: 8 * 1024 * 1024,
- Headers: map[string]string{
- "Content-Type": "application/json",
- },
- },
- expectedError: "part number",
- },
- {
- name: "Invalid part number (too low)",
- request: &MultipartUploadRequest{
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- PartNumber: 0, // Must be >= 1
- Operation: string(MultipartOpUploadPart),
- ContentSize: 8 * 1024 * 1024,
- Headers: map[string]string{
- "Content-Type": "application/json",
- },
- },
- expectedError: "part number",
- },
- {
- name: "Content type not allowed",
- request: &MultipartUploadRequest{
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- PartNumber: 1,
- Operation: string(MultipartOpUploadPart),
- ContentSize: 8 * 1024 * 1024,
- Headers: map[string]string{
- "Content-Type": "video/mp4", // Not in allowed list
- },
- },
- expectedError: "content type video/mp4 is not allowed",
- },
- {
- name: "Missing required header",
- request: &MultipartUploadRequest{
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- PartNumber: 1,
- Operation: string(MultipartOpUploadPart),
- ContentSize: 8 * 1024 * 1024,
- Headers: map[string]string{}, // Missing Content-Type
- },
- expectedError: "required header Content-Type is missing",
- },
- {
- name: "Non-upload operation (should not validate size)",
- request: &MultipartUploadRequest{
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- Operation: string(MultipartOpInitiate),
- Headers: map[string]string{
- "Content-Type": "application/json",
- },
- },
- expectedError: "",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := policy.ValidateMultipartRequestWithPolicy(tt.request)
-
- if tt.expectedError == "" {
- assert.NoError(t, err, "Policy validation should succeed")
- } else {
- assert.Error(t, err, "Policy validation should fail")
- assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text")
- }
- })
- }
-}
-
-// TestMultipartS3ActionMapping tests the mapping of multipart operations to S3 actions
-func TestMultipartS3ActionMapping(t *testing.T) {
- tests := []struct {
- operation MultipartOperation
- expectedAction Action
- }{
- {MultipartOpInitiate, s3_constants.S3_ACTION_CREATE_MULTIPART},
- {MultipartOpUploadPart, s3_constants.S3_ACTION_UPLOAD_PART},
- {MultipartOpComplete, s3_constants.S3_ACTION_COMPLETE_MULTIPART},
- {MultipartOpAbort, s3_constants.S3_ACTION_ABORT_MULTIPART},
- {MultipartOpList, s3_constants.S3_ACTION_LIST_MULTIPART_UPLOADS},
- {MultipartOpListParts, s3_constants.S3_ACTION_LIST_PARTS},
- {MultipartOperation("unknown"), "s3:InternalErrorUnknownMultipartAction"}, // Fail-closed for security
- }
-
- for _, tt := range tests {
- t.Run(string(tt.operation), func(t *testing.T) {
- action := determineMultipartS3Action(tt.operation)
- assert.Equal(t, tt.expectedAction, action, "S3 action mapping should match expected")
- })
- }
-}
-
-// TestSessionTokenExtraction tests session token extraction from various sources
-func TestSessionTokenExtraction(t *testing.T) {
- tests := []struct {
- name string
- setupRequest func() *http.Request
- expectedToken string
- }{
- {
- name: "Bearer token in Authorization header",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt", nil)
- req.Header.Set("Authorization", "Bearer test-session-token-123")
- return req
- },
- expectedToken: "test-session-token-123",
- },
- {
- name: "X-Amz-Security-Token header",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt", nil)
- req.Header.Set("X-Amz-Security-Token", "security-token-456")
- return req
- },
- expectedToken: "security-token-456",
- },
- {
- name: "X-Amz-Security-Token query parameter",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt?X-Amz-Security-Token=query-token-789", nil)
- return req
- },
- expectedToken: "query-token-789",
- },
- {
- name: "No token present",
- setupRequest: func() *http.Request {
- return httptest.NewRequest("PUT", "/test-bucket/test-file.txt", nil)
- },
- expectedToken: "",
- },
- {
- name: "Authorization header without Bearer",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt", nil)
- req.Header.Set("Authorization", "AWS access_key:signature")
- return req
- },
- expectedToken: "",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- req := tt.setupRequest()
- token := extractSessionTokenFromRequest(req)
- assert.Equal(t, tt.expectedToken, token, "Extracted token should match expected")
- })
- }
-}
-
-// TestUploadPartValidation tests upload part request validation
-func TestUploadPartValidation(t *testing.T) {
- s3Server := &S3ApiServer{}
-
- tests := []struct {
- name string
- setupRequest func() *http.Request
- expectedError string
- }{
- {
- name: "Valid upload part request",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt?partNumber=1&uploadId=test-123", nil)
- req.Header.Set("Content-Type", "application/octet-stream")
- req.ContentLength = 6 * 1024 * 1024 // 6MB
- return req
- },
- expectedError: "",
- },
- {
- name: "Missing partNumber parameter",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt?uploadId=test-123", nil)
- req.Header.Set("Content-Type", "application/octet-stream")
- req.ContentLength = 6 * 1024 * 1024
- return req
- },
- expectedError: "missing partNumber parameter",
- },
- {
- name: "Invalid partNumber format",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt?partNumber=abc&uploadId=test-123", nil)
- req.Header.Set("Content-Type", "application/octet-stream")
- req.ContentLength = 6 * 1024 * 1024
- return req
- },
- expectedError: "invalid partNumber",
- },
- {
- name: "Part size too large",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("PUT", "/test-bucket/test-file.txt?partNumber=1&uploadId=test-123", nil)
- req.Header.Set("Content-Type", "application/octet-stream")
- req.ContentLength = 6 * 1024 * 1024 * 1024 // 6GB exceeds 5GB limit
- return req
- },
- expectedError: "part size",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- req := tt.setupRequest()
- err := s3Server.validateUploadPartRequest(req)
-
- if tt.expectedError == "" {
- assert.NoError(t, err, "Upload part validation should succeed")
- } else {
- assert.Error(t, err, "Upload part validation should fail")
- assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text")
- }
- })
- }
-}
-
-// TestDefaultMultipartUploadPolicy tests the default policy configuration
-func TestDefaultMultipartUploadPolicy(t *testing.T) {
- policy := DefaultMultipartUploadPolicy()
-
- assert.Equal(t, int64(5*1024*1024*1024), policy.MaxPartSize, "Max part size should be 5GB")
- assert.Equal(t, int64(5*1024*1024), policy.MinPartSize, "Min part size should be 5MB")
- assert.Equal(t, 10000, policy.MaxParts, "Max parts should be 10,000")
- assert.Equal(t, 7*24*time.Hour, policy.MaxUploadDuration, "Max upload duration should be 7 days")
- assert.Empty(t, policy.AllowedContentTypes, "Should allow all content types by default")
- assert.Empty(t, policy.RequiredHeaders, "Should have no required headers by default")
- assert.Empty(t, policy.IPWhitelist, "Should have no IP restrictions by default")
-}
-
-// TestMultipartUploadSession tests multipart upload session structure
-func TestMultipartUploadSession(t *testing.T) {
- session := &MultipartUploadSession{
- UploadID: "test-upload-123",
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- Initiator: "arn:aws:iam::user/testuser",
- Owner: "arn:aws:iam::user/testuser",
- CreatedAt: time.Now(),
- Parts: []MultipartUploadPart{
- {
- PartNumber: 1,
- Size: 5 * 1024 * 1024,
- ETag: "abc123",
- LastModified: time.Now(),
- Checksum: "sha256:def456",
- },
- },
- Metadata: map[string]string{
- "Content-Type": "application/octet-stream",
- "x-amz-meta-custom": "value",
- },
- Policy: DefaultMultipartUploadPolicy(),
- SessionToken: "session-token-789",
- }
-
- assert.NotEmpty(t, session.UploadID, "Upload ID should not be empty")
- assert.NotEmpty(t, session.Bucket, "Bucket should not be empty")
- assert.NotEmpty(t, session.ObjectKey, "Object key should not be empty")
- assert.Len(t, session.Parts, 1, "Should have one part")
- assert.Equal(t, 1, session.Parts[0].PartNumber, "Part number should be 1")
- assert.NotNil(t, session.Policy, "Policy should not be nil")
-}
-
-// Helper functions for tests
-
-func setupTestIAMManagerForMultipart(t *testing.T) *integration.IAMManager {
- // Create IAM manager
- manager := integration.NewIAMManager()
-
- // Initialize with test configuration
- config := &integration.IAMConfig{
- STS: &sts.STSConfig{
- TokenDuration: sts.FlexibleDuration{Duration: time.Hour},
- MaxSessionLength: sts.FlexibleDuration{Duration: time.Hour * 12},
- Issuer: "test-sts",
- SigningKey: []byte("test-signing-key-32-characters-long"),
- },
- Policy: &policy.PolicyEngineConfig{
- DefaultEffect: "Deny",
- StoreType: "memory",
- },
- Roles: &integration.RoleStoreConfig{
- StoreType: "memory",
- },
- }
-
- err := manager.Initialize(config, func() string {
- return "localhost:8888" // Mock filer address for testing
- })
- require.NoError(t, err)
-
- // Set up test identity providers
- setupTestProvidersForMultipart(t, manager)
-
- return manager
-}
-
-func setupTestProvidersForMultipart(t *testing.T, manager *integration.IAMManager) {
- // Set up OIDC provider
- oidcProvider := oidc.NewMockOIDCProvider("test-oidc")
- oidcConfig := &oidc.OIDCConfig{
- Issuer: "https://test-issuer.com",
- ClientID: "test-client-id",
- }
- err := oidcProvider.Initialize(oidcConfig)
- require.NoError(t, err)
- oidcProvider.SetupDefaultTestData()
-
- // Set up LDAP provider
- ldapProvider := ldap.NewMockLDAPProvider("test-ldap")
- err = ldapProvider.Initialize(nil) // Mock doesn't need real config
- require.NoError(t, err)
- ldapProvider.SetupDefaultTestData()
-
- // Register providers
- err = manager.RegisterIdentityProvider(oidcProvider)
- require.NoError(t, err)
- err = manager.RegisterIdentityProvider(ldapProvider)
- require.NoError(t, err)
-}
-
-func setupTestRolesForMultipart(ctx context.Context, manager *integration.IAMManager) {
- // Create write policy for multipart operations
- writePolicy := &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "AllowS3MultipartOperations",
- Effect: "Allow",
- Action: []string{
- "s3:PutObject",
- "s3:GetObject",
- "s3:ListBucket",
- "s3:DeleteObject",
- "s3:CreateMultipartUpload",
- "s3:UploadPart",
- "s3:CompleteMultipartUpload",
- "s3:AbortMultipartUpload",
- "s3:ListBucketMultipartUploads",
- "s3:ListMultipartUploadParts",
- },
- Resource: []string{
- "arn:aws:s3:::*",
- "arn:aws:s3:::*/*",
- },
- },
- },
- }
-
- manager.CreatePolicy(ctx, "", "S3WritePolicy", writePolicy)
-
- // Create write role
- manager.CreateRole(ctx, "", "S3WriteRole", &integration.RoleDefinition{
- RoleName: "S3WriteRole",
- TrustPolicy: &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Effect: "Allow",
- Principal: map[string]interface{}{
- "Federated": "test-oidc",
- },
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- },
- },
- },
- AttachedPolicies: []string{"S3WritePolicy"},
- })
-
- // Create a role for multipart users
- manager.CreateRole(ctx, "", "MultipartUser", &integration.RoleDefinition{
- RoleName: "MultipartUser",
- TrustPolicy: &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Effect: "Allow",
- Principal: map[string]interface{}{
- "Federated": "test-oidc",
- },
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- },
- },
- },
- AttachedPolicies: []string{"S3WritePolicy"},
- })
-}
-
-func createMultipartRequest(t *testing.T, method, path, sessionToken string) *http.Request {
- req := httptest.NewRequest(method, path, nil)
-
- // Add session token if provided
- if sessionToken != "" {
- req.Header.Set("Authorization", "Bearer "+sessionToken)
- // Set the principal ARN header that matches the assumed role from the test setup
- // This corresponds to the role "arn:aws:iam::role/S3WriteRole" with session name "multipart-test-session"
- req.Header.Set("X-SeaweedFS-Principal", "arn:aws:sts::assumed-role/S3WriteRole/multipart-test-session")
- }
-
- // Add common headers
- req.Header.Set("Content-Type", "application/octet-stream")
-
- return req
-}
diff --git a/weed/s3api/s3_policy_templates.go b/weed/s3api/s3_policy_templates.go
deleted file mode 100644
index 1506c68ee..000000000
--- a/weed/s3api/s3_policy_templates.go
+++ /dev/null
@@ -1,618 +0,0 @@
-package s3api
-
-import (
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/iam/policy"
-)
-
-// S3PolicyTemplates provides pre-built IAM policy templates for common S3 use cases
-type S3PolicyTemplates struct{}
-
-// NewS3PolicyTemplates creates a new policy templates provider
-func NewS3PolicyTemplates() *S3PolicyTemplates {
- return &S3PolicyTemplates{}
-}
-
-// GetS3ReadOnlyPolicy returns a policy that allows read-only access to all S3 resources
-func (t *S3PolicyTemplates) GetS3ReadOnlyPolicy() *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "S3ReadOnlyAccess",
- Effect: "Allow",
- Action: []string{
- "s3:GetObject",
- "s3:GetObjectVersion",
- "s3:ListBucket",
- "s3:ListBucketVersions",
- "s3:GetBucketLocation",
- "s3:GetBucketVersioning",
- "s3:ListAllMyBuckets",
- },
- Resource: []string{
- "arn:aws:s3:::*",
- "arn:aws:s3:::*/*",
- },
- },
- },
- }
-}
-
-// GetS3WriteOnlyPolicy returns a policy that allows write-only access to all S3 resources
-func (t *S3PolicyTemplates) GetS3WriteOnlyPolicy() *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "S3WriteOnlyAccess",
- Effect: "Allow",
- Action: []string{
- "s3:PutObject",
- "s3:PutObjectAcl",
- "s3:CreateMultipartUpload",
- "s3:UploadPart",
- "s3:CompleteMultipartUpload",
- "s3:AbortMultipartUpload",
- "s3:ListMultipartUploads",
- "s3:ListParts",
- },
- Resource: []string{
- "arn:aws:s3:::*",
- "arn:aws:s3:::*/*",
- },
- },
- },
- }
-}
-
-// GetS3AdminPolicy returns a policy that allows full admin access to all S3 resources
-func (t *S3PolicyTemplates) GetS3AdminPolicy() *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "S3FullAccess",
- Effect: "Allow",
- Action: []string{
- "s3:*",
- },
- Resource: []string{
- "arn:aws:s3:::*",
- "arn:aws:s3:::*/*",
- },
- },
- },
- }
-}
-
-// GetBucketSpecificReadPolicy returns a policy for read-only access to a specific bucket
-func (t *S3PolicyTemplates) GetBucketSpecificReadPolicy(bucketName string) *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "BucketSpecificReadAccess",
- Effect: "Allow",
- Action: []string{
- "s3:GetObject",
- "s3:GetObjectVersion",
- "s3:ListBucket",
- "s3:ListBucketVersions",
- "s3:GetBucketLocation",
- },
- Resource: []string{
- "arn:aws:s3:::" + bucketName,
- "arn:aws:s3:::" + bucketName + "/*",
- },
- },
- },
- }
-}
-
-// GetBucketSpecificWritePolicy returns a policy for write-only access to a specific bucket
-func (t *S3PolicyTemplates) GetBucketSpecificWritePolicy(bucketName string) *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "BucketSpecificWriteAccess",
- Effect: "Allow",
- Action: []string{
- "s3:PutObject",
- "s3:PutObjectAcl",
- "s3:CreateMultipartUpload",
- "s3:UploadPart",
- "s3:CompleteMultipartUpload",
- "s3:AbortMultipartUpload",
- "s3:ListMultipartUploads",
- "s3:ListParts",
- },
- Resource: []string{
- "arn:aws:s3:::" + bucketName,
- "arn:aws:s3:::" + bucketName + "/*",
- },
- },
- },
- }
-}
-
-// GetPathBasedAccessPolicy returns a policy that restricts access to a specific path within a bucket
-func (t *S3PolicyTemplates) GetPathBasedAccessPolicy(bucketName, pathPrefix string) *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "ListBucketPermission",
- Effect: "Allow",
- Action: []string{
- "s3:ListBucket",
- },
- Resource: []string{
- "arn:aws:s3:::" + bucketName,
- },
- Condition: map[string]map[string]interface{}{
- "StringLike": map[string]interface{}{
- "s3:prefix": []string{pathPrefix + "/*"},
- },
- },
- },
- {
- Sid: "PathBasedObjectAccess",
- Effect: "Allow",
- Action: []string{
- "s3:GetObject",
- "s3:PutObject",
- "s3:DeleteObject",
- "s3:CreateMultipartUpload",
- "s3:UploadPart",
- "s3:CompleteMultipartUpload",
- "s3:AbortMultipartUpload",
- },
- Resource: []string{
- "arn:aws:s3:::" + bucketName + "/" + pathPrefix + "/*",
- },
- },
- },
- }
-}
-
-// GetIPRestrictedPolicy returns a policy that restricts access based on source IP
-func (t *S3PolicyTemplates) GetIPRestrictedPolicy(allowedCIDRs []string) *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "IPRestrictedS3Access",
- Effect: "Allow",
- Action: []string{
- "s3:*",
- },
- Resource: []string{
- "arn:aws:s3:::*",
- "arn:aws:s3:::*/*",
- },
- Condition: map[string]map[string]interface{}{
- "IpAddress": map[string]interface{}{
- "aws:SourceIp": allowedCIDRs,
- },
- },
- },
- },
- }
-}
-
-// GetTimeBasedAccessPolicy returns a policy that allows access only during specific hours
-func (t *S3PolicyTemplates) GetTimeBasedAccessPolicy(startHour, endHour int) *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "TimeBasedS3Access",
- Effect: "Allow",
- Action: []string{
- "s3:GetObject",
- "s3:PutObject",
- "s3:ListBucket",
- },
- Resource: []string{
- "arn:aws:s3:::*",
- "arn:aws:s3:::*/*",
- },
- Condition: map[string]map[string]interface{}{
- "DateGreaterThan": map[string]interface{}{
- "aws:CurrentTime": time.Now().Format("2006-01-02") + "T" +
- formatHour(startHour) + ":00:00Z",
- },
- "DateLessThan": map[string]interface{}{
- "aws:CurrentTime": time.Now().Format("2006-01-02") + "T" +
- formatHour(endHour) + ":00:00Z",
- },
- },
- },
- },
- }
-}
-
-// GetMultipartUploadPolicy returns a policy specifically for multipart upload operations
-func (t *S3PolicyTemplates) GetMultipartUploadPolicy(bucketName string) *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "MultipartUploadOperations",
- Effect: "Allow",
- Action: []string{
- "s3:CreateMultipartUpload",
- "s3:UploadPart",
- "s3:CompleteMultipartUpload",
- "s3:AbortMultipartUpload",
- "s3:ListMultipartUploads",
- "s3:ListParts",
- },
- Resource: []string{
- "arn:aws:s3:::" + bucketName + "/*",
- },
- },
- {
- Sid: "ListBucketForMultipart",
- Effect: "Allow",
- Action: []string{
- "s3:ListBucket",
- },
- Resource: []string{
- "arn:aws:s3:::" + bucketName,
- },
- },
- },
- }
-}
-
-// GetPresignedURLPolicy returns a policy for generating and using presigned URLs
-func (t *S3PolicyTemplates) GetPresignedURLPolicy(bucketName string) *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "PresignedURLAccess",
- Effect: "Allow",
- Action: []string{
- "s3:GetObject",
- "s3:PutObject",
- },
- Resource: []string{
- "arn:aws:s3:::" + bucketName + "/*",
- },
- Condition: map[string]map[string]interface{}{
- "StringEquals": map[string]interface{}{
- "s3:x-amz-signature-version": "AWS4-HMAC-SHA256",
- },
- },
- },
- },
- }
-}
-
-// GetTemporaryAccessPolicy returns a policy for temporary access with expiration
-func (t *S3PolicyTemplates) GetTemporaryAccessPolicy(bucketName string, expirationHours int) *policy.PolicyDocument {
- expirationTime := time.Now().Add(time.Duration(expirationHours) * time.Hour)
-
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "TemporaryS3Access",
- Effect: "Allow",
- Action: []string{
- "s3:GetObject",
- "s3:PutObject",
- "s3:ListBucket",
- },
- Resource: []string{
- "arn:aws:s3:::" + bucketName,
- "arn:aws:s3:::" + bucketName + "/*",
- },
- Condition: map[string]map[string]interface{}{
- "DateLessThan": map[string]interface{}{
- "aws:CurrentTime": expirationTime.UTC().Format("2006-01-02T15:04:05Z"),
- },
- },
- },
- },
- }
-}
-
-// GetContentTypeRestrictedPolicy returns a policy that restricts uploads to specific content types
-func (t *S3PolicyTemplates) GetContentTypeRestrictedPolicy(bucketName string, allowedContentTypes []string) *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "ContentTypeRestrictedUpload",
- Effect: "Allow",
- Action: []string{
- "s3:PutObject",
- "s3:CreateMultipartUpload",
- "s3:UploadPart",
- "s3:CompleteMultipartUpload",
- },
- Resource: []string{
- "arn:aws:s3:::" + bucketName + "/*",
- },
- Condition: map[string]map[string]interface{}{
- "StringEquals": map[string]interface{}{
- "s3:content-type": allowedContentTypes,
- },
- },
- },
- {
- Sid: "ReadAccess",
- Effect: "Allow",
- Action: []string{
- "s3:GetObject",
- "s3:ListBucket",
- },
- Resource: []string{
- "arn:aws:s3:::" + bucketName,
- "arn:aws:s3:::" + bucketName + "/*",
- },
- },
- },
- }
-}
-
-// GetDenyDeletePolicy returns a policy that allows all operations except delete
-func (t *S3PolicyTemplates) GetDenyDeletePolicy() *policy.PolicyDocument {
- return &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "AllowAllExceptDelete",
- Effect: "Allow",
- Action: []string{
- "s3:GetObject",
- "s3:GetObjectVersion",
- "s3:PutObject",
- "s3:PutObjectAcl",
- "s3:ListBucket",
- "s3:ListBucketVersions",
- "s3:CreateMultipartUpload",
- "s3:UploadPart",
- "s3:CompleteMultipartUpload",
- "s3:AbortMultipartUpload",
- "s3:ListMultipartUploads",
- "s3:ListParts",
- },
- Resource: []string{
- "arn:aws:s3:::*",
- "arn:aws:s3:::*/*",
- },
- },
- {
- Sid: "DenyDeleteOperations",
- Effect: "Deny",
- Action: []string{
- "s3:DeleteObject",
- "s3:DeleteObjectVersion",
- "s3:DeleteBucket",
- },
- Resource: []string{
- "arn:aws:s3:::*",
- "arn:aws:s3:::*/*",
- },
- },
- },
- }
-}
-
-// Helper function to format hour with leading zero
-func formatHour(hour int) string {
- if hour < 10 {
- return "0" + string(rune('0'+hour))
- }
- return string(rune('0'+hour/10)) + string(rune('0'+hour%10))
-}
-
-// PolicyTemplateDefinition represents metadata about a policy template
-type PolicyTemplateDefinition struct {
- Name string `json:"name"`
- Description string `json:"description"`
- Category string `json:"category"`
- UseCase string `json:"use_case"`
- Parameters []PolicyTemplateParam `json:"parameters,omitempty"`
- Policy *policy.PolicyDocument `json:"policy"`
-}
-
-// PolicyTemplateParam represents a parameter for customizing policy templates
-type PolicyTemplateParam struct {
- Name string `json:"name"`
- Type string `json:"type"`
- Description string `json:"description"`
- Required bool `json:"required"`
- DefaultValue string `json:"default_value,omitempty"`
- Example string `json:"example,omitempty"`
-}
-
-// GetAllPolicyTemplates returns all available policy templates with metadata
-func (t *S3PolicyTemplates) GetAllPolicyTemplates() []PolicyTemplateDefinition {
- return []PolicyTemplateDefinition{
- {
- Name: "S3ReadOnlyAccess",
- Description: "Provides read-only access to all S3 buckets and objects",
- Category: "Basic Access",
- UseCase: "Data consumers, backup services, monitoring applications",
- Policy: t.GetS3ReadOnlyPolicy(),
- },
- {
- Name: "S3WriteOnlyAccess",
- Description: "Provides write-only access to all S3 buckets and objects",
- Category: "Basic Access",
- UseCase: "Data ingestion services, backup applications",
- Policy: t.GetS3WriteOnlyPolicy(),
- },
- {
- Name: "S3AdminAccess",
- Description: "Provides full administrative access to all S3 resources",
- Category: "Administrative",
- UseCase: "S3 administrators, service accounts with full control",
- Policy: t.GetS3AdminPolicy(),
- },
- {
- Name: "BucketSpecificRead",
- Description: "Provides read-only access to a specific bucket",
- Category: "Bucket-Specific",
- UseCase: "Applications that need access to specific data sets",
- Parameters: []PolicyTemplateParam{
- {
- Name: "bucketName",
- Type: "string",
- Description: "Name of the S3 bucket to grant access to",
- Required: true,
- Example: "my-data-bucket",
- },
- },
- Policy: t.GetBucketSpecificReadPolicy("${bucketName}"),
- },
- {
- Name: "BucketSpecificWrite",
- Description: "Provides write-only access to a specific bucket",
- Category: "Bucket-Specific",
- UseCase: "Upload services, data ingestion for specific datasets",
- Parameters: []PolicyTemplateParam{
- {
- Name: "bucketName",
- Type: "string",
- Description: "Name of the S3 bucket to grant access to",
- Required: true,
- Example: "my-upload-bucket",
- },
- },
- Policy: t.GetBucketSpecificWritePolicy("${bucketName}"),
- },
- {
- Name: "PathBasedAccess",
- Description: "Restricts access to a specific path/prefix within a bucket",
- Category: "Path-Restricted",
- UseCase: "Multi-tenant applications, user-specific directories",
- Parameters: []PolicyTemplateParam{
- {
- Name: "bucketName",
- Type: "string",
- Description: "Name of the S3 bucket",
- Required: true,
- Example: "shared-bucket",
- },
- {
- Name: "pathPrefix",
- Type: "string",
- Description: "Path prefix to restrict access to",
- Required: true,
- Example: "user123/documents",
- },
- },
- Policy: t.GetPathBasedAccessPolicy("${bucketName}", "${pathPrefix}"),
- },
- {
- Name: "IPRestrictedAccess",
- Description: "Allows access only from specific IP addresses or ranges",
- Category: "Security",
- UseCase: "Corporate networks, office-based access, VPN restrictions",
- Parameters: []PolicyTemplateParam{
- {
- Name: "allowedCIDRs",
- Type: "array",
- Description: "List of allowed IP addresses or CIDR ranges",
- Required: true,
- Example: "[\"192.168.1.0/24\", \"10.0.0.0/8\"]",
- },
- },
- Policy: t.GetIPRestrictedPolicy([]string{"${allowedCIDRs}"}),
- },
- {
- Name: "MultipartUploadOnly",
- Description: "Allows only multipart upload operations on a specific bucket",
- Category: "Upload-Specific",
- UseCase: "Large file upload services, streaming applications",
- Parameters: []PolicyTemplateParam{
- {
- Name: "bucketName",
- Type: "string",
- Description: "Name of the S3 bucket for multipart uploads",
- Required: true,
- Example: "large-files-bucket",
- },
- },
- Policy: t.GetMultipartUploadPolicy("${bucketName}"),
- },
- {
- Name: "PresignedURLAccess",
- Description: "Policy for generating and using presigned URLs",
- Category: "Presigned URLs",
- UseCase: "Frontend applications, temporary file sharing",
- Parameters: []PolicyTemplateParam{
- {
- Name: "bucketName",
- Type: "string",
- Description: "Name of the S3 bucket for presigned URL access",
- Required: true,
- Example: "shared-files-bucket",
- },
- },
- Policy: t.GetPresignedURLPolicy("${bucketName}"),
- },
- {
- Name: "ContentTypeRestricted",
- Description: "Restricts uploads to specific content types",
- Category: "Content Control",
- UseCase: "Image galleries, document repositories, media libraries",
- Parameters: []PolicyTemplateParam{
- {
- Name: "bucketName",
- Type: "string",
- Description: "Name of the S3 bucket",
- Required: true,
- Example: "media-bucket",
- },
- {
- Name: "allowedContentTypes",
- Type: "array",
- Description: "List of allowed MIME content types",
- Required: true,
- Example: "[\"image/jpeg\", \"image/png\", \"video/mp4\"]",
- },
- },
- Policy: t.GetContentTypeRestrictedPolicy("${bucketName}", []string{"${allowedContentTypes}"}),
- },
- {
- Name: "DenyDeleteAccess",
- Description: "Allows all operations except delete (immutable storage)",
- Category: "Data Protection",
- UseCase: "Compliance storage, audit logs, backup retention",
- Policy: t.GetDenyDeletePolicy(),
- },
- }
-}
-
-// GetPolicyTemplateByName returns a specific policy template by name
-func (t *S3PolicyTemplates) GetPolicyTemplateByName(name string) *PolicyTemplateDefinition {
- templates := t.GetAllPolicyTemplates()
- for _, template := range templates {
- if template.Name == name {
- return &template
- }
- }
- return nil
-}
-
-// GetPolicyTemplatesByCategory returns all policy templates in a specific category
-func (t *S3PolicyTemplates) GetPolicyTemplatesByCategory(category string) []PolicyTemplateDefinition {
- var result []PolicyTemplateDefinition
- templates := t.GetAllPolicyTemplates()
- for _, template := range templates {
- if template.Category == category {
- result = append(result, template)
- }
- }
- return result
-}
diff --git a/weed/s3api/s3_policy_templates_test.go b/weed/s3api/s3_policy_templates_test.go
deleted file mode 100644
index 453260c2a..000000000
--- a/weed/s3api/s3_policy_templates_test.go
+++ /dev/null
@@ -1,504 +0,0 @@
-package s3api
-
-import (
- "fmt"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestS3PolicyTemplates(t *testing.T) {
- templates := NewS3PolicyTemplates()
-
- t.Run("S3ReadOnlyPolicy", func(t *testing.T) {
- policy := templates.GetS3ReadOnlyPolicy()
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 1)
-
- stmt := policy.Statement[0]
- assert.Equal(t, "Allow", stmt.Effect)
- assert.Equal(t, "S3ReadOnlyAccess", stmt.Sid)
- assert.Contains(t, stmt.Action, "s3:GetObject")
- assert.Contains(t, stmt.Action, "s3:ListBucket")
- assert.NotContains(t, stmt.Action, "s3:PutObject")
- assert.NotContains(t, stmt.Action, "s3:DeleteObject")
-
- assert.Contains(t, stmt.Resource, "arn:aws:s3:::*")
- assert.Contains(t, stmt.Resource, "arn:aws:s3:::*/*")
- })
-
- t.Run("S3WriteOnlyPolicy", func(t *testing.T) {
- policy := templates.GetS3WriteOnlyPolicy()
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 1)
-
- stmt := policy.Statement[0]
- assert.Equal(t, "Allow", stmt.Effect)
- assert.Equal(t, "S3WriteOnlyAccess", stmt.Sid)
- assert.Contains(t, stmt.Action, "s3:PutObject")
- assert.Contains(t, stmt.Action, "s3:CreateMultipartUpload")
- assert.NotContains(t, stmt.Action, "s3:GetObject")
- assert.NotContains(t, stmt.Action, "s3:DeleteObject")
-
- assert.Contains(t, stmt.Resource, "arn:aws:s3:::*")
- assert.Contains(t, stmt.Resource, "arn:aws:s3:::*/*")
- })
-
- t.Run("S3AdminPolicy", func(t *testing.T) {
- policy := templates.GetS3AdminPolicy()
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 1)
-
- stmt := policy.Statement[0]
- assert.Equal(t, "Allow", stmt.Effect)
- assert.Equal(t, "S3FullAccess", stmt.Sid)
- assert.Contains(t, stmt.Action, "s3:*")
-
- assert.Contains(t, stmt.Resource, "arn:aws:s3:::*")
- assert.Contains(t, stmt.Resource, "arn:aws:s3:::*/*")
- })
-}
-
-func TestBucketSpecificPolicies(t *testing.T) {
- templates := NewS3PolicyTemplates()
- bucketName := "test-bucket"
-
- t.Run("BucketSpecificReadPolicy", func(t *testing.T) {
- policy := templates.GetBucketSpecificReadPolicy(bucketName)
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 1)
-
- stmt := policy.Statement[0]
- assert.Equal(t, "Allow", stmt.Effect)
- assert.Equal(t, "BucketSpecificReadAccess", stmt.Sid)
- assert.Contains(t, stmt.Action, "s3:GetObject")
- assert.Contains(t, stmt.Action, "s3:ListBucket")
- assert.NotContains(t, stmt.Action, "s3:PutObject")
-
- expectedBucketArn := "arn:aws:s3:::" + bucketName
- expectedObjectArn := "arn:aws:s3:::" + bucketName + "/*"
- assert.Contains(t, stmt.Resource, expectedBucketArn)
- assert.Contains(t, stmt.Resource, expectedObjectArn)
- })
-
- t.Run("BucketSpecificWritePolicy", func(t *testing.T) {
- policy := templates.GetBucketSpecificWritePolicy(bucketName)
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 1)
-
- stmt := policy.Statement[0]
- assert.Equal(t, "Allow", stmt.Effect)
- assert.Equal(t, "BucketSpecificWriteAccess", stmt.Sid)
- assert.Contains(t, stmt.Action, "s3:PutObject")
- assert.Contains(t, stmt.Action, "s3:CreateMultipartUpload")
- assert.NotContains(t, stmt.Action, "s3:GetObject")
-
- expectedBucketArn := "arn:aws:s3:::" + bucketName
- expectedObjectArn := "arn:aws:s3:::" + bucketName + "/*"
- assert.Contains(t, stmt.Resource, expectedBucketArn)
- assert.Contains(t, stmt.Resource, expectedObjectArn)
- })
-}
-
-func TestPathBasedAccessPolicy(t *testing.T) {
- templates := NewS3PolicyTemplates()
- bucketName := "shared-bucket"
- pathPrefix := "user123/documents"
-
- policy := templates.GetPathBasedAccessPolicy(bucketName, pathPrefix)
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 2)
-
- // First statement: List bucket with prefix condition
- listStmt := policy.Statement[0]
- assert.Equal(t, "Allow", listStmt.Effect)
- assert.Equal(t, "ListBucketPermission", listStmt.Sid)
- assert.Contains(t, listStmt.Action, "s3:ListBucket")
- assert.Contains(t, listStmt.Resource, "arn:aws:s3:::"+bucketName)
- assert.NotNil(t, listStmt.Condition)
-
- // Second statement: Object operations on path
- objectStmt := policy.Statement[1]
- assert.Equal(t, "Allow", objectStmt.Effect)
- assert.Equal(t, "PathBasedObjectAccess", objectStmt.Sid)
- assert.Contains(t, objectStmt.Action, "s3:GetObject")
- assert.Contains(t, objectStmt.Action, "s3:PutObject")
- assert.Contains(t, objectStmt.Action, "s3:DeleteObject")
-
- expectedObjectArn := "arn:aws:s3:::" + bucketName + "/" + pathPrefix + "/*"
- assert.Contains(t, objectStmt.Resource, expectedObjectArn)
-}
-
-func TestIPRestrictedPolicy(t *testing.T) {
- templates := NewS3PolicyTemplates()
- allowedCIDRs := []string{"192.168.1.0/24", "10.0.0.0/8"}
-
- policy := templates.GetIPRestrictedPolicy(allowedCIDRs)
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 1)
-
- stmt := policy.Statement[0]
- assert.Equal(t, "Allow", stmt.Effect)
- assert.Equal(t, "IPRestrictedS3Access", stmt.Sid)
- assert.Contains(t, stmt.Action, "s3:*")
- assert.NotNil(t, stmt.Condition)
-
- // Check IP condition structure
- condition := stmt.Condition
- ipAddress, exists := condition["IpAddress"]
- assert.True(t, exists)
-
- sourceIp, exists := ipAddress["aws:SourceIp"]
- assert.True(t, exists)
- assert.Equal(t, allowedCIDRs, sourceIp)
-}
-
-func TestTimeBasedAccessPolicy(t *testing.T) {
- templates := NewS3PolicyTemplates()
- startHour := 9 // 9 AM
- endHour := 17 // 5 PM
-
- policy := templates.GetTimeBasedAccessPolicy(startHour, endHour)
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 1)
-
- stmt := policy.Statement[0]
- assert.Equal(t, "Allow", stmt.Effect)
- assert.Equal(t, "TimeBasedS3Access", stmt.Sid)
- assert.Contains(t, stmt.Action, "s3:GetObject")
- assert.Contains(t, stmt.Action, "s3:PutObject")
- assert.Contains(t, stmt.Action, "s3:ListBucket")
- assert.NotNil(t, stmt.Condition)
-
- // Check time condition structure
- condition := stmt.Condition
- _, hasGreater := condition["DateGreaterThan"]
- _, hasLess := condition["DateLessThan"]
- assert.True(t, hasGreater)
- assert.True(t, hasLess)
-}
-
-func TestMultipartUploadPolicyTemplate(t *testing.T) {
- templates := NewS3PolicyTemplates()
- bucketName := "large-files"
-
- policy := templates.GetMultipartUploadPolicy(bucketName)
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 2)
-
- // First statement: Multipart operations
- multipartStmt := policy.Statement[0]
- assert.Equal(t, "Allow", multipartStmt.Effect)
- assert.Equal(t, "MultipartUploadOperations", multipartStmt.Sid)
- assert.Contains(t, multipartStmt.Action, "s3:CreateMultipartUpload")
- assert.Contains(t, multipartStmt.Action, "s3:UploadPart")
- assert.Contains(t, multipartStmt.Action, "s3:CompleteMultipartUpload")
- assert.Contains(t, multipartStmt.Action, "s3:AbortMultipartUpload")
- assert.Contains(t, multipartStmt.Action, "s3:ListMultipartUploads")
- assert.Contains(t, multipartStmt.Action, "s3:ListParts")
-
- expectedObjectArn := "arn:aws:s3:::" + bucketName + "/*"
- assert.Contains(t, multipartStmt.Resource, expectedObjectArn)
-
- // Second statement: List bucket
- listStmt := policy.Statement[1]
- assert.Equal(t, "Allow", listStmt.Effect)
- assert.Equal(t, "ListBucketForMultipart", listStmt.Sid)
- assert.Contains(t, listStmt.Action, "s3:ListBucket")
-
- expectedBucketArn := "arn:aws:s3:::" + bucketName
- assert.Contains(t, listStmt.Resource, expectedBucketArn)
-}
-
-func TestPresignedURLPolicy(t *testing.T) {
- templates := NewS3PolicyTemplates()
- bucketName := "shared-files"
-
- policy := templates.GetPresignedURLPolicy(bucketName)
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 1)
-
- stmt := policy.Statement[0]
- assert.Equal(t, "Allow", stmt.Effect)
- assert.Equal(t, "PresignedURLAccess", stmt.Sid)
- assert.Contains(t, stmt.Action, "s3:GetObject")
- assert.Contains(t, stmt.Action, "s3:PutObject")
- assert.NotNil(t, stmt.Condition)
-
- expectedObjectArn := "arn:aws:s3:::" + bucketName + "/*"
- assert.Contains(t, stmt.Resource, expectedObjectArn)
-
- // Check signature version condition
- condition := stmt.Condition
- stringEquals, exists := condition["StringEquals"]
- assert.True(t, exists)
-
- signatureVersion, exists := stringEquals["s3:x-amz-signature-version"]
- assert.True(t, exists)
- assert.Equal(t, "AWS4-HMAC-SHA256", signatureVersion)
-}
-
-func TestTemporaryAccessPolicy(t *testing.T) {
- templates := NewS3PolicyTemplates()
- bucketName := "temp-bucket"
- expirationHours := 24
-
- policy := templates.GetTemporaryAccessPolicy(bucketName, expirationHours)
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 1)
-
- stmt := policy.Statement[0]
- assert.Equal(t, "Allow", stmt.Effect)
- assert.Equal(t, "TemporaryS3Access", stmt.Sid)
- assert.Contains(t, stmt.Action, "s3:GetObject")
- assert.Contains(t, stmt.Action, "s3:PutObject")
- assert.Contains(t, stmt.Action, "s3:ListBucket")
- assert.NotNil(t, stmt.Condition)
-
- // Check expiration condition
- condition := stmt.Condition
- dateLessThan, exists := condition["DateLessThan"]
- assert.True(t, exists)
-
- currentTime, exists := dateLessThan["aws:CurrentTime"]
- assert.True(t, exists)
- assert.IsType(t, "", currentTime) // Should be a string timestamp
-}
-
-func TestContentTypeRestrictedPolicy(t *testing.T) {
- templates := NewS3PolicyTemplates()
- bucketName := "media-bucket"
- allowedTypes := []string{"image/jpeg", "image/png", "video/mp4"}
-
- policy := templates.GetContentTypeRestrictedPolicy(bucketName, allowedTypes)
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 2)
-
- // First statement: Upload with content type restriction
- uploadStmt := policy.Statement[0]
- assert.Equal(t, "Allow", uploadStmt.Effect)
- assert.Equal(t, "ContentTypeRestrictedUpload", uploadStmt.Sid)
- assert.Contains(t, uploadStmt.Action, "s3:PutObject")
- assert.Contains(t, uploadStmt.Action, "s3:CreateMultipartUpload")
- assert.NotNil(t, uploadStmt.Condition)
-
- // Check content type condition
- condition := uploadStmt.Condition
- stringEquals, exists := condition["StringEquals"]
- assert.True(t, exists)
-
- contentType, exists := stringEquals["s3:content-type"]
- assert.True(t, exists)
- assert.Equal(t, allowedTypes, contentType)
-
- // Second statement: Read access without restrictions
- readStmt := policy.Statement[1]
- assert.Equal(t, "Allow", readStmt.Effect)
- assert.Equal(t, "ReadAccess", readStmt.Sid)
- assert.Contains(t, readStmt.Action, "s3:GetObject")
- assert.Contains(t, readStmt.Action, "s3:ListBucket")
- assert.Nil(t, readStmt.Condition) // No conditions for read access
-}
-
-func TestDenyDeletePolicy(t *testing.T) {
- templates := NewS3PolicyTemplates()
-
- policy := templates.GetDenyDeletePolicy()
-
- require.NotNil(t, policy)
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Len(t, policy.Statement, 2)
-
- // First statement: Allow everything except delete
- allowStmt := policy.Statement[0]
- assert.Equal(t, "Allow", allowStmt.Effect)
- assert.Equal(t, "AllowAllExceptDelete", allowStmt.Sid)
- assert.Contains(t, allowStmt.Action, "s3:GetObject")
- assert.Contains(t, allowStmt.Action, "s3:PutObject")
- assert.Contains(t, allowStmt.Action, "s3:ListBucket")
- assert.NotContains(t, allowStmt.Action, "s3:DeleteObject")
- assert.NotContains(t, allowStmt.Action, "s3:DeleteBucket")
-
- // Second statement: Explicitly deny delete operations
- denyStmt := policy.Statement[1]
- assert.Equal(t, "Deny", denyStmt.Effect)
- assert.Equal(t, "DenyDeleteOperations", denyStmt.Sid)
- assert.Contains(t, denyStmt.Action, "s3:DeleteObject")
- assert.Contains(t, denyStmt.Action, "s3:DeleteObjectVersion")
- assert.Contains(t, denyStmt.Action, "s3:DeleteBucket")
-}
-
-func TestPolicyTemplateMetadata(t *testing.T) {
- templates := NewS3PolicyTemplates()
-
- t.Run("GetAllPolicyTemplates", func(t *testing.T) {
- allTemplates := templates.GetAllPolicyTemplates()
-
- assert.Greater(t, len(allTemplates), 10) // Should have many templates
-
- // Check that each template has required fields
- for _, template := range allTemplates {
- assert.NotEmpty(t, template.Name)
- assert.NotEmpty(t, template.Description)
- assert.NotEmpty(t, template.Category)
- assert.NotEmpty(t, template.UseCase)
- assert.NotNil(t, template.Policy)
- assert.Equal(t, "2012-10-17", template.Policy.Version)
- }
- })
-
- t.Run("GetPolicyTemplateByName", func(t *testing.T) {
- // Test existing template
- template := templates.GetPolicyTemplateByName("S3ReadOnlyAccess")
- require.NotNil(t, template)
- assert.Equal(t, "S3ReadOnlyAccess", template.Name)
- assert.Equal(t, "Basic Access", template.Category)
-
- // Test non-existing template
- nonExistent := templates.GetPolicyTemplateByName("NonExistentTemplate")
- assert.Nil(t, nonExistent)
- })
-
- t.Run("GetPolicyTemplatesByCategory", func(t *testing.T) {
- basicAccessTemplates := templates.GetPolicyTemplatesByCategory("Basic Access")
- assert.GreaterOrEqual(t, len(basicAccessTemplates), 2)
-
- for _, template := range basicAccessTemplates {
- assert.Equal(t, "Basic Access", template.Category)
- }
-
- // Test non-existing category
- emptyCategory := templates.GetPolicyTemplatesByCategory("NonExistentCategory")
- assert.Empty(t, emptyCategory)
- })
-
- t.Run("PolicyTemplateParameters", func(t *testing.T) {
- allTemplates := templates.GetAllPolicyTemplates()
-
- // Find a template with parameters (like BucketSpecificRead)
- var templateWithParams *PolicyTemplateDefinition
- for _, template := range allTemplates {
- if template.Name == "BucketSpecificRead" {
- templateWithParams = &template
- break
- }
- }
-
- require.NotNil(t, templateWithParams)
- assert.Greater(t, len(templateWithParams.Parameters), 0)
-
- param := templateWithParams.Parameters[0]
- assert.Equal(t, "bucketName", param.Name)
- assert.Equal(t, "string", param.Type)
- assert.True(t, param.Required)
- assert.NotEmpty(t, param.Description)
- assert.NotEmpty(t, param.Example)
- })
-}
-
-func TestFormatHourHelper(t *testing.T) {
- tests := []struct {
- hour int
- expected string
- }{
- {0, "00"},
- {5, "05"},
- {9, "09"},
- {10, "10"},
- {15, "15"},
- {23, "23"},
- }
-
- for _, tt := range tests {
- t.Run(fmt.Sprintf("Hour_%d", tt.hour), func(t *testing.T) {
- result := formatHour(tt.hour)
- assert.Equal(t, tt.expected, result)
- })
- }
-}
-
-func TestPolicyTemplateCategories(t *testing.T) {
- templates := NewS3PolicyTemplates()
- allTemplates := templates.GetAllPolicyTemplates()
-
- // Extract all categories
- categoryMap := make(map[string]int)
- for _, template := range allTemplates {
- categoryMap[template.Category]++
- }
-
- // Expected categories
- expectedCategories := []string{
- "Basic Access",
- "Administrative",
- "Bucket-Specific",
- "Path-Restricted",
- "Security",
- "Upload-Specific",
- "Presigned URLs",
- "Content Control",
- "Data Protection",
- }
-
- for _, expectedCategory := range expectedCategories {
- count, exists := categoryMap[expectedCategory]
- assert.True(t, exists, "Category %s should exist", expectedCategory)
- assert.Greater(t, count, 0, "Category %s should have at least one template", expectedCategory)
- }
-}
-
-func TestPolicyValidation(t *testing.T) {
- templates := NewS3PolicyTemplates()
- allTemplates := templates.GetAllPolicyTemplates()
-
- // Test that all policies have valid structure
- for _, template := range allTemplates {
- t.Run("Policy_"+template.Name, func(t *testing.T) {
- policy := template.Policy
-
- // Basic validation
- assert.Equal(t, "2012-10-17", policy.Version)
- assert.Greater(t, len(policy.Statement), 0)
-
- // Validate each statement
- for i, stmt := range policy.Statement {
- assert.NotEmpty(t, stmt.Effect, "Statement %d should have effect", i)
- assert.Contains(t, []string{"Allow", "Deny"}, stmt.Effect, "Statement %d effect should be Allow or Deny", i)
- assert.Greater(t, len(stmt.Action), 0, "Statement %d should have actions", i)
- assert.Greater(t, len(stmt.Resource), 0, "Statement %d should have resources", i)
-
- // Check resource format
- for _, resource := range stmt.Resource {
- if resource != "*" {
- assert.Contains(t, resource, "arn:aws:s3:::", "Resource should be valid AWS S3 ARN: %s", resource)
- }
- }
- }
- })
- }
-}
diff --git a/weed/s3api/s3_presigned_url_iam.go b/weed/s3api/s3_presigned_url_iam.go
deleted file mode 100644
index b731b1634..000000000
--- a/weed/s3api/s3_presigned_url_iam.go
+++ /dev/null
@@ -1,355 +0,0 @@
-package s3api
-
-import (
- "context"
- "crypto/sha256"
- "encoding/hex"
- "fmt"
- "net/http"
- "net/url"
- "strconv"
- "strings"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
-)
-
-// S3PresignedURLManager handles IAM integration for presigned URLs
-type S3PresignedURLManager struct {
- s3iam *S3IAMIntegration
-}
-
-// NewS3PresignedURLManager creates a new presigned URL manager with IAM integration
-func NewS3PresignedURLManager(s3iam *S3IAMIntegration) *S3PresignedURLManager {
- return &S3PresignedURLManager{
- s3iam: s3iam,
- }
-}
-
-// PresignedURLRequest represents a request to generate a presigned URL
-type PresignedURLRequest struct {
- Method string `json:"method"` // HTTP method (GET, PUT, POST, DELETE)
- Bucket string `json:"bucket"` // S3 bucket name
- ObjectKey string `json:"object_key"` // S3 object key
- Expiration time.Duration `json:"expiration"` // URL expiration duration
- SessionToken string `json:"session_token"` // JWT session token for IAM
- Headers map[string]string `json:"headers"` // Additional headers to sign
- QueryParams map[string]string `json:"query_params"` // Additional query parameters
-}
-
-// PresignedURLResponse represents the generated presigned URL
-type PresignedURLResponse struct {
- URL string `json:"url"` // The presigned URL
- Method string `json:"method"` // HTTP method
- Headers map[string]string `json:"headers"` // Required headers
- ExpiresAt time.Time `json:"expires_at"` // URL expiration time
- SignedHeaders []string `json:"signed_headers"` // List of signed headers
- CanonicalQuery string `json:"canonical_query"` // Canonical query string
-}
-
-// ValidatePresignedURLWithIAM validates a presigned URL request using IAM policies
-func (iam *IdentityAccessManagement) ValidatePresignedURLWithIAM(r *http.Request, identity *Identity) s3err.ErrorCode {
- if iam.iamIntegration == nil {
- // Fall back to standard validation
- return s3err.ErrNone
- }
-
- // Extract bucket and object from request
- bucket, object := s3_constants.GetBucketAndObject(r)
-
- // Determine the S3 action from HTTP method and path
- action := determineS3ActionFromRequest(r, bucket, object)
-
- // Check if the user has permission for this action
- ctx := r.Context()
- sessionToken := extractSessionTokenFromPresignedURL(r)
- if sessionToken == "" {
- // No session token in presigned URL - use standard auth
- return s3err.ErrNone
- }
-
- // Create a temporary cloned request with Authorization header to reuse the secure AuthenticateJWT logic
- // This ensures we use the same robust validation (STS vs OIDC, signature verification, etc.)
- // as standard requests, preventing security regressions.
- authReq := r.Clone(ctx)
- authReq.Header.Set("Authorization", "Bearer "+sessionToken)
-
- // Authenticate the token using the centralized IAM integration
- iamIdentity, errCode := iam.iamIntegration.AuthenticateJWT(ctx, authReq)
- if errCode != s3err.ErrNone {
- glog.V(3).Infof("JWT authentication failed for presigned URL: %v", errCode)
- return errCode
- }
-
- // Authorize using IAM
- errCode = iam.iamIntegration.AuthorizeAction(ctx, iamIdentity, action, bucket, object, r)
- if errCode != s3err.ErrNone {
- glog.V(3).Infof("IAM authorization failed for presigned URL: principal=%s action=%s bucket=%s object=%s",
- iamIdentity.Principal, action, bucket, object)
- return errCode
- }
-
- glog.V(3).Infof("IAM authorization succeeded for presigned URL: principal=%s action=%s bucket=%s object=%s",
- iamIdentity.Principal, action, bucket, object)
- return s3err.ErrNone
-}
-
-// GeneratePresignedURLWithIAM generates a presigned URL with IAM policy validation
-func (pm *S3PresignedURLManager) GeneratePresignedURLWithIAM(ctx context.Context, req *PresignedURLRequest, baseURL string) (*PresignedURLResponse, error) {
- if pm.s3iam == nil || !pm.s3iam.enabled {
- return nil, fmt.Errorf("IAM integration not enabled")
- }
- if req == nil || strings.TrimSpace(req.SessionToken) == "" {
- return nil, fmt.Errorf("IAM authorization failed: session token is required")
- }
-
- authRequest := &http.Request{
- Method: req.Method,
- URL: &url.URL{Path: "/" + req.Bucket + "/" + req.ObjectKey},
- Header: make(http.Header),
- }
- authRequest.Header.Set("Authorization", "Bearer "+req.SessionToken)
- authRequest = authRequest.WithContext(ctx)
-
- iamIdentity, errCode := pm.s3iam.AuthenticateJWT(ctx, authRequest)
- if errCode != s3err.ErrNone {
- return nil, fmt.Errorf("IAM authorization failed: invalid session token")
- }
-
- // Determine S3 action from method
- action := determineS3ActionFromMethodAndPath(req.Method, req.Bucket, req.ObjectKey)
-
- // Check IAM permissions before generating URL
- errCode = pm.s3iam.AuthorizeAction(ctx, iamIdentity, action, req.Bucket, req.ObjectKey, authRequest)
- if errCode != s3err.ErrNone {
- return nil, fmt.Errorf("IAM authorization failed: user does not have permission for action %s on resource %s/%s", action, req.Bucket, req.ObjectKey)
- }
-
- // Generate presigned URL with validated permissions
- return pm.generatePresignedURL(req, baseURL, iamIdentity)
-}
-
-// generatePresignedURL creates the actual presigned URL
-func (pm *S3PresignedURLManager) generatePresignedURL(req *PresignedURLRequest, baseURL string, identity *IAMIdentity) (*PresignedURLResponse, error) {
- // Calculate expiration time
- expiresAt := time.Now().Add(req.Expiration)
-
- // Build the base URL
- urlPath := "/" + req.Bucket
- if req.ObjectKey != "" {
- urlPath += "/" + req.ObjectKey
- }
-
- // Create query parameters for AWS signature v4
- queryParams := make(map[string]string)
- for k, v := range req.QueryParams {
- queryParams[k] = v
- }
-
- // Add AWS signature v4 parameters
- queryParams["X-Amz-Algorithm"] = "AWS4-HMAC-SHA256"
- queryParams["X-Amz-Credential"] = fmt.Sprintf("seaweedfs/%s/us-east-1/s3/aws4_request", expiresAt.Format("20060102"))
- queryParams["X-Amz-Date"] = expiresAt.Format("20060102T150405Z")
- queryParams["X-Amz-Expires"] = strconv.Itoa(int(req.Expiration.Seconds()))
- queryParams["X-Amz-SignedHeaders"] = "host"
-
- // Add session token if available
- if identity.SessionToken != "" {
- queryParams["X-Amz-Security-Token"] = identity.SessionToken
- }
-
- // Build canonical query string
- canonicalQuery := buildCanonicalQuery(queryParams)
-
- // For now, we'll create a mock signature
- // In production, this would use proper AWS signature v4 signing
- mockSignature := generateMockSignature(req.Method, urlPath, canonicalQuery, identity.SessionToken)
- queryParams["X-Amz-Signature"] = mockSignature
-
- // Build final URL
- finalQuery := buildCanonicalQuery(queryParams)
- fullURL := baseURL + urlPath + "?" + finalQuery
-
- // Prepare response
- headers := make(map[string]string)
- for k, v := range req.Headers {
- headers[k] = v
- }
-
- return &PresignedURLResponse{
- URL: fullURL,
- Method: req.Method,
- Headers: headers,
- ExpiresAt: expiresAt,
- SignedHeaders: []string{"host"},
- CanonicalQuery: canonicalQuery,
- }, nil
-}
-
-// Helper functions
-
-// determineS3ActionFromRequest determines the S3 action based on HTTP request
-func determineS3ActionFromRequest(r *http.Request, bucket, object string) Action {
- return determineS3ActionFromMethodAndPath(r.Method, bucket, object)
-}
-
-// determineS3ActionFromMethodAndPath determines the S3 action based on method and path
-func determineS3ActionFromMethodAndPath(method, bucket, object string) Action {
- switch method {
- case "GET":
- if object == "" {
- return s3_constants.ACTION_LIST // ListBucket
- } else {
- return s3_constants.ACTION_READ // GetObject
- }
- case "PUT", "POST":
- return s3_constants.ACTION_WRITE // PutObject
- case "DELETE":
- if object == "" {
- return s3_constants.ACTION_DELETE_BUCKET // DeleteBucket
- } else {
- return s3_constants.ACTION_WRITE // DeleteObject (uses WRITE action)
- }
- case "HEAD":
- if object == "" {
- return s3_constants.ACTION_LIST // HeadBucket
- } else {
- return s3_constants.ACTION_READ // HeadObject
- }
- default:
- return s3_constants.ACTION_READ // Default to read
- }
-}
-
-// extractSessionTokenFromPresignedURL extracts session token from presigned URL query parameters
-func extractSessionTokenFromPresignedURL(r *http.Request) string {
- // Check for X-Amz-Security-Token in query parameters
- if token := r.URL.Query().Get("X-Amz-Security-Token"); token != "" {
- return token
- }
-
- // Check for session token in other possible locations
- if token := r.URL.Query().Get("SessionToken"); token != "" {
- return token
- }
-
- return ""
-}
-
-// buildCanonicalQuery builds a canonical query string for AWS signature
-func buildCanonicalQuery(params map[string]string) string {
- var keys []string
- for k := range params {
- keys = append(keys, k)
- }
-
- // Sort keys for canonical order
- for i := 0; i < len(keys); i++ {
- for j := i + 1; j < len(keys); j++ {
- if keys[i] > keys[j] {
- keys[i], keys[j] = keys[j], keys[i]
- }
- }
- }
-
- var parts []string
- for _, k := range keys {
- parts = append(parts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(params[k])))
- }
-
- return strings.Join(parts, "&")
-}
-
-// generateMockSignature generates a mock signature for testing purposes
-func generateMockSignature(method, path, query, sessionToken string) string {
- // This is a simplified signature for demonstration
- // In production, use proper AWS signature v4 calculation
- data := fmt.Sprintf("%s\n%s\n%s\n%s", method, path, query, sessionToken)
- hash := sha256.Sum256([]byte(data))
- return hex.EncodeToString(hash[:])[:16] // Truncate for readability
-}
-
-// ValidatePresignedURLExpiration validates that a presigned URL hasn't expired
-func ValidatePresignedURLExpiration(r *http.Request) error {
- query := r.URL.Query()
-
- // Get X-Amz-Date and X-Amz-Expires
- dateStr := query.Get("X-Amz-Date")
- expiresStr := query.Get("X-Amz-Expires")
-
- if dateStr == "" || expiresStr == "" {
- return fmt.Errorf("missing required presigned URL parameters")
- }
-
- // Parse date (always in UTC)
- signedDate, err := time.Parse("20060102T150405Z", dateStr)
- if err != nil {
- return fmt.Errorf("invalid X-Amz-Date format: %v", err)
- }
-
- // Parse expires
- expires, err := strconv.Atoi(expiresStr)
- if err != nil {
- return fmt.Errorf("invalid X-Amz-Expires format: %v", err)
- }
-
- // Check expiration - compare in UTC
- expirationTime := signedDate.Add(time.Duration(expires) * time.Second)
- now := time.Now().UTC()
- if now.After(expirationTime) {
- return fmt.Errorf("presigned URL has expired")
- }
-
- return nil
-}
-
-// PresignedURLSecurityPolicy represents security constraints for presigned URL generation
-type PresignedURLSecurityPolicy struct {
- MaxExpirationDuration time.Duration `json:"max_expiration_duration"` // Maximum allowed expiration
- AllowedMethods []string `json:"allowed_methods"` // Allowed HTTP methods
- RequiredHeaders []string `json:"required_headers"` // Headers that must be present
- IPWhitelist []string `json:"ip_whitelist"` // Allowed IP addresses/ranges
- MaxFileSize int64 `json:"max_file_size"` // Maximum file size for uploads
-}
-
-// DefaultPresignedURLSecurityPolicy returns a default security policy
-func DefaultPresignedURLSecurityPolicy() *PresignedURLSecurityPolicy {
- return &PresignedURLSecurityPolicy{
- MaxExpirationDuration: 7 * 24 * time.Hour, // 7 days max
- AllowedMethods: []string{"GET", "PUT", "POST", "HEAD"},
- RequiredHeaders: []string{},
- IPWhitelist: []string{}, // Empty means no IP restrictions
- MaxFileSize: 5 * 1024 * 1024 * 1024, // 5GB default
- }
-}
-
-// ValidatePresignedURLRequest validates a presigned URL request against security policy
-func (policy *PresignedURLSecurityPolicy) ValidatePresignedURLRequest(req *PresignedURLRequest) error {
- // Check expiration duration
- if req.Expiration > policy.MaxExpirationDuration {
- return fmt.Errorf("expiration duration %v exceeds maximum allowed %v", req.Expiration, policy.MaxExpirationDuration)
- }
-
- // Check HTTP method
- methodAllowed := false
- for _, allowedMethod := range policy.AllowedMethods {
- if req.Method == allowedMethod {
- methodAllowed = true
- break
- }
- }
- if !methodAllowed {
- return fmt.Errorf("HTTP method %s is not allowed", req.Method)
- }
-
- // Check required headers
- for _, requiredHeader := range policy.RequiredHeaders {
- if _, exists := req.Headers[requiredHeader]; !exists {
- return fmt.Errorf("required header %s is missing", requiredHeader)
- }
- }
-
- return nil
-}
diff --git a/weed/s3api/s3_presigned_url_iam_test.go b/weed/s3api/s3_presigned_url_iam_test.go
deleted file mode 100644
index 5d50f06dc..000000000
--- a/weed/s3api/s3_presigned_url_iam_test.go
+++ /dev/null
@@ -1,631 +0,0 @@
-package s3api
-
-import (
- "context"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
-
- "github.com/golang-jwt/jwt/v5"
- "github.com/seaweedfs/seaweedfs/weed/iam/integration"
- "github.com/seaweedfs/seaweedfs/weed/iam/ldap"
- "github.com/seaweedfs/seaweedfs/weed/iam/oidc"
- "github.com/seaweedfs/seaweedfs/weed/iam/policy"
- "github.com/seaweedfs/seaweedfs/weed/iam/sts"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-// createTestJWTPresigned creates a test JWT token with the specified issuer, subject and signing key
-func createTestJWTPresigned(t *testing.T, issuer, subject, signingKey string) string {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "iss": issuer,
- "sub": subject,
- "aud": "test-client-id",
- "exp": time.Now().Add(time.Hour).Unix(),
- "iat": time.Now().Unix(),
- // Add claims that trust policy validation expects
- "idp": "test-oidc", // Identity provider claim for trust policy matching
- })
-
- tokenString, err := token.SignedString([]byte(signingKey))
- require.NoError(t, err)
- return tokenString
-}
-
-// TestPresignedURLIAMValidation tests IAM validation for presigned URLs
-func TestPresignedURLIAMValidation(t *testing.T) {
- // Set up IAM system
- iamManager := setupTestIAMManagerForPresigned(t)
- s3iam := NewS3IAMIntegration(iamManager, "localhost:8888")
-
- // Create IAM with integration
- iam := &IdentityAccessManagement{
- isAuthEnabled: true,
- }
- iam.SetIAMIntegration(s3iam)
-
- // Set up roles
- ctx := context.Background()
- setupTestRolesForPresigned(ctx, iamManager)
-
- // Create a valid JWT token for testing
- validJWTToken := createTestJWTPresigned(t, "https://test-issuer.com", "test-user-123", "test-signing-key")
-
- // Get session token
- response, err := iamManager.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/S3ReadOnlyRole",
- WebIdentityToken: validJWTToken,
- RoleSessionName: "presigned-test-session",
- })
- require.NoError(t, err)
-
- sessionToken := response.Credentials.SessionToken
-
- tests := []struct {
- name string
- method string
- path string
- sessionToken string
- expectedResult s3err.ErrorCode
- }{
- {
- name: "GET object with read permissions",
- method: "GET",
- path: "/test-bucket/test-file.txt",
- sessionToken: sessionToken,
- expectedResult: s3err.ErrNone,
- },
- {
- name: "PUT object with read-only permissions (should fail)",
- method: "PUT",
- path: "/test-bucket/new-file.txt",
- sessionToken: sessionToken,
- expectedResult: s3err.ErrAccessDenied,
- },
- {
- name: "GET object without session token",
- method: "GET",
- path: "/test-bucket/test-file.txt",
- sessionToken: "",
- expectedResult: s3err.ErrNone, // Falls back to standard auth
- },
- {
- name: "Invalid session token",
- method: "GET",
- path: "/test-bucket/test-file.txt",
- sessionToken: "invalid-token",
- expectedResult: s3err.ErrAccessDenied,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Create request with presigned URL parameters
- req := createPresignedURLRequest(t, tt.method, tt.path, tt.sessionToken)
-
- // Create identity for testing
- identity := &Identity{
- Name: "test-user",
- Account: &AccountAdmin,
- }
-
- // Test validation
- result := iam.ValidatePresignedURLWithIAM(req, identity)
- assert.Equal(t, tt.expectedResult, result, "IAM validation result should match expected")
- })
- }
-}
-
-// TestPresignedURLGeneration tests IAM-aware presigned URL generation
-func TestPresignedURLGeneration(t *testing.T) {
- // Set up IAM system
- iamManager := setupTestIAMManagerForPresigned(t)
- s3iam := NewS3IAMIntegration(iamManager, "localhost:8888")
- s3iam.enabled = true // Enable IAM integration
- presignedManager := NewS3PresignedURLManager(s3iam)
-
- ctx := context.Background()
- setupTestRolesForPresigned(ctx, iamManager)
-
- // Create a valid JWT token for testing
- validJWTToken := createTestJWTPresigned(t, "https://test-issuer.com", "test-user-123", "test-signing-key")
-
- // Get session token
- response, err := iamManager.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/S3AdminRole",
- WebIdentityToken: validJWTToken,
- RoleSessionName: "presigned-gen-test-session",
- })
- require.NoError(t, err)
-
- sessionToken := response.Credentials.SessionToken
-
- tests := []struct {
- name string
- request *PresignedURLRequest
- shouldSucceed bool
- expectedError string
- }{
- {
- name: "Generate valid presigned GET URL",
- request: &PresignedURLRequest{
- Method: "GET",
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- Expiration: time.Hour,
- SessionToken: sessionToken,
- },
- shouldSucceed: true,
- },
- {
- name: "Generate valid presigned PUT URL",
- request: &PresignedURLRequest{
- Method: "PUT",
- Bucket: "test-bucket",
- ObjectKey: "new-file.txt",
- Expiration: time.Hour,
- SessionToken: sessionToken,
- },
- shouldSucceed: true,
- },
- {
- name: "Generate URL with invalid session token",
- request: &PresignedURLRequest{
- Method: "GET",
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- Expiration: time.Hour,
- SessionToken: "invalid-token",
- },
- shouldSucceed: false,
- expectedError: "IAM authorization failed",
- },
- {
- name: "Generate URL without session token",
- request: &PresignedURLRequest{
- Method: "GET",
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- Expiration: time.Hour,
- },
- shouldSucceed: false,
- expectedError: "IAM authorization failed",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- response, err := presignedManager.GeneratePresignedURLWithIAM(ctx, tt.request, "http://localhost:8333")
-
- if tt.shouldSucceed {
- assert.NoError(t, err, "Presigned URL generation should succeed")
- if response != nil {
- assert.NotEmpty(t, response.URL, "URL should not be empty")
- assert.Equal(t, tt.request.Method, response.Method, "Method should match")
- assert.True(t, response.ExpiresAt.After(time.Now()), "URL should not be expired")
- } else {
- t.Errorf("Response should not be nil when generation should succeed")
- }
- } else {
- assert.Error(t, err, "Presigned URL generation should fail")
- if tt.expectedError != "" {
- assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text")
- }
- }
- })
- }
-}
-
-func TestPresignedURLGenerationUsesAuthenticatedPrincipal(t *testing.T) {
- iamManager := setupTestIAMManagerForPresigned(t)
- s3iam := NewS3IAMIntegration(iamManager, "localhost:8888")
- s3iam.enabled = true
- presignedManager := NewS3PresignedURLManager(s3iam)
-
- ctx := context.Background()
- setupTestRolesForPresigned(ctx, iamManager)
-
- validJWTToken := createTestJWTPresigned(t, "https://test-issuer.com", "test-user-123", "test-signing-key")
-
- response, err := iamManager.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityRequest{
- RoleArn: "arn:aws:iam::role/S3ReadOnlyRole",
- WebIdentityToken: validJWTToken,
- RoleSessionName: "presigned-read-only-session",
- })
- require.NoError(t, err)
-
- _, err = presignedManager.GeneratePresignedURLWithIAM(ctx, &PresignedURLRequest{
- Method: "PUT",
- Bucket: "test-bucket",
- ObjectKey: "new-file.txt",
- Expiration: time.Hour,
- SessionToken: response.Credentials.SessionToken,
- }, "http://localhost:8333")
- require.Error(t, err)
- assert.Contains(t, err.Error(), "IAM authorization failed")
-}
-
-// TestPresignedURLExpiration tests URL expiration validation
-func TestPresignedURLExpiration(t *testing.T) {
- tests := []struct {
- name string
- setupRequest func() *http.Request
- expectedError string
- }{
- {
- name: "Valid non-expired URL",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("GET", "/test-bucket/test-file.txt", nil)
- q := req.URL.Query()
- // Set date to 30 minutes ago with 2 hours expiration for safe margin
- q.Set("X-Amz-Date", time.Now().UTC().Add(-30*time.Minute).Format("20060102T150405Z"))
- q.Set("X-Amz-Expires", "7200") // 2 hours
- req.URL.RawQuery = q.Encode()
- return req
- },
- expectedError: "",
- },
- {
- name: "Expired URL",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("GET", "/test-bucket/test-file.txt", nil)
- q := req.URL.Query()
- // Set date to 2 hours ago with 1 hour expiration
- q.Set("X-Amz-Date", time.Now().UTC().Add(-2*time.Hour).Format("20060102T150405Z"))
- q.Set("X-Amz-Expires", "3600") // 1 hour
- req.URL.RawQuery = q.Encode()
- return req
- },
- expectedError: "presigned URL has expired",
- },
- {
- name: "Missing date parameter",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("GET", "/test-bucket/test-file.txt", nil)
- q := req.URL.Query()
- q.Set("X-Amz-Expires", "3600")
- req.URL.RawQuery = q.Encode()
- return req
- },
- expectedError: "missing required presigned URL parameters",
- },
- {
- name: "Invalid date format",
- setupRequest: func() *http.Request {
- req := httptest.NewRequest("GET", "/test-bucket/test-file.txt", nil)
- q := req.URL.Query()
- q.Set("X-Amz-Date", "invalid-date")
- q.Set("X-Amz-Expires", "3600")
- req.URL.RawQuery = q.Encode()
- return req
- },
- expectedError: "invalid X-Amz-Date format",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- req := tt.setupRequest()
- err := ValidatePresignedURLExpiration(req)
-
- if tt.expectedError == "" {
- assert.NoError(t, err, "Validation should succeed")
- } else {
- assert.Error(t, err, "Validation should fail")
- assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text")
- }
- })
- }
-}
-
-// TestPresignedURLSecurityPolicy tests security policy enforcement
-func TestPresignedURLSecurityPolicy(t *testing.T) {
- policy := &PresignedURLSecurityPolicy{
- MaxExpirationDuration: 24 * time.Hour,
- AllowedMethods: []string{"GET", "PUT"},
- RequiredHeaders: []string{"Content-Type"},
- MaxFileSize: 1024 * 1024, // 1MB
- }
-
- tests := []struct {
- name string
- request *PresignedURLRequest
- expectedError string
- }{
- {
- name: "Valid request",
- request: &PresignedURLRequest{
- Method: "GET",
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- Expiration: 12 * time.Hour,
- Headers: map[string]string{"Content-Type": "application/json"},
- },
- expectedError: "",
- },
- {
- name: "Expiration too long",
- request: &PresignedURLRequest{
- Method: "GET",
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- Expiration: 48 * time.Hour, // Exceeds 24h limit
- Headers: map[string]string{"Content-Type": "application/json"},
- },
- expectedError: "expiration duration",
- },
- {
- name: "Method not allowed",
- request: &PresignedURLRequest{
- Method: "DELETE", // Not in allowed methods
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- Expiration: 12 * time.Hour,
- Headers: map[string]string{"Content-Type": "application/json"},
- },
- expectedError: "HTTP method DELETE is not allowed",
- },
- {
- name: "Missing required header",
- request: &PresignedURLRequest{
- Method: "GET",
- Bucket: "test-bucket",
- ObjectKey: "test-file.txt",
- Expiration: 12 * time.Hour,
- Headers: map[string]string{}, // Missing Content-Type
- },
- expectedError: "required header Content-Type is missing",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := policy.ValidatePresignedURLRequest(tt.request)
-
- if tt.expectedError == "" {
- assert.NoError(t, err, "Policy validation should succeed")
- } else {
- assert.Error(t, err, "Policy validation should fail")
- assert.Contains(t, err.Error(), tt.expectedError, "Error message should contain expected text")
- }
- })
- }
-}
-
-// TestS3ActionDetermination tests action determination from HTTP methods
-func TestS3ActionDetermination(t *testing.T) {
- tests := []struct {
- name string
- method string
- bucket string
- object string
- expectedAction Action
- }{
- {
- name: "GET object",
- method: "GET",
- bucket: "test-bucket",
- object: "test-file.txt",
- expectedAction: s3_constants.ACTION_READ,
- },
- {
- name: "GET bucket (list)",
- method: "GET",
- bucket: "test-bucket",
- object: "",
- expectedAction: s3_constants.ACTION_LIST,
- },
- {
- name: "PUT object",
- method: "PUT",
- bucket: "test-bucket",
- object: "new-file.txt",
- expectedAction: s3_constants.ACTION_WRITE,
- },
- {
- name: "DELETE object",
- method: "DELETE",
- bucket: "test-bucket",
- object: "old-file.txt",
- expectedAction: s3_constants.ACTION_WRITE,
- },
- {
- name: "DELETE bucket",
- method: "DELETE",
- bucket: "test-bucket",
- object: "",
- expectedAction: s3_constants.ACTION_DELETE_BUCKET,
- },
- {
- name: "HEAD object",
- method: "HEAD",
- bucket: "test-bucket",
- object: "test-file.txt",
- expectedAction: s3_constants.ACTION_READ,
- },
- {
- name: "POST object",
- method: "POST",
- bucket: "test-bucket",
- object: "upload-file.txt",
- expectedAction: s3_constants.ACTION_WRITE,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- action := determineS3ActionFromMethodAndPath(tt.method, tt.bucket, tt.object)
- assert.Equal(t, tt.expectedAction, action, "S3 action should match expected")
- })
- }
-}
-
-// Helper functions for tests
-
-func setupTestIAMManagerForPresigned(t *testing.T) *integration.IAMManager {
- // Create IAM manager
- manager := integration.NewIAMManager()
-
- // Initialize with test configuration
- config := &integration.IAMConfig{
- STS: &sts.STSConfig{
- TokenDuration: sts.FlexibleDuration{Duration: time.Hour},
- MaxSessionLength: sts.FlexibleDuration{Duration: time.Hour * 12},
- Issuer: "test-sts",
- SigningKey: []byte("test-signing-key-32-characters-long"),
- },
- Policy: &policy.PolicyEngineConfig{
- DefaultEffect: "Deny",
- StoreType: "memory",
- },
- Roles: &integration.RoleStoreConfig{
- StoreType: "memory",
- },
- }
-
- err := manager.Initialize(config, func() string {
- return "localhost:8888" // Mock filer address for testing
- })
- require.NoError(t, err)
-
- // Set up test identity providers
- setupTestProvidersForPresigned(t, manager)
-
- return manager
-}
-
-func setupTestProvidersForPresigned(t *testing.T, manager *integration.IAMManager) {
- // Set up OIDC provider
- oidcProvider := oidc.NewMockOIDCProvider("test-oidc")
- oidcConfig := &oidc.OIDCConfig{
- Issuer: "https://test-issuer.com",
- ClientID: "test-client-id",
- }
- err := oidcProvider.Initialize(oidcConfig)
- require.NoError(t, err)
- oidcProvider.SetupDefaultTestData()
-
- // Set up LDAP provider
- ldapProvider := ldap.NewMockLDAPProvider("test-ldap")
- err = ldapProvider.Initialize(nil) // Mock doesn't need real config
- require.NoError(t, err)
- ldapProvider.SetupDefaultTestData()
-
- // Register providers
- err = manager.RegisterIdentityProvider(oidcProvider)
- require.NoError(t, err)
- err = manager.RegisterIdentityProvider(ldapProvider)
- require.NoError(t, err)
-}
-
-func setupTestRolesForPresigned(ctx context.Context, manager *integration.IAMManager) {
- // Create read-only policy
- readOnlyPolicy := &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "AllowS3ReadOperations",
- Effect: "Allow",
- Action: []string{"s3:GetObject", "s3:ListBucket", "s3:HeadObject"},
- Resource: []string{
- "arn:aws:s3:::*",
- "arn:aws:s3:::*/*",
- },
- },
- },
- }
-
- manager.CreatePolicy(ctx, "", "S3ReadOnlyPolicy", readOnlyPolicy)
-
- // Create read-only role
- manager.CreateRole(ctx, "", "S3ReadOnlyRole", &integration.RoleDefinition{
- RoleName: "S3ReadOnlyRole",
- TrustPolicy: &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Effect: "Allow",
- Principal: map[string]interface{}{
- "Federated": "test-oidc",
- },
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- },
- },
- },
- AttachedPolicies: []string{"S3ReadOnlyPolicy"},
- })
-
- // Create admin policy
- adminPolicy := &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Sid: "AllowAllS3Operations",
- Effect: "Allow",
- Action: []string{"s3:*"},
- Resource: []string{
- "arn:aws:s3:::*",
- "arn:aws:s3:::*/*",
- },
- },
- },
- }
-
- manager.CreatePolicy(ctx, "", "S3AdminPolicy", adminPolicy)
-
- // Create admin role
- manager.CreateRole(ctx, "", "S3AdminRole", &integration.RoleDefinition{
- RoleName: "S3AdminRole",
- TrustPolicy: &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Effect: "Allow",
- Principal: map[string]interface{}{
- "Federated": "test-oidc",
- },
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- },
- },
- },
- AttachedPolicies: []string{"S3AdminPolicy"},
- })
-
- // Create a role for presigned URL users with admin permissions for testing
- manager.CreateRole(ctx, "", "PresignedUser", &integration.RoleDefinition{
- RoleName: "PresignedUser",
- TrustPolicy: &policy.PolicyDocument{
- Version: "2012-10-17",
- Statement: []policy.Statement{
- {
- Effect: "Allow",
- Principal: map[string]interface{}{
- "Federated": "test-oidc",
- },
- Action: []string{"sts:AssumeRoleWithWebIdentity"},
- },
- },
- },
- AttachedPolicies: []string{"S3AdminPolicy"}, // Use admin policy for testing
- })
-}
-
-func createPresignedURLRequest(t *testing.T, method, path, sessionToken string) *http.Request {
- req := httptest.NewRequest(method, path, nil)
-
- // Add presigned URL parameters if session token is provided
- if sessionToken != "" {
- q := req.URL.Query()
- q.Set("X-Amz-Algorithm", "AWS4-HMAC-SHA256")
- q.Set("X-Amz-Security-Token", sessionToken)
- q.Set("X-Amz-Date", time.Now().Format("20060102T150405Z"))
- q.Set("X-Amz-Expires", "3600")
- req.URL.RawQuery = q.Encode()
- }
-
- return req
-}
diff --git a/weed/s3api/s3_sse_bucket_test.go b/weed/s3api/s3_sse_bucket_test.go
deleted file mode 100644
index 74ad9296b..000000000
--- a/weed/s3api/s3_sse_bucket_test.go
+++ /dev/null
@@ -1,401 +0,0 @@
-package s3api
-
-import (
- "fmt"
- "strings"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/s3_pb"
-)
-
-// TestBucketDefaultSSEKMSEnforcement tests bucket default encryption enforcement
-func TestBucketDefaultSSEKMSEnforcement(t *testing.T) {
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- // Create bucket encryption configuration
- config := &s3_pb.EncryptionConfiguration{
- SseAlgorithm: "aws:kms",
- KmsKeyId: kmsKey.KeyID,
- BucketKeyEnabled: false,
- }
-
- t.Run("Bucket with SSE-KMS default encryption", func(t *testing.T) {
- // Test that default encryption config is properly stored and retrieved
- if config.SseAlgorithm != "aws:kms" {
- t.Errorf("Expected SSE algorithm aws:kms, got %s", config.SseAlgorithm)
- }
-
- if config.KmsKeyId != kmsKey.KeyID {
- t.Errorf("Expected KMS key ID %s, got %s", kmsKey.KeyID, config.KmsKeyId)
- }
- })
-
- t.Run("Default encryption headers generation", func(t *testing.T) {
- // Test generating default encryption headers for objects
- headers := GetDefaultEncryptionHeaders(config)
-
- if headers == nil {
- t.Fatal("Expected default headers, got nil")
- }
-
- expectedAlgorithm := headers["X-Amz-Server-Side-Encryption"]
- if expectedAlgorithm != "aws:kms" {
- t.Errorf("Expected X-Amz-Server-Side-Encryption header aws:kms, got %s", expectedAlgorithm)
- }
-
- expectedKeyID := headers["X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id"]
- if expectedKeyID != kmsKey.KeyID {
- t.Errorf("Expected X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id header %s, got %s", kmsKey.KeyID, expectedKeyID)
- }
- })
-
- t.Run("Default encryption detection", func(t *testing.T) {
- // Test IsDefaultEncryptionEnabled
- enabled := IsDefaultEncryptionEnabled(config)
- if !enabled {
- t.Error("Should detect default encryption as enabled")
- }
-
- // Test with nil config
- enabled = IsDefaultEncryptionEnabled(nil)
- if enabled {
- t.Error("Should detect default encryption as disabled for nil config")
- }
-
- // Test with empty config
- emptyConfig := &s3_pb.EncryptionConfiguration{}
- enabled = IsDefaultEncryptionEnabled(emptyConfig)
- if enabled {
- t.Error("Should detect default encryption as disabled for empty config")
- }
- })
-}
-
-// TestBucketEncryptionConfigValidation tests XML validation of bucket encryption configurations
-func TestBucketEncryptionConfigValidation(t *testing.T) {
- testCases := []struct {
- name string
- xml string
- expectError bool
- description string
- }{
- {
- name: "Valid SSE-S3 configuration",
- xml: `
-
-
- AES256
-
-
- `,
- expectError: false,
- description: "Basic SSE-S3 configuration should be valid",
- },
- {
- name: "Valid SSE-KMS configuration",
- xml: `
-
-
- aws:kms
- test-key-id
-
-
- `,
- expectError: false,
- description: "SSE-KMS configuration with key ID should be valid",
- },
- {
- name: "Valid SSE-KMS without key ID",
- xml: `
-
-
- aws:kms
-
-
- `,
- expectError: false,
- description: "SSE-KMS without key ID should use default key",
- },
- {
- name: "Invalid XML structure",
- xml: `
-
- AES256
-
- `,
- expectError: true,
- description: "Invalid XML structure should be rejected",
- },
- {
- name: "Empty configuration",
- xml: `
- `,
- expectError: true,
- description: "Empty configuration should be rejected",
- },
- {
- name: "Invalid algorithm",
- xml: `
-
-
- INVALID
-
-
- `,
- expectError: true,
- description: "Invalid algorithm should be rejected",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- config, err := encryptionConfigFromXMLBytes([]byte(tc.xml))
-
- if tc.expectError && err == nil {
- t.Errorf("Expected error for %s, but got none. %s", tc.name, tc.description)
- }
-
- if !tc.expectError && err != nil {
- t.Errorf("Expected no error for %s, but got: %v. %s", tc.name, err, tc.description)
- }
-
- if !tc.expectError && config != nil {
- // Validate the parsed configuration
- t.Logf("Successfully parsed config: Algorithm=%s, KeyID=%s",
- config.SseAlgorithm, config.KmsKeyId)
- }
- })
- }
-}
-
-// TestBucketEncryptionAPIOperations tests the bucket encryption API operations
-func TestBucketEncryptionAPIOperations(t *testing.T) {
- // Note: These tests would normally require a full S3 API server setup
- // For now, we test the individual components
-
- t.Run("PUT bucket encryption", func(t *testing.T) {
- xml := `
-
-
- aws:kms
- test-key-id
-
-
- `
-
- // Parse the XML to protobuf
- config, err := encryptionConfigFromXMLBytes([]byte(xml))
- if err != nil {
- t.Fatalf("Failed to parse encryption config: %v", err)
- }
-
- // Verify the parsed configuration
- if config.SseAlgorithm != "aws:kms" {
- t.Errorf("Expected algorithm aws:kms, got %s", config.SseAlgorithm)
- }
-
- if config.KmsKeyId != "test-key-id" {
- t.Errorf("Expected key ID test-key-id, got %s", config.KmsKeyId)
- }
-
- // Convert back to XML
- xmlBytes, err := encryptionConfigToXMLBytes(config)
- if err != nil {
- t.Fatalf("Failed to convert config to XML: %v", err)
- }
-
- // Verify round-trip
- if len(xmlBytes) == 0 {
- t.Error("Generated XML should not be empty")
- }
-
- // Parse again to verify
- roundTripConfig, err := encryptionConfigFromXMLBytes(xmlBytes)
- if err != nil {
- t.Fatalf("Failed to parse round-trip XML: %v", err)
- }
-
- if roundTripConfig.SseAlgorithm != config.SseAlgorithm {
- t.Error("Round-trip algorithm doesn't match")
- }
-
- if roundTripConfig.KmsKeyId != config.KmsKeyId {
- t.Error("Round-trip key ID doesn't match")
- }
- })
-
- t.Run("GET bucket encryption", func(t *testing.T) {
- // Test getting encryption configuration
- config := &s3_pb.EncryptionConfiguration{
- SseAlgorithm: "AES256",
- KmsKeyId: "",
- BucketKeyEnabled: false,
- }
-
- // Convert to XML for GET response
- xmlBytes, err := encryptionConfigToXMLBytes(config)
- if err != nil {
- t.Fatalf("Failed to convert config to XML: %v", err)
- }
-
- if len(xmlBytes) == 0 {
- t.Error("Generated XML should not be empty")
- }
-
- // Verify XML contains expected elements
- xmlStr := string(xmlBytes)
- if !strings.Contains(xmlStr, "AES256") {
- t.Error("XML should contain AES256 algorithm")
- }
- })
-
- t.Run("DELETE bucket encryption", func(t *testing.T) {
- // Test deleting encryption configuration
- // This would typically involve removing the configuration from metadata
-
- // Simulate checking if encryption is enabled after deletion
- enabled := IsDefaultEncryptionEnabled(nil)
- if enabled {
- t.Error("Encryption should be disabled after deletion")
- }
- })
-}
-
-// TestBucketEncryptionEdgeCases tests edge cases in bucket encryption
-func TestBucketEncryptionEdgeCases(t *testing.T) {
- t.Run("Large XML configuration", func(t *testing.T) {
- // Test with a large but valid XML
- largeXML := `
-
-
- aws:kms
- arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012
-
- true
-
- `
-
- config, err := encryptionConfigFromXMLBytes([]byte(largeXML))
- if err != nil {
- t.Fatalf("Failed to parse large XML: %v", err)
- }
-
- if config.SseAlgorithm != "aws:kms" {
- t.Error("Should parse large XML correctly")
- }
- })
-
- t.Run("XML with namespaces", func(t *testing.T) {
- // Test XML with namespaces
- namespacedXML := `
-
-
- AES256
-
-
- `
-
- config, err := encryptionConfigFromXMLBytes([]byte(namespacedXML))
- if err != nil {
- t.Fatalf("Failed to parse namespaced XML: %v", err)
- }
-
- if config.SseAlgorithm != "AES256" {
- t.Error("Should parse namespaced XML correctly")
- }
- })
-
- t.Run("Malformed XML", func(t *testing.T) {
- malformedXMLs := []string{
- `AES256`, // Unclosed tags
- ``, // Empty rule
- `not-xml-at-all`, // Not XML
- `AES256`, // Invalid namespace
- }
-
- for i, malformedXML := range malformedXMLs {
- t.Run(fmt.Sprintf("Malformed XML %d", i), func(t *testing.T) {
- _, err := encryptionConfigFromXMLBytes([]byte(malformedXML))
- if err == nil {
- t.Errorf("Expected error for malformed XML %d, but got none", i)
- }
- })
- }
- })
-}
-
-// TestGetDefaultEncryptionHeaders tests generation of default encryption headers
-func TestGetDefaultEncryptionHeaders(t *testing.T) {
- testCases := []struct {
- name string
- config *s3_pb.EncryptionConfiguration
- expectedHeaders map[string]string
- }{
- {
- name: "Nil configuration",
- config: nil,
- expectedHeaders: nil,
- },
- {
- name: "SSE-S3 configuration",
- config: &s3_pb.EncryptionConfiguration{
- SseAlgorithm: "AES256",
- },
- expectedHeaders: map[string]string{
- "X-Amz-Server-Side-Encryption": "AES256",
- },
- },
- {
- name: "SSE-KMS configuration with key",
- config: &s3_pb.EncryptionConfiguration{
- SseAlgorithm: "aws:kms",
- KmsKeyId: "test-key-id",
- },
- expectedHeaders: map[string]string{
- "X-Amz-Server-Side-Encryption": "aws:kms",
- "X-Amz-Server-Side-Encryption-Aws-Kms-Key-Id": "test-key-id",
- },
- },
- {
- name: "SSE-KMS configuration without key",
- config: &s3_pb.EncryptionConfiguration{
- SseAlgorithm: "aws:kms",
- },
- expectedHeaders: map[string]string{
- "X-Amz-Server-Side-Encryption": "aws:kms",
- },
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- headers := GetDefaultEncryptionHeaders(tc.config)
-
- if tc.expectedHeaders == nil && headers != nil {
- t.Error("Expected nil headers but got some")
- }
-
- if tc.expectedHeaders != nil && headers == nil {
- t.Error("Expected headers but got nil")
- }
-
- if tc.expectedHeaders != nil && headers != nil {
- for key, expectedValue := range tc.expectedHeaders {
- if actualValue, exists := headers[key]; !exists {
- t.Errorf("Expected header %s not found", key)
- } else if actualValue != expectedValue {
- t.Errorf("Header %s: expected %s, got %s", key, expectedValue, actualValue)
- }
- }
-
- // Check for unexpected headers
- for key := range headers {
- if _, expected := tc.expectedHeaders[key]; !expected {
- t.Errorf("Unexpected header found: %s", key)
- }
- }
- }
- })
- }
-}
diff --git a/weed/s3api/s3_sse_c.go b/weed/s3api/s3_sse_c.go
index 79cf96041..97990853f 100644
--- a/weed/s3api/s3_sse_c.go
+++ b/weed/s3api/s3_sse_c.go
@@ -58,9 +58,9 @@ var (
// SSECustomerKey represents a customer-provided encryption key for SSE-C
type SSECustomerKey struct {
- Algorithm string
- Key []byte
- KeyMD5 string
+ Algorithm string
+ Key []byte
+ KeyMD5 string
}
// IsSSECRequest checks if the request contains SSE-C headers
@@ -134,16 +134,6 @@ func validateAndParseSSECHeaders(algorithm, key, keyMD5 string) (*SSECustomerKey
}, nil
}
-// ValidateSSECHeaders validates SSE-C headers in the request
-func ValidateSSECHeaders(r *http.Request) error {
- algorithm := r.Header.Get(s3_constants.AmzServerSideEncryptionCustomerAlgorithm)
- key := r.Header.Get(s3_constants.AmzServerSideEncryptionCustomerKey)
- keyMD5 := r.Header.Get(s3_constants.AmzServerSideEncryptionCustomerKeyMD5)
-
- _, err := validateAndParseSSECHeaders(algorithm, key, keyMD5)
- return err
-}
-
// ParseSSECHeaders parses and validates SSE-C headers from the request
func ParseSSECHeaders(r *http.Request) (*SSECustomerKey, error) {
algorithm := r.Header.Get(s3_constants.AmzServerSideEncryptionCustomerAlgorithm)
diff --git a/weed/s3api/s3_sse_c_test.go b/weed/s3api/s3_sse_c_test.go
deleted file mode 100644
index 034f07a8e..000000000
--- a/weed/s3api/s3_sse_c_test.go
+++ /dev/null
@@ -1,407 +0,0 @@
-package s3api
-
-import (
- "bytes"
- "crypto/md5"
- "encoding/base64"
- "fmt"
- "io"
- "net/http"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
-)
-
-func base64MD5(b []byte) string {
- s := md5.Sum(b)
- return base64.StdEncoding.EncodeToString(s[:])
-}
-
-func TestSSECHeaderValidation(t *testing.T) {
- // Test valid SSE-C headers
- req := &http.Request{Header: make(http.Header)}
-
- key := make([]byte, 32) // 256-bit key
- for i := range key {
- key[i] = byte(i)
- }
-
- keyBase64 := base64.StdEncoding.EncodeToString(key)
- md5sum := md5.Sum(key)
- keyMD5 := base64.StdEncoding.EncodeToString(md5sum[:])
-
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, keyBase64)
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyMD5)
-
- // Test validation
- err := ValidateSSECHeaders(req)
- if err != nil {
- t.Errorf("Expected valid headers, got error: %v", err)
- }
-
- // Test parsing
- customerKey, err := ParseSSECHeaders(req)
- if err != nil {
- t.Errorf("Expected successful parsing, got error: %v", err)
- }
-
- if customerKey == nil {
- t.Error("Expected customer key, got nil")
- }
-
- if customerKey.Algorithm != "AES256" {
- t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm)
- }
-
- if !bytes.Equal(customerKey.Key, key) {
- t.Error("Key doesn't match original")
- }
-
- if customerKey.KeyMD5 != keyMD5 {
- t.Errorf("Expected key MD5 %s, got %s", keyMD5, customerKey.KeyMD5)
- }
-}
-
-func TestSSECCopySourceHeaders(t *testing.T) {
- // Test valid SSE-C copy source headers
- req := &http.Request{Header: make(http.Header)}
-
- key := make([]byte, 32) // 256-bit key
- for i := range key {
- key[i] = byte(i) + 1 // Different from regular test
- }
-
- keyBase64 := base64.StdEncoding.EncodeToString(key)
- md5sum2 := md5.Sum(key)
- keyMD5 := base64.StdEncoding.EncodeToString(md5sum2[:])
-
- req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerAlgorithm, "AES256")
- req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKey, keyBase64)
- req.Header.Set(s3_constants.AmzCopySourceServerSideEncryptionCustomerKeyMD5, keyMD5)
-
- // Test parsing copy source headers
- customerKey, err := ParseSSECCopySourceHeaders(req)
- if err != nil {
- t.Errorf("Expected successful copy source parsing, got error: %v", err)
- }
-
- if customerKey == nil {
- t.Error("Expected customer key from copy source headers, got nil")
- }
-
- if customerKey.Algorithm != "AES256" {
- t.Errorf("Expected algorithm AES256, got %s", customerKey.Algorithm)
- }
-
- if !bytes.Equal(customerKey.Key, key) {
- t.Error("Copy source key doesn't match original")
- }
-
- // Test that regular headers don't interfere with copy source headers
- regularKey, err := ParseSSECHeaders(req)
- if err != nil {
- t.Errorf("Regular header parsing should not fail: %v", err)
- }
-
- if regularKey != nil {
- t.Error("Expected nil for regular headers when only copy source headers are present")
- }
-}
-
-func TestSSECHeaderValidationErrors(t *testing.T) {
- tests := []struct {
- name string
- algorithm string
- key string
- keyMD5 string
- wantErr error
- }{
- {
- name: "invalid algorithm",
- algorithm: "AES128",
- key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
- keyMD5: base64MD5(make([]byte, 32)),
- wantErr: ErrInvalidEncryptionAlgorithm,
- },
- {
- name: "invalid key length",
- algorithm: "AES256",
- key: base64.StdEncoding.EncodeToString(make([]byte, 16)),
- keyMD5: base64MD5(make([]byte, 16)),
- wantErr: ErrInvalidEncryptionKey,
- },
- {
- name: "mismatched MD5",
- algorithm: "AES256",
- key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
- keyMD5: "wrong==md5",
- wantErr: ErrSSECustomerKeyMD5Mismatch,
- },
- {
- name: "incomplete headers",
- algorithm: "AES256",
- key: "",
- keyMD5: "",
- wantErr: ErrInvalidRequest,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- req := &http.Request{Header: make(http.Header)}
-
- if tt.algorithm != "" {
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, tt.algorithm)
- }
- if tt.key != "" {
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, tt.key)
- }
- if tt.keyMD5 != "" {
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, tt.keyMD5)
- }
-
- err := ValidateSSECHeaders(req)
- if err != tt.wantErr {
- t.Errorf("Expected error %v, got %v", tt.wantErr, err)
- }
- })
- }
-}
-
-func TestSSECEncryptionDecryption(t *testing.T) {
- // Create customer key
- key := make([]byte, 32)
- for i := range key {
- key[i] = byte(i)
- }
-
- md5sumKey := md5.Sum(key)
- customerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: key,
- KeyMD5: base64.StdEncoding.EncodeToString(md5sumKey[:]),
- }
-
- // Test data
- testData := []byte("Hello, World! This is a test of SSE-C encryption.")
-
- // Create encrypted reader
- dataReader := bytes.NewReader(testData)
- encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- // Read encrypted data
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Verify data is actually encrypted (different from original)
- if bytes.Equal(encryptedData[16:], testData) { // Skip IV
- t.Error("Data doesn't appear to be encrypted")
- }
-
- // Create decrypted reader
- encryptedReader2 := bytes.NewReader(encryptedData)
- decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- // Read decrypted data
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data: %v", err)
- }
-
- // Verify decrypted data matches original
- if !bytes.Equal(decryptedData, testData) {
- t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData)
- }
-}
-
-func TestSSECIsSSECRequest(t *testing.T) {
- // Test with SSE-C headers
- req := &http.Request{Header: make(http.Header)}
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
-
- if !IsSSECRequest(req) {
- t.Error("Expected IsSSECRequest to return true when SSE-C headers are present")
- }
-
- // Test without SSE-C headers
- req2 := &http.Request{Header: make(http.Header)}
- if IsSSECRequest(req2) {
- t.Error("Expected IsSSECRequest to return false when no SSE-C headers are present")
- }
-}
-
-// Test encryption with different data sizes (similar to s3tests)
-func TestSSECEncryptionVariousSizes(t *testing.T) {
- sizes := []int{1, 13, 1024, 1024 * 1024} // 1B, 13B, 1KB, 1MB
-
- for _, size := range sizes {
- t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
- // Create customer key
- key := make([]byte, 32)
- for i := range key {
- key[i] = byte(i + size) // Make key unique per test
- }
-
- md5sumDyn := md5.Sum(key)
- customerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: key,
- KeyMD5: base64.StdEncoding.EncodeToString(md5sumDyn[:]),
- }
-
- // Create test data of specified size
- testData := make([]byte, size)
- for i := range testData {
- testData[i] = byte('A' + (i % 26)) // Pattern of A-Z
- }
-
- // Encrypt
- dataReader := bytes.NewReader(testData)
- encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Verify encrypted data has same size as original (IV is stored in metadata, not in stream)
- if len(encryptedData) != size {
- t.Errorf("Expected encrypted data length %d (same as original), got %d", size, len(encryptedData))
- }
-
- // Decrypt
- encryptedReader2 := bytes.NewReader(encryptedData)
- decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data: %v", err)
- }
-
- // Verify decrypted data matches original
- if !bytes.Equal(decryptedData, testData) {
- t.Errorf("Decrypted data doesn't match original for size %d", size)
- }
- })
- }
-}
-
-func TestSSECEncryptionWithNilKey(t *testing.T) {
- testData := []byte("test data")
- dataReader := bytes.NewReader(testData)
-
- // Test encryption with nil key (should pass through)
- encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, nil)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader with nil key: %v", err)
- }
-
- result, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read from pass-through reader: %v", err)
- }
-
- if !bytes.Equal(result, testData) {
- t.Error("Data should pass through unchanged when key is nil")
- }
-
- // Test decryption with nil key (should pass through)
- dataReader2 := bytes.NewReader(testData)
- decryptedReader, err := CreateSSECDecryptedReader(dataReader2, nil, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader with nil key: %v", err)
- }
-
- result2, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read from pass-through reader: %v", err)
- }
-
- if !bytes.Equal(result2, testData) {
- t.Error("Data should pass through unchanged when key is nil")
- }
-}
-
-// TestSSECEncryptionSmallBuffers tests the fix for the critical bug where small buffers
-// could corrupt the data stream when reading in chunks smaller than the IV size
-func TestSSECEncryptionSmallBuffers(t *testing.T) {
- testData := []byte("This is a test message for small buffer reads")
-
- // Create customer key
- key := make([]byte, 32)
- for i := range key {
- key[i] = byte(i)
- }
-
- md5sumKey3 := md5.Sum(key)
- customerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: key,
- KeyMD5: base64.StdEncoding.EncodeToString(md5sumKey3[:]),
- }
-
- // Create encrypted reader
- dataReader := bytes.NewReader(testData)
- encryptedReader, iv, err := CreateSSECEncryptedReader(dataReader, customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- // Read with very small buffers (smaller than IV size of 16 bytes)
- var encryptedData []byte
- smallBuffer := make([]byte, 5) // Much smaller than 16-byte IV
-
- for {
- n, err := encryptedReader.Read(smallBuffer)
- if n > 0 {
- encryptedData = append(encryptedData, smallBuffer[:n]...)
- }
- if err == io.EOF {
- break
- }
- if err != nil {
- t.Fatalf("Error reading encrypted data: %v", err)
- }
- }
-
- // Verify we have some encrypted data (IV is in metadata, not in stream)
- if len(encryptedData) == 0 && len(testData) > 0 {
- t.Fatal("Expected encrypted data but got none")
- }
-
- // Expected size: same as original data (IV is stored in metadata, not in stream)
- if len(encryptedData) != len(testData) {
- t.Errorf("Expected encrypted data size %d (same as original), got %d", len(testData), len(encryptedData))
- }
-
- // Decrypt and verify
- encryptedReader2 := bytes.NewReader(encryptedData)
- decryptedReader, err := CreateSSECDecryptedReader(encryptedReader2, customerKey, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data: %v", err)
- }
-
- if !bytes.Equal(decryptedData, testData) {
- t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData)
- }
-}
diff --git a/weed/s3api/s3_sse_copy_test.go b/weed/s3api/s3_sse_copy_test.go
deleted file mode 100644
index b377b45a9..000000000
--- a/weed/s3api/s3_sse_copy_test.go
+++ /dev/null
@@ -1,628 +0,0 @@
-package s3api
-
-import (
- "bytes"
- "io"
- "net/http"
- "strings"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
-)
-
-// TestSSECObjectCopy tests copying SSE-C encrypted objects with different keys
-func TestSSECObjectCopy(t *testing.T) {
- // Original key for source object
- sourceKey := GenerateTestSSECKey(1)
- sourceCustomerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: sourceKey.Key,
- KeyMD5: sourceKey.KeyMD5,
- }
-
- // Destination key for target object
- destKey := GenerateTestSSECKey(2)
- destCustomerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: destKey.Key,
- KeyMD5: destKey.KeyMD5,
- }
-
- testData := "Hello, SSE-C copy world!"
-
- // Encrypt with source key
- encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), sourceCustomerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Test copy strategy determination
- sourceMetadata := make(map[string][]byte)
- StoreSSECIVInMetadata(sourceMetadata, iv)
- sourceMetadata[s3_constants.AmzServerSideEncryptionCustomerAlgorithm] = []byte("AES256")
- sourceMetadata[s3_constants.AmzServerSideEncryptionCustomerKeyMD5] = []byte(sourceKey.KeyMD5)
-
- t.Run("Same key copy (direct copy)", func(t *testing.T) {
- strategy, err := DetermineSSECCopyStrategy(sourceMetadata, sourceCustomerKey, sourceCustomerKey)
- if err != nil {
- t.Fatalf("Failed to determine copy strategy: %v", err)
- }
-
- if strategy != SSECCopyStrategyDirect {
- t.Errorf("Expected direct copy strategy for same key, got %v", strategy)
- }
- })
-
- t.Run("Different key copy (decrypt-encrypt)", func(t *testing.T) {
- strategy, err := DetermineSSECCopyStrategy(sourceMetadata, sourceCustomerKey, destCustomerKey)
- if err != nil {
- t.Fatalf("Failed to determine copy strategy: %v", err)
- }
-
- if strategy != SSECCopyStrategyDecryptEncrypt {
- t.Errorf("Expected decrypt-encrypt copy strategy for different keys, got %v", strategy)
- }
- })
-
- t.Run("Can direct copy check", func(t *testing.T) {
- // Same key should allow direct copy
- canDirect := CanDirectCopySSEC(sourceMetadata, sourceCustomerKey, sourceCustomerKey)
- if !canDirect {
- t.Error("Should allow direct copy with same key")
- }
-
- // Different key should not allow direct copy
- canDirect = CanDirectCopySSEC(sourceMetadata, sourceCustomerKey, destCustomerKey)
- if canDirect {
- t.Error("Should not allow direct copy with different keys")
- }
- })
-
- // Test actual copy operation (decrypt with source key, encrypt with dest key)
- t.Run("Full copy operation", func(t *testing.T) {
- // Decrypt with source key
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), sourceCustomerKey, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- // Re-encrypt with destination key
- reEncryptedReader, destIV, err := CreateSSECEncryptedReader(decryptedReader, destCustomerKey)
- if err != nil {
- t.Fatalf("Failed to create re-encrypted reader: %v", err)
- }
-
- reEncryptedData, err := io.ReadAll(reEncryptedReader)
- if err != nil {
- t.Fatalf("Failed to read re-encrypted data: %v", err)
- }
-
- // Verify we can decrypt with destination key
- finalDecryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(reEncryptedData), destCustomerKey, destIV)
- if err != nil {
- t.Fatalf("Failed to create final decrypted reader: %v", err)
- }
-
- finalData, err := io.ReadAll(finalDecryptedReader)
- if err != nil {
- t.Fatalf("Failed to read final decrypted data: %v", err)
- }
-
- if string(finalData) != testData {
- t.Errorf("Expected %s, got %s", testData, string(finalData))
- }
- })
-}
-
-// TestSSEKMSObjectCopy tests copying SSE-KMS encrypted objects
-func TestSSEKMSObjectCopy(t *testing.T) {
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- testData := "Hello, SSE-KMS copy world!"
- encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false)
-
- // Encrypt with SSE-KMS
- encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(testData), kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- t.Run("Same KMS key copy", func(t *testing.T) {
- // Decrypt with original key
- decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- // Re-encrypt with same KMS key
- reEncryptedReader, newSseKey, err := CreateSSEKMSEncryptedReader(decryptedReader, kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create re-encrypted reader: %v", err)
- }
-
- reEncryptedData, err := io.ReadAll(reEncryptedReader)
- if err != nil {
- t.Fatalf("Failed to read re-encrypted data: %v", err)
- }
-
- // Verify we can decrypt with new key
- finalDecryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(reEncryptedData), newSseKey)
- if err != nil {
- t.Fatalf("Failed to create final decrypted reader: %v", err)
- }
-
- finalData, err := io.ReadAll(finalDecryptedReader)
- if err != nil {
- t.Fatalf("Failed to read final decrypted data: %v", err)
- }
-
- if string(finalData) != testData {
- t.Errorf("Expected %s, got %s", testData, string(finalData))
- }
- })
-}
-
-// TestSSECToSSEKMSCopy tests cross-encryption copy (SSE-C to SSE-KMS)
-func TestSSECToSSEKMSCopy(t *testing.T) {
- // Setup SSE-C key
- ssecKey := GenerateTestSSECKey(1)
- ssecCustomerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: ssecKey.Key,
- KeyMD5: ssecKey.KeyMD5,
- }
-
- // Setup SSE-KMS
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- testData := "Hello, cross-encryption copy world!"
-
- // Encrypt with SSE-C
- encryptedReader, ssecIV, err := CreateSSECEncryptedReader(strings.NewReader(testData), ssecCustomerKey)
- if err != nil {
- t.Fatalf("Failed to create SSE-C encrypted reader: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read SSE-C encrypted data: %v", err)
- }
-
- // Decrypt SSE-C data
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), ssecCustomerKey, ssecIV)
- if err != nil {
- t.Fatalf("Failed to create SSE-C decrypted reader: %v", err)
- }
-
- // Re-encrypt with SSE-KMS
- encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false)
- reEncryptedReader, sseKmsKey, err := CreateSSEKMSEncryptedReader(decryptedReader, kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create SSE-KMS encrypted reader: %v", err)
- }
-
- reEncryptedData, err := io.ReadAll(reEncryptedReader)
- if err != nil {
- t.Fatalf("Failed to read SSE-KMS encrypted data: %v", err)
- }
-
- // Decrypt with SSE-KMS
- finalDecryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(reEncryptedData), sseKmsKey)
- if err != nil {
- t.Fatalf("Failed to create SSE-KMS decrypted reader: %v", err)
- }
-
- finalData, err := io.ReadAll(finalDecryptedReader)
- if err != nil {
- t.Fatalf("Failed to read final decrypted data: %v", err)
- }
-
- if string(finalData) != testData {
- t.Errorf("Expected %s, got %s", testData, string(finalData))
- }
-}
-
-// TestSSEKMSToSSECCopy tests cross-encryption copy (SSE-KMS to SSE-C)
-func TestSSEKMSToSSECCopy(t *testing.T) {
- // Setup SSE-KMS
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- // Setup SSE-C key
- ssecKey := GenerateTestSSECKey(1)
- ssecCustomerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: ssecKey.Key,
- KeyMD5: ssecKey.KeyMD5,
- }
-
- testData := "Hello, reverse cross-encryption copy world!"
- encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false)
-
- // Encrypt with SSE-KMS
- encryptedReader, sseKmsKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(testData), kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create SSE-KMS encrypted reader: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read SSE-KMS encrypted data: %v", err)
- }
-
- // Decrypt SSE-KMS data
- decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKmsKey)
- if err != nil {
- t.Fatalf("Failed to create SSE-KMS decrypted reader: %v", err)
- }
-
- // Re-encrypt with SSE-C
- reEncryptedReader, reEncryptedIV, err := CreateSSECEncryptedReader(decryptedReader, ssecCustomerKey)
- if err != nil {
- t.Fatalf("Failed to create SSE-C encrypted reader: %v", err)
- }
-
- reEncryptedData, err := io.ReadAll(reEncryptedReader)
- if err != nil {
- t.Fatalf("Failed to read SSE-C encrypted data: %v", err)
- }
-
- // Decrypt with SSE-C
- finalDecryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(reEncryptedData), ssecCustomerKey, reEncryptedIV)
- if err != nil {
- t.Fatalf("Failed to create SSE-C decrypted reader: %v", err)
- }
-
- finalData, err := io.ReadAll(finalDecryptedReader)
- if err != nil {
- t.Fatalf("Failed to read final decrypted data: %v", err)
- }
-
- if string(finalData) != testData {
- t.Errorf("Expected %s, got %s", testData, string(finalData))
- }
-}
-
-// TestSSECopyWithCorruptedSource tests copy operations with corrupted source data
-func TestSSECopyWithCorruptedSource(t *testing.T) {
- ssecKey := GenerateTestSSECKey(1)
- ssecCustomerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: ssecKey.Key,
- KeyMD5: ssecKey.KeyMD5,
- }
-
- testData := "Hello, corruption test!"
-
- // Encrypt data
- encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), ssecCustomerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Corrupt the encrypted data
- corruptedData := make([]byte, len(encryptedData))
- copy(corruptedData, encryptedData)
- if len(corruptedData) > s3_constants.AESBlockSize {
- // Corrupt a byte after the IV
- corruptedData[s3_constants.AESBlockSize] ^= 0xFF
- }
-
- // Try to decrypt corrupted data
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(corruptedData), ssecCustomerKey, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for corrupted data: %v", err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- // This is okay - corrupted data might cause read errors
- t.Logf("Read error for corrupted data (expected): %v", err)
- return
- }
-
- // If we can read it, the data should be different from original
- if string(decryptedData) == testData {
- t.Error("Decrypted corrupted data should not match original")
- }
-}
-
-// TestSSEKMSCopyStrategy tests SSE-KMS copy strategy determination
-func TestSSEKMSCopyStrategy(t *testing.T) {
- tests := []struct {
- name string
- srcMetadata map[string][]byte
- destKeyID string
- expectedStrategy SSEKMSCopyStrategy
- }{
- {
- name: "Unencrypted to unencrypted",
- srcMetadata: map[string][]byte{},
- destKeyID: "",
- expectedStrategy: SSEKMSCopyStrategyDirect,
- },
- {
- name: "Same KMS key",
- srcMetadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"),
- },
- destKeyID: "test-key-123",
- expectedStrategy: SSEKMSCopyStrategyDirect,
- },
- {
- name: "Different KMS keys",
- srcMetadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"),
- },
- destKeyID: "test-key-456",
- expectedStrategy: SSEKMSCopyStrategyDecryptEncrypt,
- },
- {
- name: "Encrypted to unencrypted",
- srcMetadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"),
- },
- destKeyID: "",
- expectedStrategy: SSEKMSCopyStrategyDecryptEncrypt,
- },
- {
- name: "Unencrypted to encrypted",
- srcMetadata: map[string][]byte{},
- destKeyID: "test-key-123",
- expectedStrategy: SSEKMSCopyStrategyDecryptEncrypt,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- strategy, err := DetermineSSEKMSCopyStrategy(tt.srcMetadata, tt.destKeyID)
- if err != nil {
- t.Fatalf("DetermineSSEKMSCopyStrategy failed: %v", err)
- }
- if strategy != tt.expectedStrategy {
- t.Errorf("Expected strategy %v, got %v", tt.expectedStrategy, strategy)
- }
- })
- }
-}
-
-// TestSSEKMSCopyHeaders tests SSE-KMS copy header parsing
-func TestSSEKMSCopyHeaders(t *testing.T) {
- tests := []struct {
- name string
- headers map[string]string
- expectedKeyID string
- expectedContext map[string]string
- expectedBucketKey bool
- expectError bool
- }{
- {
- name: "No SSE-KMS headers",
- headers: map[string]string{},
- expectedKeyID: "",
- expectedContext: nil,
- expectedBucketKey: false,
- expectError: false,
- },
- {
- name: "SSE-KMS with key ID",
- headers: map[string]string{
- s3_constants.AmzServerSideEncryption: "aws:kms",
- s3_constants.AmzServerSideEncryptionAwsKmsKeyId: "test-key-123",
- },
- expectedKeyID: "test-key-123",
- expectedContext: nil,
- expectedBucketKey: false,
- expectError: false,
- },
- {
- name: "SSE-KMS with all options",
- headers: map[string]string{
- s3_constants.AmzServerSideEncryption: "aws:kms",
- s3_constants.AmzServerSideEncryptionAwsKmsKeyId: "test-key-123",
- s3_constants.AmzServerSideEncryptionContext: "eyJ0ZXN0IjoidmFsdWUifQ==", // base64 of {"test":"value"}
- s3_constants.AmzServerSideEncryptionBucketKeyEnabled: "true",
- },
- expectedKeyID: "test-key-123",
- expectedContext: map[string]string{"test": "value"},
- expectedBucketKey: true,
- expectError: false,
- },
- {
- name: "Invalid key ID",
- headers: map[string]string{
- s3_constants.AmzServerSideEncryption: "aws:kms",
- s3_constants.AmzServerSideEncryptionAwsKmsKeyId: "invalid key id",
- },
- expectError: true,
- },
- {
- name: "Invalid encryption context",
- headers: map[string]string{
- s3_constants.AmzServerSideEncryption: "aws:kms",
- s3_constants.AmzServerSideEncryptionContext: "invalid-base64!",
- },
- expectError: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- req, _ := http.NewRequest("PUT", "/test", nil)
- for k, v := range tt.headers {
- req.Header.Set(k, v)
- }
-
- keyID, context, bucketKey, err := ParseSSEKMSCopyHeaders(req)
-
- if tt.expectError {
- if err == nil {
- t.Error("Expected error but got none")
- }
- return
- }
-
- if err != nil {
- t.Fatalf("Unexpected error: %v", err)
- }
-
- if keyID != tt.expectedKeyID {
- t.Errorf("Expected keyID %s, got %s", tt.expectedKeyID, keyID)
- }
-
- if !mapsEqual(context, tt.expectedContext) {
- t.Errorf("Expected context %v, got %v", tt.expectedContext, context)
- }
-
- if bucketKey != tt.expectedBucketKey {
- t.Errorf("Expected bucketKey %v, got %v", tt.expectedBucketKey, bucketKey)
- }
- })
- }
-}
-
-// TestSSEKMSDirectCopy tests direct copy scenarios
-func TestSSEKMSDirectCopy(t *testing.T) {
- tests := []struct {
- name string
- srcMetadata map[string][]byte
- destKeyID string
- canDirect bool
- }{
- {
- name: "Both unencrypted",
- srcMetadata: map[string][]byte{},
- destKeyID: "",
- canDirect: true,
- },
- {
- name: "Same key ID",
- srcMetadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"),
- },
- destKeyID: "test-key-123",
- canDirect: true,
- },
- {
- name: "Different key IDs",
- srcMetadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"),
- },
- destKeyID: "test-key-456",
- canDirect: false,
- },
- {
- name: "Source encrypted, dest unencrypted",
- srcMetadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"),
- },
- destKeyID: "",
- canDirect: false,
- },
- {
- name: "Source unencrypted, dest encrypted",
- srcMetadata: map[string][]byte{},
- destKeyID: "test-key-123",
- canDirect: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- canDirect := CanDirectCopySSEKMS(tt.srcMetadata, tt.destKeyID)
- if canDirect != tt.canDirect {
- t.Errorf("Expected canDirect %v, got %v", tt.canDirect, canDirect)
- }
- })
- }
-}
-
-// TestGetSourceSSEKMSInfo tests extraction of SSE-KMS info from metadata
-func TestGetSourceSSEKMSInfo(t *testing.T) {
- tests := []struct {
- name string
- metadata map[string][]byte
- expectedKeyID string
- expectedEncrypted bool
- }{
- {
- name: "No encryption",
- metadata: map[string][]byte{},
- expectedKeyID: "",
- expectedEncrypted: false,
- },
- {
- name: "SSE-KMS with key ID",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-123"),
- },
- expectedKeyID: "test-key-123",
- expectedEncrypted: true,
- },
- {
- name: "SSE-KMS without key ID (default key)",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- },
- expectedKeyID: "",
- expectedEncrypted: true,
- },
- {
- name: "Non-KMS encryption",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- },
- expectedKeyID: "",
- expectedEncrypted: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- keyID, encrypted := GetSourceSSEKMSInfo(tt.metadata)
- if keyID != tt.expectedKeyID {
- t.Errorf("Expected keyID %s, got %s", tt.expectedKeyID, keyID)
- }
- if encrypted != tt.expectedEncrypted {
- t.Errorf("Expected encrypted %v, got %v", tt.expectedEncrypted, encrypted)
- }
- })
- }
-}
-
-// Helper function to compare maps
-func mapsEqual(a, b map[string]string) bool {
- if len(a) != len(b) {
- return false
- }
- for k, v := range a {
- if b[k] != v {
- return false
- }
- }
- return true
-}
diff --git a/weed/s3api/s3_sse_error_test.go b/weed/s3api/s3_sse_error_test.go
deleted file mode 100644
index a344e2ef7..000000000
--- a/weed/s3api/s3_sse_error_test.go
+++ /dev/null
@@ -1,400 +0,0 @@
-package s3api
-
-import (
- "bytes"
- "fmt"
- "io"
- "net/http"
- "strings"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
-)
-
-// TestSSECWrongKeyDecryption tests decryption with wrong SSE-C key
-func TestSSECWrongKeyDecryption(t *testing.T) {
- // Setup original key and encrypt data
- originalKey := GenerateTestSSECKey(1)
- testData := "Hello, SSE-C world!"
-
- encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), &SSECustomerKey{
- Algorithm: "AES256",
- Key: originalKey.Key,
- KeyMD5: originalKey.KeyMD5,
- })
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- // Read encrypted data
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Try to decrypt with wrong key
- wrongKey := GenerateTestSSECKey(2) // Different seed = different key
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), &SSECustomerKey{
- Algorithm: "AES256",
- Key: wrongKey.Key,
- KeyMD5: wrongKey.KeyMD5,
- }, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- // Read decrypted data - should be garbage/different from original
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data: %v", err)
- }
-
- // Verify the decrypted data is NOT the same as original (wrong key used)
- if string(decryptedData) == testData {
- t.Error("Decryption with wrong key should not produce original data")
- }
-}
-
-// TestSSEKMSKeyNotFound tests handling of missing KMS key
-func TestSSEKMSKeyNotFound(t *testing.T) {
- // Note: The local KMS provider creates keys on-demand by design.
- // This test validates that when on-demand creation fails or is disabled,
- // appropriate errors are returned.
-
- // Test with an invalid key ID that would fail even on-demand creation
- invalidKeyID := "" // Empty key ID should fail
- encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false)
-
- _, _, err := CreateSSEKMSEncryptedReader(strings.NewReader("test data"), invalidKeyID, encryptionContext)
-
- // Should get an error for invalid/empty key
- if err == nil {
- t.Error("Expected error for empty KMS key ID, got none")
- }
-
- // For local KMS with on-demand creation, we test what we can realistically test
- if err != nil {
- t.Logf("Got expected error for empty key ID: %v", err)
- }
-}
-
-// TestSSEHeadersWithoutEncryption tests inconsistent state where headers are present but no encryption
-func TestSSEHeadersWithoutEncryption(t *testing.T) {
- testCases := []struct {
- name string
- setupReq func() *http.Request
- }{
- {
- name: "SSE-C algorithm without key",
- setupReq: func() *http.Request {
- req := CreateTestHTTPRequest("PUT", "/bucket/object", nil)
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
- // Missing key and MD5
- return req
- },
- },
- {
- name: "SSE-C key without algorithm",
- setupReq: func() *http.Request {
- req := CreateTestHTTPRequest("PUT", "/bucket/object", nil)
- keyPair := GenerateTestSSECKey(1)
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, keyPair.KeyB64)
- // Missing algorithm
- return req
- },
- },
- {
- name: "SSE-KMS key ID without algorithm",
- setupReq: func() *http.Request {
- req := CreateTestHTTPRequest("PUT", "/bucket/object", nil)
- req.Header.Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, "test-key-id")
- // Missing algorithm
- return req
- },
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req := tc.setupReq()
-
- // Validate headers - should catch incomplete configurations
- if strings.Contains(tc.name, "SSE-C") {
- err := ValidateSSECHeaders(req)
- if err == nil {
- t.Error("Expected validation error for incomplete SSE-C headers")
- }
- }
- })
- }
-}
-
-// TestSSECInvalidKeyFormats tests various invalid SSE-C key formats
-func TestSSECInvalidKeyFormats(t *testing.T) {
- testCases := []struct {
- name string
- algorithm string
- key string
- keyMD5 string
- expectErr bool
- }{
- {
- name: "Invalid algorithm",
- algorithm: "AES128",
- key: "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXk=", // 32 bytes base64
- keyMD5: "valid-md5-hash",
- expectErr: true,
- },
- {
- name: "Invalid key length (too short)",
- algorithm: "AES256",
- key: "c2hvcnRrZXk=", // "shortkey" base64 - too short
- keyMD5: "valid-md5-hash",
- expectErr: true,
- },
- {
- name: "Invalid key length (too long)",
- algorithm: "AES256",
- key: "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleQ==", // too long
- keyMD5: "valid-md5-hash",
- expectErr: true,
- },
- {
- name: "Invalid base64 key",
- algorithm: "AES256",
- key: "invalid-base64!",
- keyMD5: "valid-md5-hash",
- expectErr: true,
- },
- {
- name: "Invalid base64 MD5",
- algorithm: "AES256",
- key: "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXk=",
- keyMD5: "invalid-base64!",
- expectErr: true,
- },
- {
- name: "Mismatched MD5",
- algorithm: "AES256",
- key: "dGVzdGtleXRlc3RrZXl0ZXN0a2V5dGVzdGtleXRlc3RrZXk=",
- keyMD5: "d29uZy1tZDUtaGFzaA==", // "wrong-md5-hash" base64
- expectErr: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req := CreateTestHTTPRequest("PUT", "/bucket/object", nil)
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, tc.algorithm)
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, tc.key)
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, tc.keyMD5)
-
- err := ValidateSSECHeaders(req)
- if tc.expectErr && err == nil {
- t.Errorf("Expected error for %s, but got none", tc.name)
- }
- if !tc.expectErr && err != nil {
- t.Errorf("Expected no error for %s, but got: %v", tc.name, err)
- }
- })
- }
-}
-
-// TestSSEKMSInvalidConfigurations tests various invalid SSE-KMS configurations
-func TestSSEKMSInvalidConfigurations(t *testing.T) {
- testCases := []struct {
- name string
- setupRequest func() *http.Request
- expectError bool
- }{
- {
- name: "Invalid algorithm",
- setupRequest: func() *http.Request {
- req := CreateTestHTTPRequest("PUT", "/bucket/object", nil)
- req.Header.Set(s3_constants.AmzServerSideEncryption, "invalid-algorithm")
- return req
- },
- expectError: true,
- },
- {
- name: "Empty key ID",
- setupRequest: func() *http.Request {
- req := CreateTestHTTPRequest("PUT", "/bucket/object", nil)
- req.Header.Set(s3_constants.AmzServerSideEncryption, "aws:kms")
- req.Header.Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, "")
- return req
- },
- expectError: false, // Empty key ID might be valid (use default)
- },
- {
- name: "Invalid key ID format",
- setupRequest: func() *http.Request {
- req := CreateTestHTTPRequest("PUT", "/bucket/object", nil)
- req.Header.Set(s3_constants.AmzServerSideEncryption, "aws:kms")
- req.Header.Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, "invalid key id with spaces")
- return req
- },
- expectError: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req := tc.setupRequest()
-
- _, err := ParseSSEKMSHeaders(req)
- if tc.expectError && err == nil {
- t.Errorf("Expected error for %s, but got none", tc.name)
- }
- if !tc.expectError && err != nil {
- t.Errorf("Expected no error for %s, but got: %v", tc.name, err)
- }
- })
- }
-}
-
-// TestSSEEmptyDataHandling tests handling of empty data with SSE
-func TestSSEEmptyDataHandling(t *testing.T) {
- t.Run("SSE-C with empty data", func(t *testing.T) {
- keyPair := GenerateTestSSECKey(1)
- customerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: keyPair.Key,
- KeyMD5: keyPair.KeyMD5,
- }
-
- // Encrypt empty data
- encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(""), customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader for empty data: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted empty data: %v", err)
- }
-
- // Should have IV for empty data
- if len(iv) != s3_constants.AESBlockSize {
- t.Error("IV should be present even for empty data")
- }
-
- // Decrypt and verify
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for empty data: %v", err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted empty data: %v", err)
- }
-
- if len(decryptedData) != 0 {
- t.Errorf("Expected empty decrypted data, got %d bytes", len(decryptedData))
- }
- })
-
- t.Run("SSE-KMS with empty data", func(t *testing.T) {
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false)
-
- // Encrypt empty data
- encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(""), kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader for empty data: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted empty data: %v", err)
- }
-
- // Empty data should produce empty encrypted data (IV is stored in metadata)
- if len(encryptedData) != 0 {
- t.Errorf("Encrypted empty data should be empty, got %d bytes", len(encryptedData))
- }
-
- // Decrypt and verify
- decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for empty data: %v", err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted empty data: %v", err)
- }
-
- if len(decryptedData) != 0 {
- t.Errorf("Expected empty decrypted data, got %d bytes", len(decryptedData))
- }
- })
-}
-
-// TestSSEConcurrentAccess tests SSE operations under concurrent access
-func TestSSEConcurrentAccess(t *testing.T) {
- keyPair := GenerateTestSSECKey(1)
- customerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: keyPair.Key,
- KeyMD5: keyPair.KeyMD5,
- }
-
- const numGoroutines = 10
- done := make(chan bool, numGoroutines)
- errors := make(chan error, numGoroutines)
-
- // Run multiple encryption/decryption operations concurrently
- for i := 0; i < numGoroutines; i++ {
- go func(id int) {
- defer func() { done <- true }()
-
- testData := fmt.Sprintf("test data %d", id)
-
- // Encrypt
- encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), customerKey)
- if err != nil {
- errors <- fmt.Errorf("goroutine %d encrypt error: %v", id, err)
- return
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- errors <- fmt.Errorf("goroutine %d read encrypted error: %v", id, err)
- return
- }
-
- // Decrypt
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv)
- if err != nil {
- errors <- fmt.Errorf("goroutine %d decrypt error: %v", id, err)
- return
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- errors <- fmt.Errorf("goroutine %d read decrypted error: %v", id, err)
- return
- }
-
- if string(decryptedData) != testData {
- errors <- fmt.Errorf("goroutine %d data mismatch: expected %s, got %s", id, testData, string(decryptedData))
- return
- }
- }(i)
- }
-
- // Wait for all goroutines to complete
- for i := 0; i < numGoroutines; i++ {
- <-done
- }
-
- // Check for errors
- close(errors)
- for err := range errors {
- t.Error(err)
- }
-}
diff --git a/weed/s3api/s3_sse_http_test.go b/weed/s3api/s3_sse_http_test.go
deleted file mode 100644
index 95f141ca7..000000000
--- a/weed/s3api/s3_sse_http_test.go
+++ /dev/null
@@ -1,401 +0,0 @@
-package s3api
-
-import (
- "bytes"
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
-)
-
-// TestPutObjectWithSSEC tests PUT object with SSE-C through HTTP handler
-func TestPutObjectWithSSEC(t *testing.T) {
- keyPair := GenerateTestSSECKey(1)
- testData := "Hello, SSE-C PUT object!"
-
- // Create HTTP request
- req := CreateTestHTTPRequest("PUT", "/test-bucket/test-object", []byte(testData))
- SetupTestSSECHeaders(req, keyPair)
- SetupTestMuxVars(req, map[string]string{
- "bucket": "test-bucket",
- "object": "test-object",
- })
-
- // Create response recorder
- w := CreateTestHTTPResponse()
-
- // Test header validation
- err := ValidateSSECHeaders(req)
- if err != nil {
- t.Fatalf("Header validation failed: %v", err)
- }
-
- // Parse SSE-C headers
- customerKey, err := ParseSSECHeaders(req)
- if err != nil {
- t.Fatalf("Failed to parse SSE-C headers: %v", err)
- }
-
- if customerKey == nil {
- t.Fatal("Expected customer key, got nil")
- }
-
- // Verify parsed key matches input
- if !bytes.Equal(customerKey.Key, keyPair.Key) {
- t.Error("Parsed key doesn't match input key")
- }
-
- if customerKey.KeyMD5 != keyPair.KeyMD5 {
- t.Errorf("Parsed key MD5 doesn't match: expected %s, got %s", keyPair.KeyMD5, customerKey.KeyMD5)
- }
-
- // Simulate setting response headers
- w.Header().Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
- w.Header().Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyPair.KeyMD5)
-
- // Verify response headers
- AssertSSECHeaders(t, w, keyPair)
-}
-
-// TestGetObjectWithSSEC tests GET object with SSE-C through HTTP handler
-func TestGetObjectWithSSEC(t *testing.T) {
- keyPair := GenerateTestSSECKey(1)
-
- // Create HTTP request for GET
- req := CreateTestHTTPRequest("GET", "/test-bucket/test-object", nil)
- SetupTestSSECHeaders(req, keyPair)
- SetupTestMuxVars(req, map[string]string{
- "bucket": "test-bucket",
- "object": "test-object",
- })
-
- // Create response recorder
- w := CreateTestHTTPResponse()
-
- // Test that SSE-C is detected for GET requests
- if !IsSSECRequest(req) {
- t.Error("Should detect SSE-C request for GET with SSE-C headers")
- }
-
- // Validate headers
- err := ValidateSSECHeaders(req)
- if err != nil {
- t.Fatalf("Header validation failed: %v", err)
- }
-
- // Simulate response with SSE-C headers
- w.Header().Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
- w.Header().Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyPair.KeyMD5)
- w.WriteHeader(http.StatusOK)
-
- // Verify response
- if w.Code != http.StatusOK {
- t.Errorf("Expected status 200, got %d", w.Code)
- }
-
- AssertSSECHeaders(t, w, keyPair)
-}
-
-// TestPutObjectWithSSEKMS tests PUT object with SSE-KMS through HTTP handler
-func TestPutObjectWithSSEKMS(t *testing.T) {
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- testData := "Hello, SSE-KMS PUT object!"
-
- // Create HTTP request
- req := CreateTestHTTPRequest("PUT", "/test-bucket/test-object", []byte(testData))
- SetupTestSSEKMSHeaders(req, kmsKey.KeyID)
- SetupTestMuxVars(req, map[string]string{
- "bucket": "test-bucket",
- "object": "test-object",
- })
-
- // Create response recorder
- w := CreateTestHTTPResponse()
-
- // Test that SSE-KMS is detected
- if !IsSSEKMSRequest(req) {
- t.Error("Should detect SSE-KMS request")
- }
-
- // Parse SSE-KMS headers
- sseKmsKey, err := ParseSSEKMSHeaders(req)
- if err != nil {
- t.Fatalf("Failed to parse SSE-KMS headers: %v", err)
- }
-
- if sseKmsKey == nil {
- t.Fatal("Expected SSE-KMS key, got nil")
- }
-
- if sseKmsKey.KeyID != kmsKey.KeyID {
- t.Errorf("Parsed key ID doesn't match: expected %s, got %s", kmsKey.KeyID, sseKmsKey.KeyID)
- }
-
- // Simulate setting response headers
- w.Header().Set(s3_constants.AmzServerSideEncryption, "aws:kms")
- w.Header().Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, kmsKey.KeyID)
-
- // Verify response headers
- AssertSSEKMSHeaders(t, w, kmsKey.KeyID)
-}
-
-// TestGetObjectWithSSEKMS tests GET object with SSE-KMS through HTTP handler
-func TestGetObjectWithSSEKMS(t *testing.T) {
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- // Create HTTP request for GET (no SSE headers needed for GET)
- req := CreateTestHTTPRequest("GET", "/test-bucket/test-object", nil)
- SetupTestMuxVars(req, map[string]string{
- "bucket": "test-bucket",
- "object": "test-object",
- })
-
- // Create response recorder
- w := CreateTestHTTPResponse()
-
- // Simulate response with SSE-KMS headers (would come from stored metadata)
- w.Header().Set(s3_constants.AmzServerSideEncryption, "aws:kms")
- w.Header().Set(s3_constants.AmzServerSideEncryptionAwsKmsKeyId, kmsKey.KeyID)
- w.WriteHeader(http.StatusOK)
-
- // Verify response
- if w.Code != http.StatusOK {
- t.Errorf("Expected status 200, got %d", w.Code)
- }
-
- AssertSSEKMSHeaders(t, w, kmsKey.KeyID)
-}
-
-// TestSSECRangeRequestSupport tests that range requests are now supported for SSE-C
-func TestSSECRangeRequestSupport(t *testing.T) {
- keyPair := GenerateTestSSECKey(1)
-
- // Create HTTP request with Range header
- req := CreateTestHTTPRequest("GET", "/test-bucket/test-object", nil)
- req.Header.Set("Range", "bytes=0-100")
- SetupTestSSECHeaders(req, keyPair)
- SetupTestMuxVars(req, map[string]string{
- "bucket": "test-bucket",
- "object": "test-object",
- })
-
- // Create a mock proxy response with SSE-C headers
- proxyResponse := httptest.NewRecorder()
- proxyResponse.Header().Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
- proxyResponse.Header().Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, keyPair.KeyMD5)
- proxyResponse.Header().Set("Content-Length", "1000")
-
- // Test the detection logic - these should all still work
-
- // Should detect as SSE-C request
- if !IsSSECRequest(req) {
- t.Error("Should detect SSE-C request")
- }
-
- // Should detect range request
- if req.Header.Get("Range") == "" {
- t.Error("Range header should be present")
- }
-
- // The combination should now be allowed and handled by the filer layer
- // Range requests with SSE-C are now supported since IV is stored in metadata
-}
-
-// TestSSEHeaderConflicts tests conflicting SSE headers
-func TestSSEHeaderConflicts(t *testing.T) {
- testCases := []struct {
- name string
- setupFn func(*http.Request)
- valid bool
- }{
- {
- name: "SSE-C and SSE-KMS conflict",
- setupFn: func(req *http.Request) {
- keyPair := GenerateTestSSECKey(1)
- SetupTestSSECHeaders(req, keyPair)
- SetupTestSSEKMSHeaders(req, "test-key-id")
- },
- valid: false,
- },
- {
- name: "Valid SSE-C only",
- setupFn: func(req *http.Request) {
- keyPair := GenerateTestSSECKey(1)
- SetupTestSSECHeaders(req, keyPair)
- },
- valid: true,
- },
- {
- name: "Valid SSE-KMS only",
- setupFn: func(req *http.Request) {
- SetupTestSSEKMSHeaders(req, "test-key-id")
- },
- valid: true,
- },
- {
- name: "No SSE headers",
- setupFn: func(req *http.Request) {
- // No SSE headers
- },
- valid: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req := CreateTestHTTPRequest("PUT", "/test-bucket/test-object", []byte("test"))
- tc.setupFn(req)
-
- ssecDetected := IsSSECRequest(req)
- sseKmsDetected := IsSSEKMSRequest(req)
-
- // Both shouldn't be detected simultaneously
- if ssecDetected && sseKmsDetected {
- t.Error("Both SSE-C and SSE-KMS should not be detected simultaneously")
- }
-
- // Test validation if SSE-C is detected
- if ssecDetected {
- err := ValidateSSECHeaders(req)
- if tc.valid && err != nil {
- t.Errorf("Expected valid SSE-C headers, got error: %v", err)
- }
- if !tc.valid && err == nil && tc.name == "SSE-C and SSE-KMS conflict" {
- // This specific test case should probably be handled at a higher level
- t.Log("Conflict detection should be handled by higher-level validation")
- }
- }
- })
- }
-}
-
-// TestSSECopySourceHeaders tests copy operations with SSE headers
-func TestSSECopySourceHeaders(t *testing.T) {
- sourceKey := GenerateTestSSECKey(1)
- destKey := GenerateTestSSECKey(2)
-
- // Create copy request with both source and destination SSE-C headers
- req := CreateTestHTTPRequest("PUT", "/dest-bucket/dest-object", nil)
-
- // Set copy source headers
- SetupTestSSECCopyHeaders(req, sourceKey)
-
- // Set destination headers
- SetupTestSSECHeaders(req, destKey)
-
- // Set copy source
- req.Header.Set("X-Amz-Copy-Source", "/source-bucket/source-object")
-
- SetupTestMuxVars(req, map[string]string{
- "bucket": "dest-bucket",
- "object": "dest-object",
- })
-
- // Parse copy source headers
- copySourceKey, err := ParseSSECCopySourceHeaders(req)
- if err != nil {
- t.Fatalf("Failed to parse copy source headers: %v", err)
- }
-
- if copySourceKey == nil {
- t.Fatal("Expected copy source key, got nil")
- }
-
- if !bytes.Equal(copySourceKey.Key, sourceKey.Key) {
- t.Error("Copy source key doesn't match")
- }
-
- // Parse destination headers
- destCustomerKey, err := ParseSSECHeaders(req)
- if err != nil {
- t.Fatalf("Failed to parse destination headers: %v", err)
- }
-
- if destCustomerKey == nil {
- t.Fatal("Expected destination key, got nil")
- }
-
- if !bytes.Equal(destCustomerKey.Key, destKey.Key) {
- t.Error("Destination key doesn't match")
- }
-}
-
-// TestSSERequestValidation tests comprehensive request validation
-func TestSSERequestValidation(t *testing.T) {
- testCases := []struct {
- name string
- method string
- setupFn func(*http.Request)
- expectError bool
- errorType string
- }{
- {
- name: "Valid PUT with SSE-C",
- method: "PUT",
- setupFn: func(req *http.Request) {
- keyPair := GenerateTestSSECKey(1)
- SetupTestSSECHeaders(req, keyPair)
- },
- expectError: false,
- },
- {
- name: "Valid GET with SSE-C",
- method: "GET",
- setupFn: func(req *http.Request) {
- keyPair := GenerateTestSSECKey(1)
- SetupTestSSECHeaders(req, keyPair)
- },
- expectError: false,
- },
- {
- name: "Invalid SSE-C key format",
- method: "PUT",
- setupFn: func(req *http.Request) {
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, "invalid-key")
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKeyMD5, "invalid-md5")
- },
- expectError: true,
- errorType: "InvalidRequest",
- },
- {
- name: "Missing SSE-C key MD5",
- method: "PUT",
- setupFn: func(req *http.Request) {
- keyPair := GenerateTestSSECKey(1)
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerAlgorithm, "AES256")
- req.Header.Set(s3_constants.AmzServerSideEncryptionCustomerKey, keyPair.KeyB64)
- // Missing MD5
- },
- expectError: true,
- errorType: "InvalidRequest",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req := CreateTestHTTPRequest(tc.method, "/test-bucket/test-object", []byte("test data"))
- tc.setupFn(req)
-
- SetupTestMuxVars(req, map[string]string{
- "bucket": "test-bucket",
- "object": "test-object",
- })
-
- // Test header validation
- if IsSSECRequest(req) {
- err := ValidateSSECHeaders(req)
- if tc.expectError && err == nil {
- t.Errorf("Expected error for %s, but got none", tc.name)
- }
- if !tc.expectError && err != nil {
- t.Errorf("Expected no error for %s, but got: %v", tc.name, err)
- }
- }
- })
- }
-}
diff --git a/weed/s3api/s3_sse_kms.go b/weed/s3api/s3_sse_kms.go
index fa9451a8f..b87e0bf1a 100644
--- a/weed/s3api/s3_sse_kms.go
+++ b/weed/s3api/s3_sse_kms.go
@@ -59,11 +59,6 @@ const (
// Bucket key cache TTL (moved to be used with per-bucket cache)
const BucketKeyCacheTTL = time.Hour
-// CreateSSEKMSEncryptedReader creates an encrypted reader using KMS envelope encryption
-func CreateSSEKMSEncryptedReader(r io.Reader, keyID string, encryptionContext map[string]string) (io.Reader, *SSEKMSKey, error) {
- return CreateSSEKMSEncryptedReaderWithBucketKey(r, keyID, encryptionContext, false)
-}
-
// CreateSSEKMSEncryptedReaderWithBucketKey creates an encrypted reader with optional S3 Bucket Keys optimization
func CreateSSEKMSEncryptedReaderWithBucketKey(r io.Reader, keyID string, encryptionContext map[string]string, bucketKeyEnabled bool) (io.Reader, *SSEKMSKey, error) {
if bucketKeyEnabled {
@@ -111,42 +106,6 @@ func CreateSSEKMSEncryptedReaderWithBucketKey(r io.Reader, keyID string, encrypt
return encryptedReader, sseKey, nil
}
-// CreateSSEKMSEncryptedReaderWithBaseIV creates an SSE-KMS encrypted reader using a provided base IV
-// This is used for multipart uploads where all chunks need to use the same base IV
-func CreateSSEKMSEncryptedReaderWithBaseIV(r io.Reader, keyID string, encryptionContext map[string]string, bucketKeyEnabled bool, baseIV []byte) (io.Reader, *SSEKMSKey, error) {
- if err := ValidateIV(baseIV, "base IV"); err != nil {
- return nil, nil, err
- }
-
- // Generate data key using common utility
- dataKeyResult, err := generateKMSDataKey(keyID, encryptionContext)
- if err != nil {
- return nil, nil, err
- }
-
- // Ensure we clear the plaintext data key from memory when done
- defer clearKMSDataKey(dataKeyResult)
-
- // Use the provided base IV instead of generating a new one
- iv := make([]byte, s3_constants.AESBlockSize)
- copy(iv, baseIV)
-
- // Create CTR mode cipher stream
- stream := cipher.NewCTR(dataKeyResult.Block, iv)
-
- // Create the SSE-KMS metadata using utility function
- sseKey := createSSEKMSKey(dataKeyResult, encryptionContext, bucketKeyEnabled, iv, 0)
-
- // The IV is stored in SSE key metadata, so the encrypted stream does not need to prepend the IV
- // This ensures correct Content-Length for clients
- encryptedReader := &cipher.StreamReader{S: stream, R: r}
-
- // Store the base IV in the SSE key for metadata storage
- sseKey.IV = iv
-
- return encryptedReader, sseKey, nil
-}
-
// CreateSSEKMSEncryptedReaderWithBaseIVAndOffset creates an SSE-KMS encrypted reader using a provided base IV and offset
// This is used for multipart uploads where all chunks need unique IVs to prevent IV reuse vulnerabilities
func CreateSSEKMSEncryptedReaderWithBaseIVAndOffset(r io.Reader, keyID string, encryptionContext map[string]string, bucketKeyEnabled bool, baseIV []byte, offset int64) (io.Reader, *SSEKMSKey, error) {
@@ -453,67 +412,6 @@ func CreateSSEKMSDecryptedReader(r io.Reader, sseKey *SSEKMSKey) (io.Reader, err
return decryptReader, nil
}
-// ParseSSEKMSHeaders parses SSE-KMS headers from an HTTP request
-func ParseSSEKMSHeaders(r *http.Request) (*SSEKMSKey, error) {
- sseAlgorithm := r.Header.Get(s3_constants.AmzServerSideEncryption)
-
- // Check if SSE-KMS is requested
- if sseAlgorithm == "" {
- return nil, nil // No SSE headers present
- }
- if sseAlgorithm != s3_constants.SSEAlgorithmKMS {
- return nil, fmt.Errorf("invalid SSE algorithm: %s", sseAlgorithm)
- }
-
- keyID := r.Header.Get(s3_constants.AmzServerSideEncryptionAwsKmsKeyId)
- encryptionContextHeader := r.Header.Get(s3_constants.AmzServerSideEncryptionContext)
- bucketKeyEnabledHeader := r.Header.Get(s3_constants.AmzServerSideEncryptionBucketKeyEnabled)
-
- // Parse encryption context if provided
- var encryptionContext map[string]string
- if encryptionContextHeader != "" {
- // Decode base64-encoded JSON encryption context
- contextBytes, err := base64.StdEncoding.DecodeString(encryptionContextHeader)
- if err != nil {
- return nil, fmt.Errorf("invalid encryption context format: %v", err)
- }
-
- if err := json.Unmarshal(contextBytes, &encryptionContext); err != nil {
- return nil, fmt.Errorf("invalid encryption context JSON: %v", err)
- }
- }
-
- // Parse bucket key enabled flag
- bucketKeyEnabled := strings.ToLower(bucketKeyEnabledHeader) == "true"
-
- sseKey := &SSEKMSKey{
- KeyID: keyID,
- EncryptionContext: encryptionContext,
- BucketKeyEnabled: bucketKeyEnabled,
- }
-
- // Validate the parsed key including key ID format
- if err := ValidateSSEKMSKeyInternal(sseKey); err != nil {
- return nil, err
- }
-
- return sseKey, nil
-}
-
-// ValidateSSEKMSKey validates an SSE-KMS key configuration
-func ValidateSSEKMSKeyInternal(sseKey *SSEKMSKey) error {
- if err := ValidateSSEKMSKey(sseKey); err != nil {
- return err
- }
-
- // An empty key ID is valid and means the default KMS key should be used.
- if sseKey.KeyID != "" && !isValidKMSKeyID(sseKey.KeyID) {
- return fmt.Errorf("invalid KMS key ID format: %s", sseKey.KeyID)
- }
-
- return nil
-}
-
// BuildEncryptionContext creates the encryption context for S3 objects
func BuildEncryptionContext(bucketName, objectKey string, useBucketKey bool) map[string]string {
return kms.BuildS3EncryptionContext(bucketName, objectKey, useBucketKey)
@@ -732,28 +630,6 @@ func IsSSEKMSEncrypted(metadata map[string][]byte) bool {
return false
}
-// IsAnySSEEncrypted checks if metadata indicates any type of SSE encryption
-func IsAnySSEEncrypted(metadata map[string][]byte) bool {
- if metadata == nil {
- return false
- }
-
- // Check for any SSE type
- if IsSSECEncrypted(metadata) {
- return true
- }
- if IsSSEKMSEncrypted(metadata) {
- return true
- }
-
- // Check for SSE-S3
- if sseAlgorithm, exists := metadata[s3_constants.AmzServerSideEncryption]; exists {
- return string(sseAlgorithm) == s3_constants.SSEAlgorithmAES256
- }
-
- return false
-}
-
// MapKMSErrorToS3Error maps KMS errors to appropriate S3 error codes
func MapKMSErrorToS3Error(err error) s3err.ErrorCode {
if err == nil {
@@ -990,21 +866,6 @@ func DetermineUnifiedCopyStrategy(state *EncryptionState, srcMetadata map[string
return CopyStrategyDirect, nil
}
-// DetectEncryptionState analyzes the source metadata and request headers to determine encryption state
-func DetectEncryptionState(srcMetadata map[string][]byte, r *http.Request, srcPath, dstPath string) *EncryptionState {
- state := &EncryptionState{
- SrcSSEC: IsSSECEncrypted(srcMetadata),
- SrcSSEKMS: IsSSEKMSEncrypted(srcMetadata),
- SrcSSES3: IsSSES3EncryptedInternal(srcMetadata),
- DstSSEC: IsSSECRequest(r),
- DstSSEKMS: IsSSEKMSRequest(r),
- DstSSES3: IsSSES3RequestInternal(r),
- SameObject: srcPath == dstPath,
- }
-
- return state
-}
-
// DetectEncryptionStateWithEntry analyzes the source entry and request headers to determine encryption state
// This version can detect multipart encrypted objects by examining chunks
func DetectEncryptionStateWithEntry(entry *filer_pb.Entry, r *http.Request, srcPath, dstPath string) *EncryptionState {
diff --git a/weed/s3api/s3_sse_kms_test.go b/weed/s3api/s3_sse_kms_test.go
deleted file mode 100644
index 487a239a5..000000000
--- a/weed/s3api/s3_sse_kms_test.go
+++ /dev/null
@@ -1,399 +0,0 @@
-package s3api
-
-import (
- "bytes"
- "encoding/json"
- "io"
- "strings"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/kms"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
-)
-
-func TestSSEKMSEncryptionDecryption(t *testing.T) {
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- // Test data
- testData := "Hello, SSE-KMS world! This is a test of envelope encryption."
- testReader := strings.NewReader(testData)
-
- // Create encryption context
- encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false)
-
- // Encrypt the data
- encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(testReader, kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- // Verify SSE key metadata
- if sseKey.KeyID != kmsKey.KeyID {
- t.Errorf("Expected key ID %s, got %s", kmsKey.KeyID, sseKey.KeyID)
- }
-
- if len(sseKey.EncryptedDataKey) == 0 {
- t.Error("Encrypted data key should not be empty")
- }
-
- if sseKey.EncryptionContext == nil {
- t.Error("Encryption context should not be nil")
- }
-
- // Read the encrypted data
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Verify the encrypted data is different from original
- if string(encryptedData) == testData {
- t.Error("Encrypted data should be different from original data")
- }
-
- // The encrypted data should be same size as original (IV is stored in metadata, not in stream)
- if len(encryptedData) != len(testData) {
- t.Errorf("Encrypted data should be same size as original: expected %d, got %d", len(testData), len(encryptedData))
- }
-
- // Decrypt the data
- decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- // Read the decrypted data
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data: %v", err)
- }
-
- // Verify the decrypted data matches the original
- if string(decryptedData) != testData {
- t.Errorf("Decrypted data does not match original.\nExpected: %s\nGot: %s", testData, string(decryptedData))
- }
-}
-
-func TestSSEKMSKeyValidation(t *testing.T) {
- tests := []struct {
- name string
- keyID string
- wantValid bool
- }{
- {
- name: "Valid UUID key ID",
- keyID: "12345678-1234-1234-1234-123456789012",
- wantValid: true,
- },
- {
- name: "Valid alias",
- keyID: "alias/my-test-key",
- wantValid: true,
- },
- {
- name: "Valid ARN",
- keyID: "arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012",
- wantValid: true,
- },
- {
- name: "Valid alias ARN",
- keyID: "arn:aws:kms:us-east-1:123456789012:alias/my-test-key",
- wantValid: true,
- },
-
- {
- name: "Valid test key format",
- keyID: "invalid-key-format",
- wantValid: true, // Now valid - following Minio's permissive approach
- },
- {
- name: "Valid short key",
- keyID: "12345678-1234",
- wantValid: true, // Now valid - following Minio's permissive approach
- },
- {
- name: "Invalid - leading space",
- keyID: " leading-space",
- wantValid: false,
- },
- {
- name: "Invalid - trailing space",
- keyID: "trailing-space ",
- wantValid: false,
- },
- {
- name: "Invalid - empty",
- keyID: "",
- wantValid: false,
- },
- {
- name: "Invalid - internal spaces",
- keyID: "invalid key id",
- wantValid: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- valid := isValidKMSKeyID(tt.keyID)
- if valid != tt.wantValid {
- t.Errorf("isValidKMSKeyID(%s) = %v, want %v", tt.keyID, valid, tt.wantValid)
- }
- })
- }
-}
-
-func TestSSEKMSMetadataSerialization(t *testing.T) {
- // Create test SSE key
- sseKey := &SSEKMSKey{
- KeyID: "test-key-id",
- EncryptedDataKey: []byte("encrypted-data-key"),
- EncryptionContext: map[string]string{
- "aws:s3:arn": "arn:aws:s3:::test-bucket/test-object",
- },
- BucketKeyEnabled: true,
- }
-
- // Serialize metadata
- serialized, err := SerializeSSEKMSMetadata(sseKey)
- if err != nil {
- t.Fatalf("Failed to serialize SSE-KMS metadata: %v", err)
- }
-
- // Verify it's valid JSON
- var jsonData map[string]interface{}
- if err := json.Unmarshal(serialized, &jsonData); err != nil {
- t.Fatalf("Serialized data is not valid JSON: %v", err)
- }
-
- // Deserialize metadata
- deserializedKey, err := DeserializeSSEKMSMetadata(serialized)
- if err != nil {
- t.Fatalf("Failed to deserialize SSE-KMS metadata: %v", err)
- }
-
- // Verify the deserialized data matches original
- if deserializedKey.KeyID != sseKey.KeyID {
- t.Errorf("KeyID mismatch: expected %s, got %s", sseKey.KeyID, deserializedKey.KeyID)
- }
-
- if !bytes.Equal(deserializedKey.EncryptedDataKey, sseKey.EncryptedDataKey) {
- t.Error("EncryptedDataKey mismatch")
- }
-
- if len(deserializedKey.EncryptionContext) != len(sseKey.EncryptionContext) {
- t.Error("EncryptionContext length mismatch")
- }
-
- for k, v := range sseKey.EncryptionContext {
- if deserializedKey.EncryptionContext[k] != v {
- t.Errorf("EncryptionContext mismatch for key %s: expected %s, got %s", k, v, deserializedKey.EncryptionContext[k])
- }
- }
-
- if deserializedKey.BucketKeyEnabled != sseKey.BucketKeyEnabled {
- t.Errorf("BucketKeyEnabled mismatch: expected %v, got %v", sseKey.BucketKeyEnabled, deserializedKey.BucketKeyEnabled)
- }
-}
-
-func TestBuildEncryptionContext(t *testing.T) {
- tests := []struct {
- name string
- bucket string
- object string
- useBucketKey bool
- expectedARN string
- }{
- {
- name: "Object-level encryption",
- bucket: "test-bucket",
- object: "test-object",
- useBucketKey: false,
- expectedARN: "arn:aws:s3:::test-bucket/test-object",
- },
- {
- name: "Bucket-level encryption",
- bucket: "test-bucket",
- object: "test-object",
- useBucketKey: true,
- expectedARN: "arn:aws:s3:::test-bucket",
- },
- {
- name: "Nested object path",
- bucket: "my-bucket",
- object: "folder/subfolder/file.txt",
- useBucketKey: false,
- expectedARN: "arn:aws:s3:::my-bucket/folder/subfolder/file.txt",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- context := BuildEncryptionContext(tt.bucket, tt.object, tt.useBucketKey)
-
- if context == nil {
- t.Fatal("Encryption context should not be nil")
- }
-
- arn, exists := context[kms.EncryptionContextS3ARN]
- if !exists {
- t.Error("Encryption context should contain S3 ARN")
- }
-
- if arn != tt.expectedARN {
- t.Errorf("Expected ARN %s, got %s", tt.expectedARN, arn)
- }
- })
- }
-}
-
-func TestKMSErrorMapping(t *testing.T) {
- tests := []struct {
- name string
- kmsError *kms.KMSError
- expectedErr string
- }{
- {
- name: "Key not found",
- kmsError: &kms.KMSError{
- Code: kms.ErrCodeNotFoundException,
- Message: "Key not found",
- },
- expectedErr: "KMSKeyNotFoundException",
- },
- {
- name: "Access denied",
- kmsError: &kms.KMSError{
- Code: kms.ErrCodeAccessDenied,
- Message: "Access denied",
- },
- expectedErr: "KMSAccessDeniedException",
- },
- {
- name: "Key unavailable",
- kmsError: &kms.KMSError{
- Code: kms.ErrCodeKeyUnavailable,
- Message: "Key is disabled",
- },
- expectedErr: "KMSKeyDisabledException",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- errorCode := MapKMSErrorToS3Error(tt.kmsError)
-
- // Get the actual error description
- apiError := s3err.GetAPIError(errorCode)
- if apiError.Code != tt.expectedErr {
- t.Errorf("Expected error code %s, got %s", tt.expectedErr, apiError.Code)
- }
- })
- }
-}
-
-// TestLargeDataEncryption tests encryption/decryption of larger data streams
-func TestSSEKMSLargeDataEncryption(t *testing.T) {
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- // Create a larger test dataset (1MB)
- testData := strings.Repeat("This is a test of SSE-KMS with larger data streams. ", 20000)
- testReader := strings.NewReader(testData)
-
- // Create encryption context
- encryptionContext := BuildEncryptionContext("large-bucket", "large-object", false)
-
- // Encrypt the data
- encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(testReader, kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- // Read the encrypted data
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Decrypt the data
- decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- // Read the decrypted data
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data: %v", err)
- }
-
- // Verify the decrypted data matches the original
- if string(decryptedData) != testData {
- t.Errorf("Decrypted data length: %d, original data length: %d", len(decryptedData), len(testData))
- t.Error("Decrypted large data does not match original")
- }
-
- t.Logf("Successfully encrypted/decrypted %d bytes of data", len(testData))
-}
-
-// TestValidateSSEKMSKey tests the ValidateSSEKMSKey function, which correctly handles empty key IDs
-func TestValidateSSEKMSKey(t *testing.T) {
- tests := []struct {
- name string
- sseKey *SSEKMSKey
- wantErr bool
- }{
- {
- name: "nil SSE-KMS key",
- sseKey: nil,
- wantErr: true,
- },
- {
- name: "empty key ID (valid - represents default KMS key)",
- sseKey: &SSEKMSKey{
- KeyID: "",
- EncryptionContext: map[string]string{"test": "value"},
- BucketKeyEnabled: false,
- },
- wantErr: false,
- },
- {
- name: "valid UUID key ID",
- sseKey: &SSEKMSKey{
- KeyID: "12345678-1234-1234-1234-123456789012",
- EncryptionContext: map[string]string{"test": "value"},
- BucketKeyEnabled: true,
- },
- wantErr: false,
- },
- {
- name: "valid alias",
- sseKey: &SSEKMSKey{
- KeyID: "alias/my-test-key",
- EncryptionContext: map[string]string{},
- BucketKeyEnabled: false,
- },
- wantErr: false,
- },
- {
- name: "valid flexible key ID format",
- sseKey: &SSEKMSKey{
- KeyID: "invalid-format",
- EncryptionContext: map[string]string{},
- BucketKeyEnabled: false,
- },
- wantErr: false, // Now valid - following Minio's permissive approach
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := ValidateSSEKMSKey(tt.sseKey)
- if (err != nil) != tt.wantErr {
- t.Errorf("ValidateSSEKMSKey() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
diff --git a/weed/s3api/s3_sse_metadata_test.go b/weed/s3api/s3_sse_metadata_test.go
deleted file mode 100644
index c0c1360af..000000000
--- a/weed/s3api/s3_sse_metadata_test.go
+++ /dev/null
@@ -1,328 +0,0 @@
-package s3api
-
-import (
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
-)
-
-// TestSSECIsEncrypted tests detection of SSE-C encryption from metadata
-func TestSSECIsEncrypted(t *testing.T) {
- testCases := []struct {
- name string
- metadata map[string][]byte
- expected bool
- }{
- {
- name: "Empty metadata",
- metadata: CreateTestMetadata(),
- expected: false,
- },
- {
- name: "Valid SSE-C metadata",
- metadata: CreateTestMetadataWithSSEC(GenerateTestSSECKey(1)),
- expected: true,
- },
- {
- name: "SSE-C algorithm only",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte("AES256"),
- },
- expected: true,
- },
- {
- name: "SSE-C key MD5 only",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryptionCustomerKeyMD5: []byte("somemd5"),
- },
- expected: true,
- },
- {
- name: "Other encryption type (SSE-KMS)",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- },
- expected: false,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- result := IsSSECEncrypted(tc.metadata)
- if result != tc.expected {
- t.Errorf("Expected %v, got %v", tc.expected, result)
- }
- })
- }
-}
-
-// TestSSEKMSIsEncrypted tests detection of SSE-KMS encryption from metadata
-func TestSSEKMSIsEncrypted(t *testing.T) {
- testCases := []struct {
- name string
- metadata map[string][]byte
- expected bool
- }{
- {
- name: "Empty metadata",
- metadata: CreateTestMetadata(),
- expected: false,
- },
- {
- name: "Valid SSE-KMS metadata",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- s3_constants.AmzEncryptedDataKey: []byte("encrypted-key"),
- },
- expected: true,
- },
- {
- name: "SSE-KMS algorithm only",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- },
- expected: true,
- },
- {
- name: "SSE-KMS encrypted data key only",
- metadata: map[string][]byte{
- s3_constants.AmzEncryptedDataKey: []byte("encrypted-key"),
- },
- expected: false, // Only encrypted data key without algorithm header should not be considered SSE-KMS
- },
- {
- name: "Other encryption type (SSE-C)",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte("AES256"),
- },
- expected: false,
- },
- {
- name: "SSE-S3 (AES256)",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- },
- expected: false,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- result := IsSSEKMSEncrypted(tc.metadata)
- if result != tc.expected {
- t.Errorf("Expected %v, got %v", tc.expected, result)
- }
- })
- }
-}
-
-// TestSSETypeDiscrimination tests that SSE types don't interfere with each other
-func TestSSETypeDiscrimination(t *testing.T) {
- // Test SSE-C headers don't trigger SSE-KMS detection
- t.Run("SSE-C headers don't trigger SSE-KMS", func(t *testing.T) {
- req := CreateTestHTTPRequest("PUT", "/bucket/object", nil)
- keyPair := GenerateTestSSECKey(1)
- SetupTestSSECHeaders(req, keyPair)
-
- // Should detect SSE-C, not SSE-KMS
- if !IsSSECRequest(req) {
- t.Error("Should detect SSE-C request")
- }
- if IsSSEKMSRequest(req) {
- t.Error("Should not detect SSE-KMS request for SSE-C headers")
- }
- })
-
- // Test SSE-KMS headers don't trigger SSE-C detection
- t.Run("SSE-KMS headers don't trigger SSE-C", func(t *testing.T) {
- req := CreateTestHTTPRequest("PUT", "/bucket/object", nil)
- SetupTestSSEKMSHeaders(req, "test-key-id")
-
- // Should detect SSE-KMS, not SSE-C
- if IsSSECRequest(req) {
- t.Error("Should not detect SSE-C request for SSE-KMS headers")
- }
- if !IsSSEKMSRequest(req) {
- t.Error("Should detect SSE-KMS request")
- }
- })
-
- // Test metadata discrimination
- t.Run("Metadata type discrimination", func(t *testing.T) {
- ssecMetadata := CreateTestMetadataWithSSEC(GenerateTestSSECKey(1))
-
- // Should detect as SSE-C, not SSE-KMS
- if !IsSSECEncrypted(ssecMetadata) {
- t.Error("Should detect SSE-C encrypted metadata")
- }
- if IsSSEKMSEncrypted(ssecMetadata) {
- t.Error("Should not detect SSE-KMS for SSE-C metadata")
- }
- })
-}
-
-// TestSSECParseCorruptedMetadata tests handling of corrupted SSE-C metadata
-func TestSSECParseCorruptedMetadata(t *testing.T) {
- testCases := []struct {
- name string
- metadata map[string][]byte
- expectError bool
- errorMessage string
- }{
- {
- name: "Missing algorithm",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryptionCustomerKeyMD5: []byte("valid-md5"),
- },
- expectError: false, // Detection should still work with partial metadata
- },
- {
- name: "Invalid key MD5 format",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte("AES256"),
- s3_constants.AmzServerSideEncryptionCustomerKeyMD5: []byte("invalid-base64!"),
- },
- expectError: false, // Detection should work, validation happens later
- },
- {
- name: "Empty values",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte(""),
- s3_constants.AmzServerSideEncryptionCustomerKeyMD5: []byte(""),
- },
- expectError: false,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- // Test that detection doesn't panic on corrupted metadata
- result := IsSSECEncrypted(tc.metadata)
- // The detection should be robust and not crash
- t.Logf("Detection result for %s: %v", tc.name, result)
- })
- }
-}
-
-// TestSSEKMSParseCorruptedMetadata tests handling of corrupted SSE-KMS metadata
-func TestSSEKMSParseCorruptedMetadata(t *testing.T) {
- testCases := []struct {
- name string
- metadata map[string][]byte
- }{
- {
- name: "Invalid encrypted data key",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- s3_constants.AmzEncryptedDataKey: []byte("invalid-base64!"),
- },
- },
- {
- name: "Invalid encryption context",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- s3_constants.AmzEncryptionContextMeta: []byte("invalid-json"),
- },
- },
- {
- name: "Empty values",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte(""),
- s3_constants.AmzEncryptedDataKey: []byte(""),
- },
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- // Test that detection doesn't panic on corrupted metadata
- result := IsSSEKMSEncrypted(tc.metadata)
- t.Logf("Detection result for %s: %v", tc.name, result)
- })
- }
-}
-
-// TestSSEMetadataDeserialization tests SSE-KMS metadata deserialization with various inputs
-func TestSSEMetadataDeserialization(t *testing.T) {
- testCases := []struct {
- name string
- data []byte
- expectError bool
- }{
- {
- name: "Empty data",
- data: []byte{},
- expectError: true,
- },
- {
- name: "Invalid JSON",
- data: []byte("invalid-json"),
- expectError: true,
- },
- {
- name: "Valid JSON but wrong structure",
- data: []byte(`{"wrong": "structure"}`),
- expectError: false, // Our deserialization might be lenient
- },
- {
- name: "Null data",
- data: nil,
- expectError: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- _, err := DeserializeSSEKMSMetadata(tc.data)
- if tc.expectError && err == nil {
- t.Error("Expected error but got none")
- }
- if !tc.expectError && err != nil {
- t.Errorf("Expected no error but got: %v", err)
- }
- })
- }
-}
-
-// TestGeneralSSEDetection tests the general SSE detection that works across types
-func TestGeneralSSEDetection(t *testing.T) {
- testCases := []struct {
- name string
- metadata map[string][]byte
- expected bool
- }{
- {
- name: "No encryption",
- metadata: CreateTestMetadata(),
- expected: false,
- },
- {
- name: "SSE-C encrypted",
- metadata: CreateTestMetadataWithSSEC(GenerateTestSSECKey(1)),
- expected: true,
- },
- {
- name: "SSE-KMS encrypted",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- },
- expected: true,
- },
- {
- name: "SSE-S3 encrypted",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- },
- expected: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- result := IsAnySSEEncrypted(tc.metadata)
- if result != tc.expected {
- t.Errorf("Expected %v, got %v", tc.expected, result)
- }
- })
- }
-}
diff --git a/weed/s3api/s3_sse_multipart_test.go b/weed/s3api/s3_sse_multipart_test.go
deleted file mode 100644
index c4dc9a45a..000000000
--- a/weed/s3api/s3_sse_multipart_test.go
+++ /dev/null
@@ -1,569 +0,0 @@
-package s3api
-
-import (
- "bytes"
- "fmt"
- "io"
- "strings"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
-)
-
-// TestSSECMultipartUpload tests SSE-C with multipart uploads
-func TestSSECMultipartUpload(t *testing.T) {
- keyPair := GenerateTestSSECKey(1)
- customerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: keyPair.Key,
- KeyMD5: keyPair.KeyMD5,
- }
-
- // Test data larger than typical part size
- testData := strings.Repeat("Hello, SSE-C multipart world! ", 1000) // ~30KB
-
- t.Run("Single part encryption/decryption", func(t *testing.T) {
- // Encrypt the data
- encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(testData), customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Decrypt the data
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data: %v", err)
- }
-
- if string(decryptedData) != testData {
- t.Error("Decrypted data doesn't match original")
- }
- })
-
- t.Run("Simulated multipart upload parts", func(t *testing.T) {
- // Simulate multiple parts (each part gets encrypted separately)
- partSize := 5 * 1024 // 5KB parts
- var encryptedParts [][]byte
- var partIVs [][]byte
-
- for i := 0; i < len(testData); i += partSize {
- end := i + partSize
- if end > len(testData) {
- end = len(testData)
- }
-
- partData := testData[i:end]
-
- // Each part is encrypted separately in multipart uploads
- encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(partData), customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader for part %d: %v", i/partSize, err)
- }
-
- encryptedPart, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted part %d: %v", i/partSize, err)
- }
-
- encryptedParts = append(encryptedParts, encryptedPart)
- partIVs = append(partIVs, iv)
- }
-
- // Simulate reading back the multipart object
- var reconstructedData strings.Builder
-
- for i, encryptedPart := range encryptedParts {
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedPart), customerKey, partIVs[i])
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for part %d: %v", i, err)
- }
-
- decryptedPart, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted part %d: %v", i, err)
- }
-
- reconstructedData.Write(decryptedPart)
- }
-
- if reconstructedData.String() != testData {
- t.Error("Reconstructed multipart data doesn't match original")
- }
- })
-
- t.Run("Multipart with different part sizes", func(t *testing.T) {
- partSizes := []int{1024, 2048, 4096, 8192} // Various part sizes
-
- for _, partSize := range partSizes {
- t.Run(fmt.Sprintf("PartSize_%d", partSize), func(t *testing.T) {
- var encryptedParts [][]byte
- var partIVs [][]byte
-
- for i := 0; i < len(testData); i += partSize {
- end := i + partSize
- if end > len(testData) {
- end = len(testData)
- }
-
- partData := testData[i:end]
-
- encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(partData), customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- encryptedPart, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted part: %v", err)
- }
-
- encryptedParts = append(encryptedParts, encryptedPart)
- partIVs = append(partIVs, iv)
- }
-
- // Verify reconstruction
- var reconstructedData strings.Builder
-
- for j, encryptedPart := range encryptedParts {
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedPart), customerKey, partIVs[j])
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- decryptedPart, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted part: %v", err)
- }
-
- reconstructedData.Write(decryptedPart)
- }
-
- if reconstructedData.String() != testData {
- t.Errorf("Reconstructed data doesn't match original for part size %d", partSize)
- }
- })
- }
- })
-}
-
-// TestSSEKMSMultipartUpload tests SSE-KMS with multipart uploads
-func TestSSEKMSMultipartUpload(t *testing.T) {
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- // Test data larger than typical part size
- testData := strings.Repeat("Hello, SSE-KMS multipart world! ", 1000) // ~30KB
- encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false)
-
- t.Run("Single part encryption/decryption", func(t *testing.T) {
- // Encrypt the data
- encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(testData), kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Decrypt the data
- decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data: %v", err)
- }
-
- if string(decryptedData) != testData {
- t.Error("Decrypted data doesn't match original")
- }
- })
-
- t.Run("Simulated multipart upload parts", func(t *testing.T) {
- // Simulate multiple parts (each part might use the same or different KMS operations)
- partSize := 5 * 1024 // 5KB parts
- var encryptedParts [][]byte
- var sseKeys []*SSEKMSKey
-
- for i := 0; i < len(testData); i += partSize {
- end := i + partSize
- if end > len(testData) {
- end = len(testData)
- }
-
- partData := testData[i:end]
-
- // Each part might get its own data key in KMS multipart uploads
- encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(partData), kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader for part %d: %v", i/partSize, err)
- }
-
- encryptedPart, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted part %d: %v", i/partSize, err)
- }
-
- encryptedParts = append(encryptedParts, encryptedPart)
- sseKeys = append(sseKeys, sseKey)
- }
-
- // Simulate reading back the multipart object
- var reconstructedData strings.Builder
-
- for i, encryptedPart := range encryptedParts {
- decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedPart), sseKeys[i])
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for part %d: %v", i, err)
- }
-
- decryptedPart, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted part %d: %v", i, err)
- }
-
- reconstructedData.Write(decryptedPart)
- }
-
- if reconstructedData.String() != testData {
- t.Error("Reconstructed multipart data doesn't match original")
- }
- })
-
- t.Run("Multipart consistency checks", func(t *testing.T) {
- // Test that all parts use the same KMS key ID but different data keys
- partSize := 5 * 1024
- var sseKeys []*SSEKMSKey
-
- for i := 0; i < len(testData); i += partSize {
- end := i + partSize
- if end > len(testData) {
- end = len(testData)
- }
-
- partData := testData[i:end]
-
- _, sseKey, err := CreateSSEKMSEncryptedReader(strings.NewReader(partData), kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- sseKeys = append(sseKeys, sseKey)
- }
-
- // Verify all parts use the same KMS key ID
- for i, sseKey := range sseKeys {
- if sseKey.KeyID != kmsKey.KeyID {
- t.Errorf("Part %d has wrong KMS key ID: expected %s, got %s", i, kmsKey.KeyID, sseKey.KeyID)
- }
- }
-
- // Verify each part has different encrypted data keys (they should be unique)
- for i := 0; i < len(sseKeys); i++ {
- for j := i + 1; j < len(sseKeys); j++ {
- if bytes.Equal(sseKeys[i].EncryptedDataKey, sseKeys[j].EncryptedDataKey) {
- t.Errorf("Parts %d and %d have identical encrypted data keys (should be unique)", i, j)
- }
- }
- }
- })
-}
-
-// TestMultipartSSEMixedScenarios tests edge cases with multipart and SSE
-func TestMultipartSSEMixedScenarios(t *testing.T) {
- t.Run("Empty parts handling", func(t *testing.T) {
- keyPair := GenerateTestSSECKey(1)
- customerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: keyPair.Key,
- KeyMD5: keyPair.KeyMD5,
- }
-
- // Test empty part
- encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(""), customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader for empty data: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted empty data: %v", err)
- }
-
- // Empty part should produce empty encrypted data, but still have a valid IV
- if len(encryptedData) != 0 {
- t.Errorf("Expected empty encrypted data for empty part, got %d bytes", len(encryptedData))
- }
- if len(iv) != s3_constants.AESBlockSize {
- t.Errorf("Expected IV of size %d, got %d", s3_constants.AESBlockSize, len(iv))
- }
-
- // Decrypt and verify
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for empty data: %v", err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted empty data: %v", err)
- }
-
- if len(decryptedData) != 0 {
- t.Errorf("Expected empty decrypted data, got %d bytes", len(decryptedData))
- }
- })
-
- t.Run("Single byte parts", func(t *testing.T) {
- keyPair := GenerateTestSSECKey(1)
- customerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: keyPair.Key,
- KeyMD5: keyPair.KeyMD5,
- }
-
- testData := "ABCDEFGHIJ"
- var encryptedParts [][]byte
- var partIVs [][]byte
-
- // Encrypt each byte as a separate part
- for i, b := range []byte(testData) {
- partData := string(b)
-
- encryptedReader, iv, err := CreateSSECEncryptedReader(strings.NewReader(partData), customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader for byte %d: %v", i, err)
- }
-
- encryptedPart, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted byte %d: %v", i, err)
- }
-
- encryptedParts = append(encryptedParts, encryptedPart)
- partIVs = append(partIVs, iv)
- }
-
- // Reconstruct
- var reconstructedData strings.Builder
-
- for i, encryptedPart := range encryptedParts {
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedPart), customerKey, partIVs[i])
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for byte %d: %v", i, err)
- }
-
- decryptedPart, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted byte %d: %v", i, err)
- }
-
- reconstructedData.Write(decryptedPart)
- }
-
- if reconstructedData.String() != testData {
- t.Errorf("Expected %s, got %s", testData, reconstructedData.String())
- }
- })
-
- t.Run("Very large parts", func(t *testing.T) {
- keyPair := GenerateTestSSECKey(1)
- customerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: keyPair.Key,
- KeyMD5: keyPair.KeyMD5,
- }
-
- // Create a large part (1MB)
- largeData := make([]byte, 1024*1024)
- for i := range largeData {
- largeData[i] = byte(i % 256)
- }
-
- // Encrypt
- encryptedReader, iv, err := CreateSSECEncryptedReader(bytes.NewReader(largeData), customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader for large data: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted large data: %v", err)
- }
-
- // Decrypt
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for large data: %v", err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted large data: %v", err)
- }
-
- if !bytes.Equal(decryptedData, largeData) {
- t.Error("Large data doesn't match after encryption/decryption")
- }
- })
-}
-
-func TestSSECLargeObjectChunkReassembly(t *testing.T) {
- keyPair := GenerateTestSSECKey(1)
- customerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: keyPair.Key,
- KeyMD5: keyPair.KeyMD5,
- }
-
- const chunkSize = 8 * 1024 * 1024 // matches putToFiler chunk size
- totalSize := chunkSize*2 + 3*1024*1024
- plaintext := make([]byte, totalSize)
- for i := range plaintext {
- plaintext[i] = byte(i % 251)
- }
-
- encryptedReader, iv, err := CreateSSECEncryptedReader(bytes.NewReader(plaintext), customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- var reconstructed bytes.Buffer
- offset := int64(0)
- for offset < int64(len(encryptedData)) {
- end := offset + chunkSize
- if end > int64(len(encryptedData)) {
- end = int64(len(encryptedData))
- }
-
- chunkIV := make([]byte, len(iv))
- copy(chunkIV, iv)
- chunkReader := bytes.NewReader(encryptedData[offset:end])
- decryptedReader, decErr := CreateSSECDecryptedReaderWithOffset(chunkReader, customerKey, chunkIV, uint64(offset))
- if decErr != nil {
- t.Fatalf("Failed to create decrypted reader for offset %d: %v", offset, decErr)
- }
- decryptedChunk, decErr := io.ReadAll(decryptedReader)
- if decErr != nil {
- t.Fatalf("Failed to read decrypted chunk at offset %d: %v", offset, decErr)
- }
- reconstructed.Write(decryptedChunk)
- offset = end
- }
-
- if !bytes.Equal(reconstructed.Bytes(), plaintext) {
- t.Fatalf("Reconstructed data mismatch: expected %d bytes, got %d", len(plaintext), reconstructed.Len())
- }
-}
-
-// TestMultipartSSEPerformance tests performance characteristics of SSE with multipart
-func TestMultipartSSEPerformance(t *testing.T) {
- if testing.Short() {
- t.Skip("Skipping performance test in short mode")
- }
-
- t.Run("SSE-C performance with multiple parts", func(t *testing.T) {
- keyPair := GenerateTestSSECKey(1)
- customerKey := &SSECustomerKey{
- Algorithm: "AES256",
- Key: keyPair.Key,
- KeyMD5: keyPair.KeyMD5,
- }
-
- partSize := 64 * 1024 // 64KB parts
- numParts := 10
-
- for partNum := 0; partNum < numParts; partNum++ {
- partData := make([]byte, partSize)
- for i := range partData {
- partData[i] = byte((partNum + i) % 256)
- }
-
- // Encrypt
- encryptedReader, iv, err := CreateSSECEncryptedReader(bytes.NewReader(partData), customerKey)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader for part %d: %v", partNum, err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data for part %d: %v", partNum, err)
- }
-
- // Decrypt
- decryptedReader, err := CreateSSECDecryptedReader(bytes.NewReader(encryptedData), customerKey, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for part %d: %v", partNum, err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data for part %d: %v", partNum, err)
- }
-
- if !bytes.Equal(decryptedData, partData) {
- t.Errorf("Data mismatch for part %d", partNum)
- }
- }
- })
-
- t.Run("SSE-KMS performance with multiple parts", func(t *testing.T) {
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- partSize := 64 * 1024 // 64KB parts
- numParts := 5 // Fewer parts for KMS due to overhead
- encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false)
-
- for partNum := 0; partNum < numParts; partNum++ {
- partData := make([]byte, partSize)
- for i := range partData {
- partData[i] = byte((partNum + i) % 256)
- }
-
- // Encrypt
- encryptedReader, sseKey, err := CreateSSEKMSEncryptedReader(bytes.NewReader(partData), kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader for part %d: %v", partNum, err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data for part %d: %v", partNum, err)
- }
-
- // Decrypt
- decryptedReader, err := CreateSSEKMSDecryptedReader(bytes.NewReader(encryptedData), sseKey)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for part %d: %v", partNum, err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data for part %d: %v", partNum, err)
- }
-
- if !bytes.Equal(decryptedData, partData) {
- t.Errorf("Data mismatch for part %d", partNum)
- }
- }
- })
-}
diff --git a/weed/s3api/s3_sse_s3.go b/weed/s3api/s3_sse_s3.go
index d9ea5a919..801221ed3 100644
--- a/weed/s3api/s3_sse_s3.go
+++ b/weed/s3api/s3_sse_s3.go
@@ -137,13 +137,6 @@ func CreateSSES3DecryptedReader(reader io.Reader, key *SSES3Key, iv []byte) (io.
return decryptReader, nil
}
-// GetSSES3Headers returns the headers for SSE-S3 encrypted objects
-func GetSSES3Headers() map[string]string {
- return map[string]string{
- s3_constants.AmzServerSideEncryption: SSES3Algorithm,
- }
-}
-
// SerializeSSES3Metadata serializes SSE-S3 metadata for storage using envelope encryption
func SerializeSSES3Metadata(key *SSES3Key) ([]byte, error) {
if err := ValidateSSES3Key(key); err != nil {
@@ -339,7 +332,7 @@ func (km *SSES3KeyManager) InitializeWithFiler(filerClient filer_pb.FilerClient)
v := util.GetViper()
cfgKEK := v.GetString(sseS3KEKConfigKey) // hex-encoded, drop-in for filer file
- cfgKey := v.GetString(sseS3KeyConfigKey) // any string, HKDF-derived
+ cfgKey := v.GetString(sseS3KeyConfigKey) // any string, HKDF-derived
if cfgKEK != "" && cfgKey != "" {
return fmt.Errorf("only one of %s and %s may be set, not both", sseS3KEKConfigKey, sseS3KeyConfigKey)
@@ -454,7 +447,6 @@ func (km *SSES3KeyManager) loadSuperKeyFromFiler() error {
return nil
}
-
// GetOrCreateKey gets an existing key or creates a new one
// With envelope encryption, we always generate a new DEK since we don't store them
func (km *SSES3KeyManager) GetOrCreateKey(keyID string) (*SSES3Key, error) {
@@ -532,14 +524,6 @@ func (km *SSES3KeyManager) StoreKey(key *SSES3Key) {
// The DEK is encrypted with the super key and stored in object metadata
}
-// GetKey is now a no-op since we don't cache keys
-// Keys are retrieved by decrypting the encrypted DEK from object metadata
-func (km *SSES3KeyManager) GetKey(keyID string) (*SSES3Key, bool) {
- // No-op: With envelope encryption, keys are not cached
- // Each object's metadata contains the encrypted DEK
- return nil, false
-}
-
// GetMasterKey returns a derived key from the master KEK for STS signing
// This uses HKDF to isolate the STS security domain from the SSE-S3 domain
func (km *SSES3KeyManager) GetMasterKey() []byte {
@@ -596,47 +580,6 @@ func InitializeGlobalSSES3KeyManager(filerClient *wdclient.FilerClient, grpcDial
return globalSSES3KeyManager.InitializeWithFiler(wrapper)
}
-// ProcessSSES3Request processes an SSE-S3 request and returns encryption metadata
-func ProcessSSES3Request(r *http.Request) (map[string][]byte, error) {
- if !IsSSES3RequestInternal(r) {
- return nil, nil
- }
-
- // Generate or retrieve encryption key
- keyManager := GetSSES3KeyManager()
- key, err := keyManager.GetOrCreateKey("")
- if err != nil {
- return nil, fmt.Errorf("get SSE-S3 key: %w", err)
- }
-
- // Serialize key metadata
- keyData, err := SerializeSSES3Metadata(key)
- if err != nil {
- return nil, fmt.Errorf("serialize SSE-S3 metadata: %w", err)
- }
-
- // Store key in manager
- keyManager.StoreKey(key)
-
- // Return metadata
- metadata := map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte(SSES3Algorithm),
- s3_constants.SeaweedFSSSES3Key: keyData,
- }
-
- return metadata, nil
-}
-
-// GetSSES3KeyFromMetadata extracts SSE-S3 key from object metadata
-func GetSSES3KeyFromMetadata(metadata map[string][]byte, keyManager *SSES3KeyManager) (*SSES3Key, error) {
- keyData, exists := metadata[s3_constants.SeaweedFSSSES3Key]
- if !exists {
- return nil, fmt.Errorf("SSE-S3 key not found in metadata")
- }
-
- return DeserializeSSES3Metadata(keyData, keyManager)
-}
-
// GetSSES3IV extracts the IV for single-part SSE-S3 objects
// Priority: 1) object-level metadata (for inline/small files), 2) first chunk metadata
func GetSSES3IV(entry *filer_pb.Entry, sseS3Key *SSES3Key, keyManager *SSES3KeyManager) ([]byte, error) {
diff --git a/weed/s3api/s3_sse_s3_test.go b/weed/s3api/s3_sse_s3_test.go
deleted file mode 100644
index af64850d9..000000000
--- a/weed/s3api/s3_sse_s3_test.go
+++ /dev/null
@@ -1,1079 +0,0 @@
-package s3api
-
-import (
- "bytes"
- "encoding/hex"
- "fmt"
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/util"
-)
-
-// TestSSES3EncryptionDecryption tests basic SSE-S3 encryption and decryption
-func TestSSES3EncryptionDecryption(t *testing.T) {
- // Generate SSE-S3 key
- sseS3Key, err := GenerateSSES3Key()
- if err != nil {
- t.Fatalf("Failed to generate SSE-S3 key: %v", err)
- }
-
- // Test data
- testData := []byte("Hello, World! This is a test of SSE-S3 encryption.")
-
- // Create encrypted reader
- dataReader := bytes.NewReader(testData)
- encryptedReader, iv, err := CreateSSES3EncryptedReader(dataReader, sseS3Key)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- // Read encrypted data
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Verify data is actually encrypted (different from original)
- if bytes.Equal(encryptedData, testData) {
- t.Error("Data doesn't appear to be encrypted")
- }
-
- // Create decrypted reader
- encryptedReader2 := bytes.NewReader(encryptedData)
- decryptedReader, err := CreateSSES3DecryptedReader(encryptedReader2, sseS3Key, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- // Read decrypted data
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data: %v", err)
- }
-
- // Verify decrypted data matches original
- if !bytes.Equal(decryptedData, testData) {
- t.Errorf("Decrypted data doesn't match original.\nOriginal: %s\nDecrypted: %s", testData, decryptedData)
- }
-}
-
-// TestSSES3IsRequestInternal tests detection of SSE-S3 requests
-func TestSSES3IsRequestInternal(t *testing.T) {
- testCases := []struct {
- name string
- headers map[string]string
- expected bool
- }{
- {
- name: "Valid SSE-S3 request",
- headers: map[string]string{
- s3_constants.AmzServerSideEncryption: "AES256",
- },
- expected: true,
- },
- {
- name: "No SSE headers",
- headers: map[string]string{},
- expected: false,
- },
- {
- name: "SSE-KMS request",
- headers: map[string]string{
- s3_constants.AmzServerSideEncryption: "aws:kms",
- },
- expected: false,
- },
- {
- name: "SSE-C request",
- headers: map[string]string{
- s3_constants.AmzServerSideEncryptionCustomerAlgorithm: "AES256",
- },
- expected: false,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- req := &http.Request{Header: make(http.Header)}
- for k, v := range tc.headers {
- req.Header.Set(k, v)
- }
-
- result := IsSSES3RequestInternal(req)
- if result != tc.expected {
- t.Errorf("Expected %v, got %v", tc.expected, result)
- }
- })
- }
-}
-
-// TestSSES3MetadataSerialization tests SSE-S3 metadata serialization and deserialization
-func TestSSES3MetadataSerialization(t *testing.T) {
- // Initialize global key manager
- globalSSES3KeyManager = NewSSES3KeyManager()
- defer func() {
- globalSSES3KeyManager = NewSSES3KeyManager()
- }()
-
- // Set up the key manager with a super key for testing
- keyManager := GetSSES3KeyManager()
- keyManager.superKey = make([]byte, 32)
- for i := range keyManager.superKey {
- keyManager.superKey[i] = byte(i)
- }
-
- // Generate SSE-S3 key
- sseS3Key, err := GenerateSSES3Key()
- if err != nil {
- t.Fatalf("Failed to generate SSE-S3 key: %v", err)
- }
-
- // Add IV to the key
- sseS3Key.IV = make([]byte, 16)
- for i := range sseS3Key.IV {
- sseS3Key.IV[i] = byte(i * 2)
- }
-
- // Serialize metadata
- serialized, err := SerializeSSES3Metadata(sseS3Key)
- if err != nil {
- t.Fatalf("Failed to serialize SSE-S3 metadata: %v", err)
- }
-
- if len(serialized) == 0 {
- t.Error("Serialized metadata is empty")
- }
-
- // Deserialize metadata
- deserializedKey, err := DeserializeSSES3Metadata(serialized, keyManager)
- if err != nil {
- t.Fatalf("Failed to deserialize SSE-S3 metadata: %v", err)
- }
-
- // Verify key matches
- if !bytes.Equal(deserializedKey.Key, sseS3Key.Key) {
- t.Error("Deserialized key doesn't match original key")
- }
-
- // Verify IV matches
- if !bytes.Equal(deserializedKey.IV, sseS3Key.IV) {
- t.Error("Deserialized IV doesn't match original IV")
- }
-
- // Verify algorithm matches
- if deserializedKey.Algorithm != sseS3Key.Algorithm {
- t.Errorf("Algorithm mismatch: expected %s, got %s", sseS3Key.Algorithm, deserializedKey.Algorithm)
- }
-
- // Verify key ID matches
- if deserializedKey.KeyID != sseS3Key.KeyID {
- t.Errorf("Key ID mismatch: expected %s, got %s", sseS3Key.KeyID, deserializedKey.KeyID)
- }
-}
-
-// TestDetectPrimarySSETypeS3 tests detection of SSE-S3 as primary encryption type
-func TestDetectPrimarySSETypeS3(t *testing.T) {
- s3a := &S3ApiServer{}
-
- testCases := []struct {
- name string
- entry *filer_pb.Entry
- expected string
- }{
- {
- name: "Single SSE-S3 chunk",
- entry: &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- },
- Attributes: &filer_pb.FuseAttributes{},
- Chunks: []*filer_pb.FileChunk{
- {
- FileId: "1,123",
- Offset: 0,
- Size: 1024,
- SseType: filer_pb.SSEType_SSE_S3,
- SseMetadata: []byte("metadata"),
- },
- },
- },
- expected: s3_constants.SSETypeS3,
- },
- {
- name: "Multiple SSE-S3 chunks",
- entry: &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- },
- Attributes: &filer_pb.FuseAttributes{},
- Chunks: []*filer_pb.FileChunk{
- {
- FileId: "1,123",
- Offset: 0,
- Size: 1024,
- SseType: filer_pb.SSEType_SSE_S3,
- SseMetadata: []byte("metadata1"),
- },
- {
- FileId: "2,456",
- Offset: 1024,
- Size: 1024,
- SseType: filer_pb.SSEType_SSE_S3,
- SseMetadata: []byte("metadata2"),
- },
- },
- },
- expected: s3_constants.SSETypeS3,
- },
- {
- name: "Mixed SSE-S3 and SSE-KMS chunks (SSE-S3 majority)",
- entry: &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- },
- Attributes: &filer_pb.FuseAttributes{},
- Chunks: []*filer_pb.FileChunk{
- {
- FileId: "1,123",
- Offset: 0,
- Size: 1024,
- SseType: filer_pb.SSEType_SSE_S3,
- SseMetadata: []byte("metadata1"),
- },
- {
- FileId: "2,456",
- Offset: 1024,
- Size: 1024,
- SseType: filer_pb.SSEType_SSE_S3,
- SseMetadata: []byte("metadata2"),
- },
- {
- FileId: "3,789",
- Offset: 2048,
- Size: 1024,
- SseType: filer_pb.SSEType_SSE_KMS,
- SseMetadata: []byte("metadata3"),
- },
- },
- },
- expected: s3_constants.SSETypeS3,
- },
- {
- name: "No chunks, SSE-S3 metadata without KMS key ID",
- entry: &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- },
- Attributes: &filer_pb.FuseAttributes{},
- Chunks: []*filer_pb.FileChunk{},
- },
- expected: s3_constants.SSETypeS3,
- },
- {
- name: "No chunks, SSE-KMS metadata with KMS key ID",
- entry: &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- s3_constants.AmzServerSideEncryptionAwsKmsKeyId: []byte("test-key-id"),
- },
- Attributes: &filer_pb.FuseAttributes{},
- Chunks: []*filer_pb.FileChunk{},
- },
- expected: s3_constants.SSETypeKMS,
- },
- {
- name: "SSE-C chunks",
- entry: &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte("AES256"),
- },
- Attributes: &filer_pb.FuseAttributes{},
- Chunks: []*filer_pb.FileChunk{
- {
- FileId: "1,123",
- Offset: 0,
- Size: 1024,
- SseType: filer_pb.SSEType_SSE_C,
- SseMetadata: []byte("metadata"),
- },
- },
- },
- expected: s3_constants.SSETypeC,
- },
- {
- name: "Unencrypted",
- entry: &filer_pb.Entry{
- Extended: map[string][]byte{},
- Attributes: &filer_pb.FuseAttributes{},
- Chunks: []*filer_pb.FileChunk{
- {
- FileId: "1,123",
- Offset: 0,
- Size: 1024,
- },
- },
- },
- expected: "None",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- result := s3a.detectPrimarySSEType(tc.entry)
- if result != tc.expected {
- t.Errorf("Expected %s, got %s", tc.expected, result)
- }
- })
- }
-}
-
-// TestSSES3EncryptionWithBaseIV tests multipart encryption with base IV
-func TestSSES3EncryptionWithBaseIV(t *testing.T) {
- // Generate SSE-S3 key
- sseS3Key, err := GenerateSSES3Key()
- if err != nil {
- t.Fatalf("Failed to generate SSE-S3 key: %v", err)
- }
-
- // Generate base IV
- baseIV := make([]byte, 16)
- for i := range baseIV {
- baseIV[i] = byte(i)
- }
-
- // Test data for two parts
- testData1 := []byte("Part 1 of multipart upload test.")
- testData2 := []byte("Part 2 of multipart upload test.")
-
- // Encrypt part 1 at offset 0
- dataReader1 := bytes.NewReader(testData1)
- encryptedReader1, iv1, err := CreateSSES3EncryptedReaderWithBaseIV(dataReader1, sseS3Key, baseIV, 0)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader for part 1: %v", err)
- }
-
- encryptedData1, err := io.ReadAll(encryptedReader1)
- if err != nil {
- t.Fatalf("Failed to read encrypted data for part 1: %v", err)
- }
-
- // Encrypt part 2 at offset (simulating second part)
- dataReader2 := bytes.NewReader(testData2)
- offset2 := int64(len(testData1))
- encryptedReader2, iv2, err := CreateSSES3EncryptedReaderWithBaseIV(dataReader2, sseS3Key, baseIV, offset2)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader for part 2: %v", err)
- }
-
- encryptedData2, err := io.ReadAll(encryptedReader2)
- if err != nil {
- t.Fatalf("Failed to read encrypted data for part 2: %v", err)
- }
-
- // IVs should be different (offset-based)
- if bytes.Equal(iv1, iv2) {
- t.Error("IVs should be different for different offsets")
- }
-
- // Decrypt part 1
- decryptedReader1, err := CreateSSES3DecryptedReader(bytes.NewReader(encryptedData1), sseS3Key, iv1)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for part 1: %v", err)
- }
-
- decryptedData1, err := io.ReadAll(decryptedReader1)
- if err != nil {
- t.Fatalf("Failed to read decrypted data for part 1: %v", err)
- }
-
- // Decrypt part 2
- decryptedReader2, err := CreateSSES3DecryptedReader(bytes.NewReader(encryptedData2), sseS3Key, iv2)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader for part 2: %v", err)
- }
-
- decryptedData2, err := io.ReadAll(decryptedReader2)
- if err != nil {
- t.Fatalf("Failed to read decrypted data for part 2: %v", err)
- }
-
- // Verify decrypted data matches original
- if !bytes.Equal(decryptedData1, testData1) {
- t.Errorf("Decrypted part 1 doesn't match original.\nOriginal: %s\nDecrypted: %s", testData1, decryptedData1)
- }
-
- if !bytes.Equal(decryptedData2, testData2) {
- t.Errorf("Decrypted part 2 doesn't match original.\nOriginal: %s\nDecrypted: %s", testData2, decryptedData2)
- }
-}
-
-// TestSSES3WrongKeyDecryption tests that wrong key fails decryption
-func TestSSES3WrongKeyDecryption(t *testing.T) {
- // Generate two different keys
- sseS3Key1, err := GenerateSSES3Key()
- if err != nil {
- t.Fatalf("Failed to generate SSE-S3 key 1: %v", err)
- }
-
- sseS3Key2, err := GenerateSSES3Key()
- if err != nil {
- t.Fatalf("Failed to generate SSE-S3 key 2: %v", err)
- }
-
- // Test data
- testData := []byte("Secret data encrypted with key 1")
-
- // Encrypt with key 1
- dataReader := bytes.NewReader(testData)
- encryptedReader, iv, err := CreateSSES3EncryptedReader(dataReader, sseS3Key1)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Try to decrypt with key 2 (wrong key)
- decryptedReader, err := CreateSSES3DecryptedReader(bytes.NewReader(encryptedData), sseS3Key2, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data: %v", err)
- }
-
- // Decrypted data should NOT match original (wrong key produces garbage)
- if bytes.Equal(decryptedData, testData) {
- t.Error("Decryption with wrong key should not produce correct plaintext")
- }
-}
-
-// TestSSES3KeyGeneration tests SSE-S3 key generation
-func TestSSES3KeyGeneration(t *testing.T) {
- // Generate multiple keys
- keys := make([]*SSES3Key, 10)
- for i := range keys {
- key, err := GenerateSSES3Key()
- if err != nil {
- t.Fatalf("Failed to generate SSE-S3 key %d: %v", i, err)
- }
- keys[i] = key
-
- // Verify key properties
- if len(key.Key) != SSES3KeySize {
- t.Errorf("Key %d has wrong size: expected %d, got %d", i, SSES3KeySize, len(key.Key))
- }
-
- if key.Algorithm != SSES3Algorithm {
- t.Errorf("Key %d has wrong algorithm: expected %s, got %s", i, SSES3Algorithm, key.Algorithm)
- }
-
- if key.KeyID == "" {
- t.Errorf("Key %d has empty key ID", i)
- }
- }
-
- // Verify keys are unique
- for i := 0; i < len(keys); i++ {
- for j := i + 1; j < len(keys); j++ {
- if bytes.Equal(keys[i].Key, keys[j].Key) {
- t.Errorf("Keys %d and %d are identical (should be unique)", i, j)
- }
- if keys[i].KeyID == keys[j].KeyID {
- t.Errorf("Key IDs %d and %d are identical (should be unique)", i, j)
- }
- }
- }
-}
-
-// TestSSES3VariousSizes tests SSE-S3 encryption/decryption with various data sizes
-func TestSSES3VariousSizes(t *testing.T) {
- sizes := []int{1, 15, 16, 17, 100, 1024, 4096, 1048576}
-
- for _, size := range sizes {
- t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) {
- // Generate test data
- testData := make([]byte, size)
- for i := range testData {
- testData[i] = byte(i % 256)
- }
-
- // Generate key
- sseS3Key, err := GenerateSSES3Key()
- if err != nil {
- t.Fatalf("Failed to generate SSE-S3 key: %v", err)
- }
-
- // Encrypt
- dataReader := bytes.NewReader(testData)
- encryptedReader, iv, err := CreateSSES3EncryptedReader(dataReader, sseS3Key)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
-
- encryptedData, err := io.ReadAll(encryptedReader)
- if err != nil {
- t.Fatalf("Failed to read encrypted data: %v", err)
- }
-
- // Verify encrypted size matches original
- if len(encryptedData) != size {
- t.Errorf("Encrypted size mismatch: expected %d, got %d", size, len(encryptedData))
- }
-
- // Decrypt
- decryptedReader, err := CreateSSES3DecryptedReader(bytes.NewReader(encryptedData), sseS3Key, iv)
- if err != nil {
- t.Fatalf("Failed to create decrypted reader: %v", err)
- }
-
- decryptedData, err := io.ReadAll(decryptedReader)
- if err != nil {
- t.Fatalf("Failed to read decrypted data: %v", err)
- }
-
- // Verify
- if !bytes.Equal(decryptedData, testData) {
- t.Errorf("Decrypted data doesn't match original for size %d", size)
- }
- })
- }
-}
-
-// TestSSES3ResponseHeaders tests that SSE-S3 response headers are set correctly
-func TestSSES3ResponseHeaders(t *testing.T) {
- w := httptest.NewRecorder()
-
- // Simulate setting SSE-S3 response headers
- w.Header().Set(s3_constants.AmzServerSideEncryption, SSES3Algorithm)
-
- // Verify headers
- algorithm := w.Header().Get(s3_constants.AmzServerSideEncryption)
- if algorithm != "AES256" {
- t.Errorf("Expected algorithm AES256, got %s", algorithm)
- }
-
- // Should NOT have customer key headers
- if w.Header().Get(s3_constants.AmzServerSideEncryptionCustomerAlgorithm) != "" {
- t.Error("Should not have SSE-C customer algorithm header")
- }
-
- if w.Header().Get(s3_constants.AmzServerSideEncryptionCustomerKeyMD5) != "" {
- t.Error("Should not have SSE-C customer key MD5 header")
- }
-
- // Should NOT have KMS key ID
- if w.Header().Get(s3_constants.AmzServerSideEncryptionAwsKmsKeyId) != "" {
- t.Error("Should not have SSE-KMS key ID header")
- }
-}
-
-// TestSSES3IsEncryptedInternal tests detection of SSE-S3 encryption from metadata
-func TestSSES3IsEncryptedInternal(t *testing.T) {
- testCases := []struct {
- name string
- metadata map[string][]byte
- expected bool
- }{
- {
- name: "Empty metadata",
- metadata: map[string][]byte{},
- expected: false,
- },
- {
- name: "Valid SSE-S3 metadata with key",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- s3_constants.SeaweedFSSSES3Key: []byte("test-key-data"),
- },
- expected: true,
- },
- {
- name: "SSE-S3 header without key (orphaned header - GitHub #7562)",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- },
- expected: false, // Should not be considered encrypted without the key
- },
- {
- name: "SSE-KMS metadata",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- },
- expected: false,
- },
- {
- name: "SSE-C metadata",
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryptionCustomerAlgorithm: []byte("AES256"),
- },
- expected: false,
- },
- {
- name: "Key without header",
- metadata: map[string][]byte{
- s3_constants.SeaweedFSSSES3Key: []byte("test-key-data"),
- },
- expected: false, // Need both header and key
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- result := IsSSES3EncryptedInternal(tc.metadata)
- if result != tc.expected {
- t.Errorf("Expected %v, got %v", tc.expected, result)
- }
- })
- }
-}
-
-// TestSSES3InvalidMetadataDeserialization tests error handling for invalid metadata
-func TestSSES3InvalidMetadataDeserialization(t *testing.T) {
- keyManager := NewSSES3KeyManager()
- keyManager.superKey = make([]byte, 32)
-
- testCases := []struct {
- name string
- metadata []byte
- shouldError bool
- }{
- {
- name: "Empty metadata",
- metadata: []byte{},
- shouldError: true,
- },
- {
- name: "Invalid JSON",
- metadata: []byte("not valid json"),
- shouldError: true,
- },
- {
- name: "Missing keyId",
- metadata: []byte(`{"algorithm":"AES256"}`),
- shouldError: true,
- },
- {
- name: "Invalid base64 encrypted DEK",
- metadata: []byte(`{"keyId":"test","algorithm":"AES256","encryptedDEK":"not-valid-base64!","nonce":"dGVzdA=="}`),
- shouldError: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- _, err := DeserializeSSES3Metadata(tc.metadata, keyManager)
- if tc.shouldError && err == nil {
- t.Error("Expected error but got none")
- }
- if !tc.shouldError && err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- })
- }
-}
-
-// setViperKey is a test helper that sets a config key via its WEED_ env var.
-func setViperKey(t *testing.T, key, value string) {
- t.Helper()
- util.GetViper().SetDefault(key, "")
- t.Setenv("WEED_"+strings.ReplaceAll(strings.ToUpper(key), ".", "_"), value)
-}
-
-// TestSSES3KEKConfig tests that sse_s3.kek (hex format) is used as KEK
-func TestSSES3KEKConfig(t *testing.T) {
- testKey := make([]byte, 32)
- for i := range testKey {
- testKey[i] = byte(i + 50)
- }
- setViperKey(t, sseS3KEKConfigKey, hex.EncodeToString(testKey))
-
- km := NewSSES3KeyManager()
- err := km.InitializeWithFiler(nil)
- if err != nil {
- t.Fatalf("InitializeWithFiler failed: %v", err)
- }
-
- if !bytes.Equal(km.superKey, testKey) {
- t.Errorf("superKey mismatch: expected %x, got %x", testKey, km.superKey)
- }
-
- // Round-trip DEK encryption
- dek := make([]byte, 32)
- for i := range dek {
- dek[i] = byte(i)
- }
- encrypted, nonce, err := km.encryptKeyWithSuperKey(dek)
- if err != nil {
- t.Fatalf("encryptKeyWithSuperKey failed: %v", err)
- }
- decrypted, err := km.decryptKeyWithSuperKey(encrypted, nonce)
- if err != nil {
- t.Fatalf("decryptKeyWithSuperKey failed: %v", err)
- }
- if !bytes.Equal(decrypted, dek) {
- t.Error("round-trip DEK mismatch")
- }
-}
-
-// TestSSES3KEKConfigInvalidHex tests rejection of bad hex
-func TestSSES3KEKConfigInvalidHex(t *testing.T) {
- setViperKey(t, sseS3KEKConfigKey, "not-valid-hex")
-
- km := NewSSES3KeyManager()
- err := km.InitializeWithFiler(nil)
- if err == nil {
- t.Fatal("expected error for invalid hex, got nil")
- }
- if !strings.Contains(err.Error(), "hex-encoded") {
- t.Errorf("expected hex error, got: %v", err)
- }
-}
-
-// TestSSES3KEKConfigWrongSize tests rejection of wrong-size hex key
-func TestSSES3KEKConfigWrongSize(t *testing.T) {
- setViperKey(t, sseS3KEKConfigKey, hex.EncodeToString(make([]byte, 16)))
-
- km := NewSSES3KeyManager()
- err := km.InitializeWithFiler(nil)
- if err == nil {
- t.Fatal("expected error for wrong key size, got nil")
- }
- if !strings.Contains(err.Error(), "32 bytes") {
- t.Errorf("expected size error, got: %v", err)
- }
-}
-
-// TestSSES3KeyConfig tests that sse_s3.key (any string, HKDF) works
-func TestSSES3KeyConfig(t *testing.T) {
- setViperKey(t, sseS3KeyConfigKey, "my-secret-passphrase")
-
- km := NewSSES3KeyManager()
- err := km.InitializeWithFiler(nil)
- if err != nil {
- t.Fatalf("InitializeWithFiler failed: %v", err)
- }
-
- if len(km.superKey) != SSES3KeySize {
- t.Fatalf("expected %d-byte superKey, got %d", SSES3KeySize, len(km.superKey))
- }
-
- // Deterministic: same input → same output
- expected, err := deriveKeyFromSecret("my-secret-passphrase")
- if err != nil {
- t.Fatalf("deriveKeyFromSecret failed: %v", err)
- }
- if !bytes.Equal(km.superKey, expected) {
- t.Errorf("superKey mismatch: expected %x, got %x", expected, km.superKey)
- }
-}
-
-// TestSSES3KeyConfigDifferentSecrets tests different strings produce different keys
-func TestSSES3KeyConfigDifferentSecrets(t *testing.T) {
- k1, _ := deriveKeyFromSecret("secret-one")
- k2, _ := deriveKeyFromSecret("secret-two")
- if bytes.Equal(k1, k2) {
- t.Error("different secrets should produce different keys")
- }
-}
-
-// TestSSES3BothConfigsReject tests that setting both config keys is rejected
-func TestSSES3BothConfigsReject(t *testing.T) {
- setViperKey(t, sseS3KEKConfigKey, hex.EncodeToString(make([]byte, 32)))
- setViperKey(t, sseS3KeyConfigKey, "some-passphrase")
-
- km := NewSSES3KeyManager()
- err := km.InitializeWithFiler(nil)
- if err == nil {
- t.Fatal("expected error when both configs set, got nil")
- }
- if !strings.Contains(err.Error(), "only one") {
- t.Errorf("expected 'only one' error, got: %v", err)
- }
-}
-
-// TestGetSSES3Headers tests SSE-S3 header generation
-func TestGetSSES3Headers(t *testing.T) {
- headers := GetSSES3Headers()
-
- if len(headers) == 0 {
- t.Error("Expected headers to be non-empty")
- }
-
- algorithm, exists := headers[s3_constants.AmzServerSideEncryption]
- if !exists {
- t.Error("Expected AmzServerSideEncryption header to exist")
- }
-
- if algorithm != "AES256" {
- t.Errorf("Expected algorithm AES256, got %s", algorithm)
- }
-}
-
-// TestProcessSSES3Request tests processing of SSE-S3 requests
-func TestProcessSSES3Request(t *testing.T) {
- // Initialize global key manager
- globalSSES3KeyManager = NewSSES3KeyManager()
- defer func() {
- globalSSES3KeyManager = NewSSES3KeyManager()
- }()
-
- // Set up the key manager with a super key for testing
- keyManager := GetSSES3KeyManager()
- keyManager.superKey = make([]byte, 32)
- for i := range keyManager.superKey {
- keyManager.superKey[i] = byte(i)
- }
-
- // Create SSE-S3 request
- req := httptest.NewRequest("PUT", "/bucket/object", nil)
- req.Header.Set(s3_constants.AmzServerSideEncryption, "AES256")
-
- // Process request
- metadata, err := ProcessSSES3Request(req)
- if err != nil {
- t.Fatalf("Failed to process SSE-S3 request: %v", err)
- }
-
- if metadata == nil {
- t.Fatal("Expected metadata to be non-nil")
- }
-
- // Verify metadata contains SSE algorithm
- if sseAlgo, exists := metadata[s3_constants.AmzServerSideEncryption]; !exists {
- t.Error("Expected SSE algorithm in metadata")
- } else if string(sseAlgo) != "AES256" {
- t.Errorf("Expected AES256, got %s", string(sseAlgo))
- }
-
- // Verify metadata contains key data
- if _, exists := metadata[s3_constants.SeaweedFSSSES3Key]; !exists {
- t.Error("Expected SSE-S3 key data in metadata")
- }
-}
-
-// TestGetSSES3KeyFromMetadata tests extraction of SSE-S3 key from metadata
-func TestGetSSES3KeyFromMetadata(t *testing.T) {
- // Initialize global key manager
- globalSSES3KeyManager = NewSSES3KeyManager()
- defer func() {
- globalSSES3KeyManager = NewSSES3KeyManager()
- }()
-
- // Set up the key manager with a super key for testing
- keyManager := GetSSES3KeyManager()
- keyManager.superKey = make([]byte, 32)
- for i := range keyManager.superKey {
- keyManager.superKey[i] = byte(i)
- }
-
- // Generate and serialize key
- sseS3Key, err := GenerateSSES3Key()
- if err != nil {
- t.Fatalf("Failed to generate SSE-S3 key: %v", err)
- }
-
- sseS3Key.IV = make([]byte, 16)
- for i := range sseS3Key.IV {
- sseS3Key.IV[i] = byte(i)
- }
-
- serialized, err := SerializeSSES3Metadata(sseS3Key)
- if err != nil {
- t.Fatalf("Failed to serialize SSE-S3 metadata: %v", err)
- }
-
- metadata := map[string][]byte{
- s3_constants.SeaweedFSSSES3Key: serialized,
- }
-
- // Extract key
- extractedKey, err := GetSSES3KeyFromMetadata(metadata, keyManager)
- if err != nil {
- t.Fatalf("Failed to get SSE-S3 key from metadata: %v", err)
- }
-
- // Verify key matches
- if !bytes.Equal(extractedKey.Key, sseS3Key.Key) {
- t.Error("Extracted key doesn't match original key")
- }
-
- if !bytes.Equal(extractedKey.IV, sseS3Key.IV) {
- t.Error("Extracted IV doesn't match original IV")
- }
-}
-
-// TestSSES3EnvelopeEncryption tests that envelope encryption works correctly
-func TestSSES3EnvelopeEncryption(t *testing.T) {
- // Initialize key manager with a super key
- keyManager := NewSSES3KeyManager()
- keyManager.superKey = make([]byte, 32)
- for i := range keyManager.superKey {
- keyManager.superKey[i] = byte(i + 100)
- }
-
- // Generate a DEK
- dek := make([]byte, 32)
- for i := range dek {
- dek[i] = byte(i)
- }
-
- // Encrypt DEK with super key
- encryptedDEK, nonce, err := keyManager.encryptKeyWithSuperKey(dek)
- if err != nil {
- t.Fatalf("Failed to encrypt DEK: %v", err)
- }
-
- if len(encryptedDEK) == 0 {
- t.Error("Encrypted DEK is empty")
- }
-
- if len(nonce) == 0 {
- t.Error("Nonce is empty")
- }
-
- // Decrypt DEK with super key
- decryptedDEK, err := keyManager.decryptKeyWithSuperKey(encryptedDEK, nonce)
- if err != nil {
- t.Fatalf("Failed to decrypt DEK: %v", err)
- }
-
- // Verify DEK matches
- if !bytes.Equal(decryptedDEK, dek) {
- t.Error("Decrypted DEK doesn't match original DEK")
- }
-}
-
-// TestValidateSSES3Key tests SSE-S3 key validation
-func TestValidateSSES3Key(t *testing.T) {
- testCases := []struct {
- name string
- key *SSES3Key
- shouldError bool
- errorMsg string
- }{
- {
- name: "Nil key",
- key: nil,
- shouldError: true,
- errorMsg: "SSE-S3 key cannot be nil",
- },
- {
- name: "Valid key",
- key: &SSES3Key{
- Key: make([]byte, 32),
- KeyID: "test-key",
- Algorithm: "AES256",
- },
- shouldError: false,
- },
- {
- name: "Valid key with IV",
- key: &SSES3Key{
- Key: make([]byte, 32),
- KeyID: "test-key",
- Algorithm: "AES256",
- IV: make([]byte, 16),
- },
- shouldError: false,
- },
- {
- name: "Invalid key size (too small)",
- key: &SSES3Key{
- Key: make([]byte, 16),
- KeyID: "test-key",
- Algorithm: "AES256",
- },
- shouldError: true,
- errorMsg: "invalid SSE-S3 key size",
- },
- {
- name: "Invalid key size (too large)",
- key: &SSES3Key{
- Key: make([]byte, 64),
- KeyID: "test-key",
- Algorithm: "AES256",
- },
- shouldError: true,
- errorMsg: "invalid SSE-S3 key size",
- },
- {
- name: "Nil key bytes",
- key: &SSES3Key{
- Key: nil,
- KeyID: "test-key",
- Algorithm: "AES256",
- },
- shouldError: true,
- errorMsg: "SSE-S3 key bytes cannot be nil",
- },
- {
- name: "Empty key ID",
- key: &SSES3Key{
- Key: make([]byte, 32),
- KeyID: "",
- Algorithm: "AES256",
- },
- shouldError: true,
- errorMsg: "SSE-S3 key ID cannot be empty",
- },
- {
- name: "Invalid algorithm",
- key: &SSES3Key{
- Key: make([]byte, 32),
- KeyID: "test-key",
- Algorithm: "INVALID",
- },
- shouldError: true,
- errorMsg: "invalid SSE-S3 algorithm",
- },
- {
- name: "Invalid IV length",
- key: &SSES3Key{
- Key: make([]byte, 32),
- KeyID: "test-key",
- Algorithm: "AES256",
- IV: make([]byte, 8), // Wrong size
- },
- shouldError: true,
- errorMsg: "invalid SSE-S3 IV length",
- },
- {
- name: "Empty IV is allowed (set during encryption)",
- key: &SSES3Key{
- Key: make([]byte, 32),
- KeyID: "test-key",
- Algorithm: "AES256",
- IV: []byte{}, // Empty is OK
- },
- shouldError: false,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- err := ValidateSSES3Key(tc.key)
- if tc.shouldError {
- if err == nil {
- t.Error("Expected error but got none")
- } else if tc.errorMsg != "" && !strings.Contains(err.Error(), tc.errorMsg) {
- t.Errorf("Expected error containing %q, got: %v", tc.errorMsg, err)
- }
- } else {
- if err != nil {
- t.Errorf("Unexpected error: %v", err)
- }
- }
- })
- }
-}
diff --git a/weed/s3api/s3_validation_utils.go b/weed/s3api/s3_validation_utils.go
index f69fc9c26..16e63595c 100644
--- a/weed/s3api/s3_validation_utils.go
+++ b/weed/s3api/s3_validation_utils.go
@@ -58,14 +58,6 @@ func ValidateSSEKMSKey(sseKey *SSEKMSKey) error {
return nil
}
-// ValidateSSECKey validates that an SSE-C key is not nil
-func ValidateSSECKey(customerKey *SSECustomerKey) error {
- if customerKey == nil {
- return fmt.Errorf("SSE-C customer key cannot be nil")
- }
- return nil
-}
-
// ValidateSSES3Key validates that an SSE-S3 key has valid structure and contents
func ValidateSSES3Key(sseKey *SSES3Key) error {
if sseKey == nil {
diff --git a/weed/s3api/s3api_acl_helper.go b/weed/s3api/s3api_acl_helper.go
index 6cfa17f34..5c4804536 100644
--- a/weed/s3api/s3api_acl_helper.go
+++ b/weed/s3api/s3api_acl_helper.go
@@ -20,16 +20,6 @@ type AccountManager interface {
GetAccountIdByEmail(email string) string
}
-// GetAccountId get AccountId from request headers, AccountAnonymousId will be return if not presen
-func GetAccountId(r *http.Request) string {
- id := r.Header.Get(s3_constants.AmzAccountId)
- if len(id) == 0 {
- return s3_constants.AccountAnonymousId
- } else {
- return id
- }
-}
-
// ExtractAcl extracts the acl from the request body, or from the header if request body is empty
func ExtractAcl(r *http.Request, accountManager AccountManager, ownership, bucketOwnerId, ownerId, accountId string) (grants []*s3.Grant, errCode s3err.ErrorCode) {
if r.Body != nil && r.Body != http.NoBody {
@@ -318,83 +308,6 @@ func ValidateAndTransferGrants(accountManager AccountManager, grants []*s3.Grant
return result, s3err.ErrNone
}
-// DetermineReqGrants generates the grant set (Grants) according to accountId and reqPermission.
-func DetermineReqGrants(accountId, aclAction string) (grants []*s3.Grant) {
- // group grantee (AllUsers)
- grants = append(grants, &s3.Grant{
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- Permission: &aclAction,
- })
- grants = append(grants, &s3.Grant{
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- Permission: &s3_constants.PermissionFullControl,
- })
-
- // canonical grantee (accountId)
- grants = append(grants, &s3.Grant{
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: &accountId,
- },
- Permission: &aclAction,
- })
- grants = append(grants, &s3.Grant{
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: &accountId,
- },
- Permission: &s3_constants.PermissionFullControl,
- })
-
- // group grantee (AuthenticateUsers)
- if accountId != s3_constants.AccountAnonymousId {
- grants = append(grants, &s3.Grant{
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAuthenticatedUsers,
- },
- Permission: &aclAction,
- })
- grants = append(grants, &s3.Grant{
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAuthenticatedUsers,
- },
- Permission: &s3_constants.PermissionFullControl,
- })
- }
- return
-}
-
-func SetAcpOwnerHeader(r *http.Request, acpOwnerId string) {
- r.Header.Set(s3_constants.ExtAmzOwnerKey, acpOwnerId)
-}
-
-func GetAcpOwner(entryExtended map[string][]byte, defaultOwner string) string {
- ownerIdBytes, ok := entryExtended[s3_constants.ExtAmzOwnerKey]
- if ok && len(ownerIdBytes) > 0 {
- return string(ownerIdBytes)
- }
- return defaultOwner
-}
-
-func SetAcpGrantsHeader(r *http.Request, acpGrants []*s3.Grant) {
- if len(acpGrants) > 0 {
- a, err := json.Marshal(acpGrants)
- if err == nil {
- r.Header.Set(s3_constants.ExtAmzAclKey, string(a))
- } else {
- glog.Warning("Marshal acp grants err", err)
- }
- }
-}
-
// GetAcpGrants return grants parsed from entry
func GetAcpGrants(entryExtended map[string][]byte) []*s3.Grant {
acpBytes, ok := entryExtended[s3_constants.ExtAmzAclKey]
@@ -433,82 +346,3 @@ func AssembleEntryWithAcp(objectEntry *filer_pb.Entry, objectOwner string, grant
return s3err.ErrNone
}
-
-// GrantEquals Compare whether two Grants are equal in meaning, not completely
-// equal (compare Grantee.Type and the corresponding Value for equality, other
-// fields of Grantee are ignored)
-func GrantEquals(a, b *s3.Grant) bool {
- // grant
- if a == b {
- return true
- }
-
- if a == nil || b == nil {
- return false
- }
-
- // grant.Permission
- if a.Permission != b.Permission {
- if a.Permission == nil || b.Permission == nil {
- return false
- }
-
- if *a.Permission != *b.Permission {
- return false
- }
- }
-
- // grant.Grantee
- ag := a.Grantee
- bg := b.Grantee
- if ag != bg {
- if ag == nil || bg == nil {
- return false
- }
- // grantee.Type
- if ag.Type != bg.Type {
- if ag.Type == nil || bg.Type == nil {
- return false
- }
- if *ag.Type != *bg.Type {
- return false
- }
- }
- // value corresponding to granteeType
- if ag.Type != nil {
- switch *ag.Type {
- case s3_constants.GrantTypeGroup:
- if ag.URI != bg.URI {
- if ag.URI == nil || bg.URI == nil {
- return false
- }
-
- if *ag.URI != *bg.URI {
- return false
- }
- }
- case s3_constants.GrantTypeCanonicalUser:
- if ag.ID != bg.ID {
- if ag.ID == nil || bg.ID == nil {
- return false
- }
-
- if *ag.ID != *bg.ID {
- return false
- }
- }
- case s3_constants.GrantTypeAmazonCustomerByEmail:
- if ag.EmailAddress != bg.EmailAddress {
- if ag.EmailAddress == nil || bg.EmailAddress == nil {
- return false
- }
-
- if *ag.EmailAddress != *bg.EmailAddress {
- return false
- }
- }
- }
- }
- }
- return true
-}
diff --git a/weed/s3api/s3api_acl_helper_test.go b/weed/s3api/s3api_acl_helper_test.go
deleted file mode 100644
index d3a625ce2..000000000
--- a/weed/s3api/s3api_acl_helper_test.go
+++ /dev/null
@@ -1,710 +0,0 @@
-package s3api
-
-import (
- "bytes"
- "encoding/json"
- "io"
- "net/http"
- "testing"
-
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/service/s3"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
-)
-
-var accountManager *IdentityAccessManagement
-
-func init() {
- accountManager = &IdentityAccessManagement{}
- _ = accountManager.loadS3ApiConfiguration(&iam_pb.S3ApiConfiguration{
- Accounts: []*iam_pb.Account{
- {
- Id: "accountA",
- DisplayName: "accountAName",
- EmailAddress: "accountA@example.com",
- },
- {
- Id: "accountB",
- DisplayName: "accountBName",
- EmailAddress: "accountB@example.com",
- },
- },
- })
-}
-
-func TestGetAccountId(t *testing.T) {
- req := &http.Request{
- Header: make(map[string][]string),
- }
- //case1
- //accountId: "admin"
- req.Header.Set(s3_constants.AmzAccountId, s3_constants.AccountAdminId)
- if GetAccountId(req) != s3_constants.AccountAdminId {
- t.Fatal("expect accountId: admin")
- }
-
- //case2
- //accountId: "anoymous"
- req.Header.Set(s3_constants.AmzAccountId, s3_constants.AccountAnonymousId)
- if GetAccountId(req) != s3_constants.AccountAnonymousId {
- t.Fatal("expect accountId: anonymous")
- }
-
- //case3
- //accountId is nil => "anonymous"
- req.Header.Del(s3_constants.AmzAccountId)
- if GetAccountId(req) != s3_constants.AccountAnonymousId {
- t.Fatal("expect accountId: anonymous")
- }
-}
-
-func TestExtractAcl(t *testing.T) {
- type Case struct {
- id int
- resultErrCode, expectErrCode s3err.ErrorCode
- resultGrants, expectGrants []*s3.Grant
- }
- testCases := make([]*Case, 0)
- accountAdminId := "admin"
- {
- //case1 (good case)
- //parse acp from request body
- req := &http.Request{
- Header: make(map[string][]string),
- }
- req.Body = io.NopCloser(bytes.NewReader([]byte(`
-
-
- admin
- admin
-
-
-
-
- admin
-
- FULL_CONTROL
-
-
-
- http://acs.amazonaws.com/groups/global/AllUsers
-
- FULL_CONTROL
-
-
-
- `)))
- objectWriter := "accountA"
- grants, errCode := ExtractAcl(req, accountManager, s3_constants.OwnershipObjectWriter, accountAdminId, accountAdminId, objectWriter)
- testCases = append(testCases, &Case{
- 1,
- errCode, s3err.ErrNone,
- grants, []*s3.Grant{
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: &accountAdminId,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- },
- })
- }
-
- {
- //case2 (good case)
- //parse acp from header (cannedAcl)
- req := &http.Request{
- Header: make(map[string][]string),
- }
- req.Body = nil
- req.Header.Set(s3_constants.AmzCannedAcl, s3_constants.CannedAclPrivate)
- objectWriter := "accountA"
- grants, errCode := ExtractAcl(req, accountManager, s3_constants.OwnershipObjectWriter, accountAdminId, accountAdminId, objectWriter)
- testCases = append(testCases, &Case{
- 2,
- errCode, s3err.ErrNone,
- grants, []*s3.Grant{
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: &objectWriter,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- },
- })
- }
-
- {
- //case3 (bad case)
- //parse acp from request body (content is invalid)
- req := &http.Request{
- Header: make(map[string][]string),
- }
- req.Body = io.NopCloser(bytes.NewReader([]byte("zdfsaf")))
- req.Header.Set(s3_constants.AmzCannedAcl, s3_constants.CannedAclPrivate)
- objectWriter := "accountA"
- _, errCode := ExtractAcl(req, accountManager, s3_constants.OwnershipObjectWriter, accountAdminId, accountAdminId, objectWriter)
- testCases = append(testCases, &Case{
- id: 3,
- resultErrCode: errCode, expectErrCode: s3err.ErrInvalidRequest,
- })
- }
-
- //case4 (bad case)
- //parse acp from header (cannedAcl is invalid)
- req := &http.Request{
- Header: make(map[string][]string),
- }
- req.Body = nil
- req.Header.Set(s3_constants.AmzCannedAcl, "dfaksjfk")
- objectWriter := "accountA"
- _, errCode := ExtractAcl(req, accountManager, s3_constants.OwnershipObjectWriter, accountAdminId, "", objectWriter)
- testCases = append(testCases, &Case{
- id: 4,
- resultErrCode: errCode, expectErrCode: s3err.ErrInvalidRequest,
- })
-
- {
- //case5 (bad case)
- //parse acp from request body: owner is inconsistent
- req.Body = io.NopCloser(bytes.NewReader([]byte(`
-
-
- admin
- admin
-
-
-
-
- admin
-
- FULL_CONTROL
-
-
-
- http://acs.amazonaws.com/groups/global/AllUsers
-
- FULL_CONTROL
-
-
-
- `)))
- objectWriter = "accountA"
- _, errCode := ExtractAcl(req, accountManager, s3_constants.OwnershipObjectWriter, accountAdminId, objectWriter, objectWriter)
- testCases = append(testCases, &Case{
- id: 5,
- resultErrCode: errCode, expectErrCode: s3err.ErrAccessDenied,
- })
- }
-
- for _, tc := range testCases {
- if tc.resultErrCode != tc.expectErrCode {
- t.Fatalf("case[%d]: errorCode not expect", tc.id)
- }
- if !grantsEquals(tc.resultGrants, tc.expectGrants) {
- t.Fatalf("case[%d]: grants not expect", tc.id)
- }
- }
-}
-
-func TestParseAndValidateAclHeaders(t *testing.T) {
- type Case struct {
- id int
- resultOwner, expectOwner string
- resultErrCode, expectErrCode s3err.ErrorCode
- resultGrants, expectGrants []*s3.Grant
- }
- testCases := make([]*Case, 0)
- bucketOwner := "admin"
-
- {
- //case1 (good case)
- //parse custom acl
- req := &http.Request{
- Header: make(map[string][]string),
- }
- objectWriter := "accountA"
- req.Header.Set(s3_constants.AmzAclFullControl, `uri="http://acs.amazonaws.com/groups/global/AllUsers", id="anonymous", emailAddress="admin@example.com"`)
- ownerId, grants, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipObjectWriter, bucketOwner, objectWriter, false)
- testCases = append(testCases, &Case{
- 1,
- ownerId, objectWriter,
- errCode, s3err.ErrNone,
- grants, []*s3.Grant{
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: aws.String(s3_constants.AccountAnonymousId),
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: aws.String(s3_constants.AccountAdminId),
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- },
- })
- }
- {
- //case2 (good case)
- //parse canned acl (ownership=ObjectWriter)
- req := &http.Request{
- Header: make(map[string][]string),
- }
- objectWriter := "accountA"
- req.Header.Set(s3_constants.AmzCannedAcl, s3_constants.CannedAclBucketOwnerFullControl)
- ownerId, grants, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipObjectWriter, bucketOwner, objectWriter, false)
- testCases = append(testCases, &Case{
- 2,
- ownerId, objectWriter,
- errCode, s3err.ErrNone,
- grants, []*s3.Grant{
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: &objectWriter,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: &bucketOwner,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- },
- })
- }
- {
- //case3 (good case)
- //parse canned acl (ownership=OwnershipBucketOwnerPreferred)
- req := &http.Request{
- Header: make(map[string][]string),
- }
- objectWriter := "accountA"
- req.Header.Set(s3_constants.AmzCannedAcl, s3_constants.CannedAclBucketOwnerFullControl)
- ownerId, grants, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipBucketOwnerPreferred, bucketOwner, objectWriter, false)
- testCases = append(testCases, &Case{
- 3,
- ownerId, bucketOwner,
- errCode, s3err.ErrNone,
- grants, []*s3.Grant{
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: &bucketOwner,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- },
- })
- }
- {
- //case4 (bad case)
- //parse custom acl (grantee id not exists)
- req := &http.Request{
- Header: make(map[string][]string),
- }
- objectWriter := "accountA"
- req.Header.Set(s3_constants.AmzAclFullControl, `uri="http://acs.amazonaws.com/groups/global/AllUsers", id="notExistsAccount", emailAddress="admin@example.com"`)
- _, _, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipObjectWriter, bucketOwner, objectWriter, false)
- testCases = append(testCases, &Case{
- id: 4,
- resultErrCode: errCode, expectErrCode: s3err.ErrInvalidRequest,
- })
- }
-
- {
- //case5 (bad case)
- //parse custom acl (invalid format)
- req := &http.Request{
- Header: make(map[string][]string),
- }
- objectWriter := "accountA"
- req.Header.Set(s3_constants.AmzAclFullControl, `uri="http:sfasf"`)
- _, _, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipObjectWriter, bucketOwner, objectWriter, false)
- testCases = append(testCases, &Case{
- id: 5,
- resultErrCode: errCode, expectErrCode: s3err.ErrInvalidRequest,
- })
- }
-
- {
- //case6 (bad case)
- //parse canned acl (invalid value)
- req := &http.Request{
- Header: make(map[string][]string),
- }
- objectWriter := "accountA"
- req.Header.Set(s3_constants.AmzCannedAcl, `uri="http:sfasf"`)
- _, _, errCode := ParseAndValidateAclHeaders(req, accountManager, s3_constants.OwnershipObjectWriter, bucketOwner, objectWriter, false)
- testCases = append(testCases, &Case{
- id: 5,
- resultErrCode: errCode, expectErrCode: s3err.ErrInvalidRequest,
- })
- }
-
- for _, tc := range testCases {
- if tc.expectErrCode != tc.resultErrCode {
- t.Errorf("case[%d]: errCode unexpect", tc.id)
- }
- if tc.resultOwner != tc.expectOwner {
- t.Errorf("case[%d]: ownerId unexpect", tc.id)
- }
- if !grantsEquals(tc.resultGrants, tc.expectGrants) {
- t.Fatalf("case[%d]: grants not expect", tc.id)
- }
- }
-}
-
-func grantsEquals(a, b []*s3.Grant) bool {
- if len(a) != len(b) {
- return false
- }
- for i, grant := range a {
- if !GrantEquals(grant, b[i]) {
- return false
- }
- }
- return true
-}
-
-func TestDetermineReqGrants(t *testing.T) {
- {
- //case1: request account is anonymous
- accountId := s3_constants.AccountAnonymousId
- reqPermission := s3_constants.PermissionRead
-
- resultGrants := DetermineReqGrants(accountId, reqPermission)
- expectGrants := []*s3.Grant{
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- Permission: &reqPermission,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: &accountId,
- },
- Permission: &reqPermission,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: &accountId,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- }
- if !grantsEquals(resultGrants, expectGrants) {
- t.Fatalf("grants not expect")
- }
- }
- {
- //case2: request account is not anonymous (Iam authed)
- accountId := "accountX"
- reqPermission := s3_constants.PermissionRead
-
- resultGrants := DetermineReqGrants(accountId, reqPermission)
- expectGrants := []*s3.Grant{
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- Permission: &reqPermission,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: &accountId,
- },
- Permission: &reqPermission,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeCanonicalUser,
- ID: &accountId,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAuthenticatedUsers,
- },
- Permission: &reqPermission,
- },
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAuthenticatedUsers,
- },
- Permission: &s3_constants.PermissionFullControl,
- },
- }
- if !grantsEquals(resultGrants, expectGrants) {
- t.Fatalf("grants not expect")
- }
- }
-}
-
-func TestAssembleEntryWithAcp(t *testing.T) {
- defaultOwner := "admin"
-
- //case1
- //assemble with non-empty grants
- expectOwner := "accountS"
- expectGrants := []*s3.Grant{
- {
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- ID: aws.String(s3_constants.AccountAdminId),
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- },
- }
- entry := &filer_pb.Entry{}
- AssembleEntryWithAcp(entry, expectOwner, expectGrants)
-
- resultOwner := GetAcpOwner(entry.Extended, defaultOwner)
- if resultOwner != expectOwner {
- t.Fatalf("owner not expect")
- }
-
- resultGrants := GetAcpGrants(entry.Extended)
- if !grantsEquals(resultGrants, expectGrants) {
- t.Fatal("grants not expect")
- }
-
- //case2
- //assemble with empty grants (override)
- AssembleEntryWithAcp(entry, "", nil)
- resultOwner = GetAcpOwner(entry.Extended, defaultOwner)
- if resultOwner != defaultOwner {
- t.Fatalf("owner not expect")
- }
-
- resultGrants = GetAcpGrants(entry.Extended)
- if len(resultGrants) != 0 {
- t.Fatal("grants not expect")
- }
-
-}
-
-func TestGrantEquals(t *testing.T) {
- testCases := map[bool]bool{
- GrantEquals(nil, nil): true,
-
- GrantEquals(&s3.Grant{}, nil): false,
-
- GrantEquals(&s3.Grant{}, &s3.Grant{}): true,
-
- GrantEquals(&s3.Grant{
- Permission: &s3_constants.PermissionRead,
- }, &s3.Grant{}): false,
-
- GrantEquals(&s3.Grant{
- Permission: &s3_constants.PermissionRead,
- }, &s3.Grant{
- Permission: &s3_constants.PermissionRead,
- }): true,
-
- GrantEquals(&s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{},
- }, &s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{},
- }): true,
-
- GrantEquals(&s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- },
- }, &s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{},
- }): false,
-
- //type not present, compare other fields of grant is meaningless
- GrantEquals(&s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- ID: aws.String(s3_constants.AccountAdminId),
- //EmailAddress: &s3account.AccountAdmin.EmailAddress,
- },
- }, &s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- ID: aws.String(s3_constants.AccountAdminId),
- },
- }): true,
-
- GrantEquals(&s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- },
- }, &s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- },
- }): true,
-
- GrantEquals(&s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- }, &s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- }): true,
-
- GrantEquals(&s3.Grant{
- Permission: &s3_constants.PermissionWrite,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- }, &s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- }): false,
-
- GrantEquals(&s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- ID: aws.String(s3_constants.AccountAdminId),
- },
- }, &s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- ID: aws.String(s3_constants.AccountAdminId),
- },
- }): true,
-
- GrantEquals(&s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- ID: aws.String(s3_constants.AccountAdminId),
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- }, &s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- ID: aws.String(s3_constants.AccountAdminId),
- },
- }): false,
-
- GrantEquals(&s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- ID: aws.String(s3_constants.AccountAdminId),
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- }, &s3.Grant{
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- }): true,
- }
-
- for tc, expect := range testCases {
- if tc != expect {
- t.Fatal("TestGrantEquals not expect!")
- }
- }
-}
-
-func TestSetAcpOwnerHeader(t *testing.T) {
- ownerId := "accountZ"
- req := &http.Request{
- Header: make(map[string][]string),
- }
- SetAcpOwnerHeader(req, ownerId)
-
- if req.Header.Get(s3_constants.ExtAmzOwnerKey) != ownerId {
- t.Fatalf("owner unexpect")
- }
-}
-
-func TestSetAcpGrantsHeader(t *testing.T) {
- req := &http.Request{
- Header: make(map[string][]string),
- }
- grants := []*s3.Grant{
- {
- Permission: &s3_constants.PermissionRead,
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- ID: aws.String(s3_constants.AccountAdminId),
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- },
- }
- SetAcpGrantsHeader(req, grants)
-
- grantsJson, _ := json.Marshal(grants)
- if req.Header.Get(s3_constants.ExtAmzAclKey) != string(grantsJson) {
- t.Fatalf("owner unexpect")
- }
-}
diff --git a/weed/s3api/s3api_bucket_handlers.go b/weed/s3api/s3api_bucket_handlers.go
index 5abbd5d22..eec6a11ed 100644
--- a/weed/s3api/s3api_bucket_handlers.go
+++ b/weed/s3api/s3api_bucket_handlers.go
@@ -150,14 +150,6 @@ func isBucketOwnedByIdentity(entry *filer_pb.Entry, identity *Identity) bool {
return true
}
-// isBucketVisibleToIdentity is kept for backward compatibility with tests.
-// It checks if a bucket should be visible based on ownership only.
-// Deprecated: Use isBucketOwnedByIdentity instead. The ListBucketsHandler
-// now uses OR logic: a bucket is visible if user owns it OR has List permission.
-func isBucketVisibleToIdentity(entry *filer_pb.Entry, identity *Identity) bool {
- return isBucketOwnedByIdentity(entry, identity)
-}
-
func (s3a *S3ApiServer) PutBucketHandler(w http.ResponseWriter, r *http.Request) {
// collect parameters
diff --git a/weed/s3api/s3api_bucket_handlers_test.go b/weed/s3api/s3api_bucket_handlers_test.go
deleted file mode 100644
index ee79381b3..000000000
--- a/weed/s3api/s3api_bucket_handlers_test.go
+++ /dev/null
@@ -1,1085 +0,0 @@
-package s3api
-
-import (
- "encoding/json"
- "encoding/xml"
- "fmt"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
-
- "github.com/aws/aws-sdk-go/service/s3"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestPutBucketAclCannedAclSupport(t *testing.T) {
- // Test that the ExtractAcl function can handle various canned ACLs
- // This tests the core functionality without requiring a fully initialized S3ApiServer
-
- testCases := []struct {
- name string
- cannedAcl string
- shouldWork bool
- description string
- }{
- {
- name: "private",
- cannedAcl: s3_constants.CannedAclPrivate,
- shouldWork: true,
- description: "private ACL should be accepted",
- },
- {
- name: "public-read",
- cannedAcl: s3_constants.CannedAclPublicRead,
- shouldWork: true,
- description: "public-read ACL should be accepted",
- },
- {
- name: "public-read-write",
- cannedAcl: s3_constants.CannedAclPublicReadWrite,
- shouldWork: true,
- description: "public-read-write ACL should be accepted",
- },
- {
- name: "authenticated-read",
- cannedAcl: s3_constants.CannedAclAuthenticatedRead,
- shouldWork: true,
- description: "authenticated-read ACL should be accepted",
- },
- {
- name: "bucket-owner-read",
- cannedAcl: s3_constants.CannedAclBucketOwnerRead,
- shouldWork: true,
- description: "bucket-owner-read ACL should be accepted",
- },
- {
- name: "bucket-owner-full-control",
- cannedAcl: s3_constants.CannedAclBucketOwnerFullControl,
- shouldWork: true,
- description: "bucket-owner-full-control ACL should be accepted",
- },
- {
- name: "invalid-acl",
- cannedAcl: "invalid-acl-value",
- shouldWork: false,
- description: "invalid ACL should be rejected",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- // Create a request with the specified canned ACL
- req := httptest.NewRequest("PUT", "/bucket?acl", nil)
- req.Header.Set(s3_constants.AmzCannedAcl, tc.cannedAcl)
- req.Header.Set(s3_constants.AmzAccountId, "test-account-123")
-
- // Create a mock IAM for testing
- mockIam := &mockIamInterface{}
-
- // Test the ACL extraction directly
- grants, errCode := ExtractAcl(req, mockIam, "", "test-account-123", "test-account-123", "test-account-123")
-
- if tc.shouldWork {
- assert.Equal(t, s3err.ErrNone, errCode, "Expected ACL parsing to succeed for %s", tc.cannedAcl)
- assert.NotEmpty(t, grants, "Expected grants to be generated for valid ACL %s", tc.cannedAcl)
- t.Logf("✓ PASS: %s - %s", tc.name, tc.description)
- } else {
- assert.NotEqual(t, s3err.ErrNone, errCode, "Expected ACL parsing to fail for invalid ACL %s", tc.cannedAcl)
- t.Logf("✓ PASS: %s - %s", tc.name, tc.description)
- }
- })
- }
-}
-
-// TestBucketWithoutACLIsNotPublicRead tests that buckets without ACLs are not public-read
-func TestBucketWithoutACLIsNotPublicRead(t *testing.T) {
- // Create a bucket config without ACL (like a freshly created bucket)
- config := &BucketConfig{
- Name: "test-bucket",
- IsPublicRead: false, // Should be explicitly false
- }
-
- // Verify that buckets without ACL are not public-read
- assert.False(t, config.IsPublicRead, "Bucket without ACL should not be public-read")
-}
-
-func TestBucketConfigInitialization(t *testing.T) {
- // Test that BucketConfig properly initializes IsPublicRead field
- config := &BucketConfig{
- Name: "test-bucket",
- IsPublicRead: false, // Explicitly set to false for private buckets
- }
-
- // Verify proper initialization
- assert.False(t, config.IsPublicRead, "Newly created bucket should not be public-read by default")
-}
-
-// TestUpdateBucketConfigCacheConsistency tests that updateBucketConfigCacheFromEntry
-// properly handles the IsPublicRead flag consistently with getBucketConfig
-func TestUpdateBucketConfigCacheConsistency(t *testing.T) {
- t.Run("bucket without ACL should have IsPublicRead=false", func(t *testing.T) {
- // Simulate an entry without ACL (like a freshly created bucket)
- entry := &filer_pb.Entry{
- Name: "test-bucket",
- Attributes: &filer_pb.FuseAttributes{
- FileMode: 0755,
- },
- // Extended is nil or doesn't contain ACL
- }
-
- // Test what updateBucketConfigCacheFromEntry would create
- config := &BucketConfig{
- Name: entry.Name,
- Entry: entry,
- IsPublicRead: false, // Should be explicitly false
- }
-
- // When Extended is nil, IsPublicRead should be false
- assert.False(t, config.IsPublicRead, "Bucket without Extended metadata should not be public-read")
-
- // When Extended exists but has no ACL key, IsPublicRead should also be false
- entry.Extended = make(map[string][]byte)
- entry.Extended["some-other-key"] = []byte("some-value")
-
- config = &BucketConfig{
- Name: entry.Name,
- Entry: entry,
- IsPublicRead: false, // Should be explicitly false
- }
-
- // Simulate the else branch: no ACL means private bucket
- if _, exists := entry.Extended[s3_constants.ExtAmzAclKey]; !exists {
- config.IsPublicRead = false
- }
-
- assert.False(t, config.IsPublicRead, "Bucket with Extended but no ACL should not be public-read")
- })
-
- t.Run("bucket with public-read ACL should have IsPublicRead=true", func(t *testing.T) {
- // Create a mock public-read ACL using AWS S3 SDK types
- publicReadGrants := []*s3.Grant{
- {
- Grantee: &s3.Grantee{
- Type: &s3_constants.GrantTypeGroup,
- URI: &s3_constants.GranteeGroupAllUsers,
- },
- Permission: &s3_constants.PermissionRead,
- },
- }
-
- aclBytes, err := json.Marshal(publicReadGrants)
- require.NoError(t, err)
-
- entry := &filer_pb.Entry{
- Name: "public-bucket",
- Extended: map[string][]byte{
- s3_constants.ExtAmzAclKey: aclBytes,
- },
- }
-
- config := &BucketConfig{
- Name: entry.Name,
- Entry: entry,
- IsPublicRead: false, // Start with false
- }
-
- // Simulate what updateBucketConfigCacheFromEntry would do
- if acl, exists := entry.Extended[s3_constants.ExtAmzAclKey]; exists {
- config.ACL = acl
- config.IsPublicRead = parseAndCachePublicReadStatus(acl)
- }
-
- assert.True(t, config.IsPublicRead, "Bucket with public-read ACL should be public-read")
- })
-}
-
-// mockIamInterface is a simple mock for testing
-type mockIamInterface struct{}
-
-func (m *mockIamInterface) GetAccountNameById(canonicalId string) string {
- return "test-user-" + canonicalId
-}
-
-func (m *mockIamInterface) GetAccountIdByEmail(email string) string {
- return "account-for-" + email
-}
-
-// TestListAllMyBucketsResultNamespace verifies that the ListAllMyBucketsResult
-// XML response includes the proper S3 namespace URI
-func TestListAllMyBucketsResultNamespace(t *testing.T) {
- // Create a sample ListAllMyBucketsResult response
- response := ListAllMyBucketsResult{
- Owner: CanonicalUser{
- ID: "test-owner-id",
- DisplayName: "test-owner",
- },
- Buckets: ListAllMyBucketsList{
- Bucket: []ListAllMyBucketsEntry{
- {
- Name: "test-bucket",
- CreationDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
- },
- },
- },
- }
-
- // Marshal the response to XML
- xmlData, err := xml.Marshal(response)
- require.NoError(t, err, "Failed to marshal XML response")
-
- xmlString := string(xmlData)
-
- // Verify that the XML contains the proper namespace
- assert.Contains(t, xmlString, `xmlns="http://s3.amazonaws.com/doc/2006-03-01/"`,
- "XML response should contain the S3 namespace URI")
-
- // Verify the root element has the correct name
- assert.Contains(t, xmlString, "", "XML should contain Owner element")
- assert.Contains(t, xmlString, "", "XML should contain Buckets element")
- assert.Contains(t, xmlString, "", "XML should contain Bucket element")
- assert.Contains(t, xmlString, "test-bucket", "XML should contain bucket name")
-
- t.Logf("Generated XML:\n%s", xmlString)
-}
-
-// TestListBucketsOwnershipFiltering tests that ListBucketsHandler properly filters
-// buckets based on ownership, allowing only bucket owners (or admins) to see their buckets
-func TestListBucketsOwnershipFiltering(t *testing.T) {
- testCases := []struct {
- name string
- buckets []testBucket
- requestIdentityId string
- requestIsAdmin bool
- expectedBucketNames []string
- description string
- }{
- {
- name: "non-admin sees only owned buckets",
- buckets: []testBucket{
- {name: "user1-bucket", ownerId: "user1"},
- {name: "user2-bucket", ownerId: "user2"},
- {name: "user1-bucket2", ownerId: "user1"},
- },
- requestIdentityId: "user1",
- requestIsAdmin: false,
- expectedBucketNames: []string{"user1-bucket", "user1-bucket2"},
- description: "Non-admin user should only see buckets they own",
- },
- {
- name: "admin sees all buckets",
- buckets: []testBucket{
- {name: "user1-bucket", ownerId: "user1"},
- {name: "user2-bucket", ownerId: "user2"},
- {name: "user3-bucket", ownerId: "user3"},
- },
- requestIdentityId: "admin",
- requestIsAdmin: true,
- expectedBucketNames: []string{"user1-bucket", "user2-bucket", "user3-bucket"},
- description: "Admin should see all buckets regardless of owner",
- },
- {
- name: "buckets without owner are hidden from non-admins",
- buckets: []testBucket{
- {name: "owned-bucket", ownerId: "user1"},
- {name: "unowned-bucket", ownerId: ""}, // No owner set
- },
- requestIdentityId: "user2",
- requestIsAdmin: false,
- expectedBucketNames: []string{},
- description: "Buckets without owner should be hidden from non-admin users",
- },
- {
- name: "unauthenticated user sees no buckets",
- buckets: []testBucket{
- {name: "owned-bucket", ownerId: "user1"},
- {name: "unowned-bucket", ownerId: ""},
- },
- requestIdentityId: "",
- requestIsAdmin: false,
- expectedBucketNames: []string{},
- description: "Unauthenticated requests should not see any buckets",
- },
- {
- name: "admin sees buckets regardless of ownership",
- buckets: []testBucket{
- {name: "user1-bucket", ownerId: "user1"},
- {name: "user2-bucket", ownerId: "user2"},
- {name: "unowned-bucket", ownerId: ""},
- },
- requestIdentityId: "admin",
- requestIsAdmin: true,
- expectedBucketNames: []string{"user1-bucket", "user2-bucket", "unowned-bucket"},
- description: "Admin should see all buckets regardless of ownership",
- },
- {
- name: "buckets with nil Extended metadata hidden from non-admins",
- buckets: []testBucket{
- {name: "bucket-no-extended", ownerId: "", nilExtended: true},
- {name: "bucket-with-owner", ownerId: "user1"},
- },
- requestIdentityId: "user1",
- requestIsAdmin: false,
- expectedBucketNames: []string{"bucket-with-owner"},
- description: "Buckets with nil Extended (no owner) should be hidden from non-admins",
- },
- {
- name: "user sees only their bucket among many",
- buckets: []testBucket{
- {name: "alice-bucket", ownerId: "alice"},
- {name: "bob-bucket", ownerId: "bob"},
- {name: "charlie-bucket", ownerId: "charlie"},
- {name: "alice-bucket2", ownerId: "alice"},
- },
- requestIdentityId: "bob",
- requestIsAdmin: false,
- expectedBucketNames: []string{"bob-bucket"},
- description: "User should see only their single bucket among many",
- },
- {
- name: "admin sees buckets without owners",
- buckets: []testBucket{
- {name: "owned-bucket", ownerId: "user1"},
- {name: "unowned-bucket", ownerId: ""},
- {name: "no-metadata-bucket", ownerId: "", nilExtended: true},
- },
- requestIdentityId: "admin",
- requestIsAdmin: true,
- expectedBucketNames: []string{"owned-bucket", "unowned-bucket", "no-metadata-bucket"},
- description: "Admin should see all buckets including those without owners",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- // Create mock entries
- entries := make([]*filer_pb.Entry, 0, len(tc.buckets))
- for _, bucket := range tc.buckets {
- entry := &filer_pb.Entry{
- Name: bucket.name,
- IsDirectory: true,
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- if !bucket.nilExtended {
- entry.Extended = make(map[string][]byte)
- if bucket.ownerId != "" {
- entry.Extended[s3_constants.AmzIdentityId] = []byte(bucket.ownerId)
- }
- }
-
- entries = append(entries, entry)
- }
-
- // Filter entries using the actual production code
- var filteredBuckets []string
- for _, entry := range entries {
- var identity *Identity
- if tc.requestIdentityId != "" {
- identity = mockIdentity(tc.requestIdentityId, tc.requestIsAdmin)
- }
- if isBucketVisibleToIdentity(entry, identity) {
- filteredBuckets = append(filteredBuckets, entry.Name)
- }
- }
-
- // Assert expected buckets match filtered buckets
- assert.ElementsMatch(t, tc.expectedBucketNames, filteredBuckets,
- "%s - Expected buckets: %v, Got: %v", tc.description, tc.expectedBucketNames, filteredBuckets)
- })
- }
-}
-
-// testBucket represents a bucket for testing with ownership metadata
-type testBucket struct {
- name string
- ownerId string
- nilExtended bool
-}
-
-// mockIdentity creates a mock Identity for testing bucket visibility
-func mockIdentity(name string, isAdmin bool) *Identity {
- identity := &Identity{
- Name: name,
- }
- if isAdmin {
- identity.Credentials = []*Credential{
- {
- AccessKey: "admin-key",
- SecretKey: "admin-secret",
- },
- }
- identity.Actions = []Action{Action(s3_constants.ACTION_ADMIN)}
- }
- return identity
-}
-
-// TestListBucketsOwnershipEdgeCases tests edge cases in ownership filtering
-func TestListBucketsOwnershipEdgeCases(t *testing.T) {
- t.Run("malformed owner id with special characters", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Name: "test-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("user@domain.com"),
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- identity := mockIdentity("user@domain.com", false)
-
- // Should match exactly even with special characters
- isVisible := isBucketVisibleToIdentity(entry, identity)
-
- assert.True(t, isVisible, "Should match owner ID with special characters exactly")
- })
-
- t.Run("owner id with unicode characters", func(t *testing.T) {
- unicodeOwnerId := "用户123"
- entry := &filer_pb.Entry{
- Name: "test-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte(unicodeOwnerId),
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- identity := mockIdentity(unicodeOwnerId, false)
-
- isVisible := isBucketVisibleToIdentity(entry, identity)
-
- assert.True(t, isVisible, "Should handle unicode owner IDs correctly")
- })
-
- t.Run("owner id with binary data", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Name: "test-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte{0x00, 0x01, 0x02, 0xFF},
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- identity := mockIdentity("normaluser", false)
-
- // Should not panic when converting binary data to string
- assert.NotPanics(t, func() {
- isVisible := isBucketVisibleToIdentity(entry, identity)
- assert.False(t, isVisible, "Binary owner ID should not match normal user")
- })
- })
-
- t.Run("empty owner id in Extended", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Name: "test-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte(""),
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- identity := mockIdentity("user1", false)
-
- isVisible := isBucketVisibleToIdentity(entry, identity)
-
- assert.False(t, isVisible, "Empty owner ID should be treated as unowned (hidden from non-admins)")
- })
-
- t.Run("nil Extended map safe access", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Name: "test-bucket",
- IsDirectory: true,
- Extended: nil, // Explicitly nil
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- identity := mockIdentity("user1", false)
-
- // Should not panic with nil Extended map
- assert.NotPanics(t, func() {
- isVisible := isBucketVisibleToIdentity(entry, identity)
- assert.False(t, isVisible, "Nil Extended (no owner) should be hidden from non-admins")
- })
- })
-
- t.Run("very long owner id", func(t *testing.T) {
- longOwnerId := strings.Repeat("a", 10000)
- entry := &filer_pb.Entry{
- Name: "test-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte(longOwnerId),
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- identity := mockIdentity(longOwnerId, false)
-
- // Should handle very long owner IDs without panic
- assert.NotPanics(t, func() {
- isVisible := isBucketVisibleToIdentity(entry, identity)
- assert.True(t, isVisible, "Long owner ID should match correctly")
- })
- })
-}
-
-// TestListBucketsOwnershipWithPermissions tests that ownership filtering
-// works in conjunction with permission checks
-func TestListBucketsOwnershipWithPermissions(t *testing.T) {
- t.Run("ownership check before permission check", func(t *testing.T) {
- // Simulate scenario where ownership check filters first,
- // then permission check applies to remaining buckets
- entries := []*filer_pb.Entry{
- {
- Name: "owned-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("user1"),
- },
- Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()},
- },
- {
- Name: "other-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("user2"),
- },
- Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()},
- },
- }
-
- identity := mockIdentity("user1", false)
-
- // First pass: ownership filtering
- var afterOwnershipFilter []*filer_pb.Entry
- for _, entry := range entries {
- if isBucketVisibleToIdentity(entry, identity) {
- afterOwnershipFilter = append(afterOwnershipFilter, entry)
- }
- }
-
- // Only owned-bucket should remain after ownership filter
- assert.Len(t, afterOwnershipFilter, 1, "Only owned bucket should pass ownership filter")
- assert.Equal(t, "owned-bucket", afterOwnershipFilter[0].Name)
-
- // Permission checks would apply to afterOwnershipFilter entries
- // (not tested here as it depends on IAM system)
- })
-
- t.Run("admin bypasses ownership but not permissions", func(t *testing.T) {
- entries := []*filer_pb.Entry{
- {
- Name: "user1-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("user1"),
- },
- Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()},
- },
- {
- Name: "user2-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("user2"),
- },
- Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()},
- },
- }
-
- identity := mockIdentity("admin-user", true)
-
- // Admin bypasses ownership check
- var afterOwnershipFilter []*filer_pb.Entry
- for _, entry := range entries {
- if isBucketVisibleToIdentity(entry, identity) {
- afterOwnershipFilter = append(afterOwnershipFilter, entry)
- }
- }
-
- // Admin should see all buckets after ownership filter
- assert.Len(t, afterOwnershipFilter, 2, "Admin should see all buckets after ownership filter")
- // Note: Permission checks still apply to admins in actual implementation
- })
-}
-
-// TestListBucketsOwnershipCaseSensitivity tests case sensitivity in owner matching
-func TestListBucketsOwnershipCaseSensitivity(t *testing.T) {
- entry := &filer_pb.Entry{
- Name: "test-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("User1"),
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- testCases := []struct {
- requestIdentityId string
- shouldMatch bool
- }{
- {"User1", true},
- {"user1", false}, // Case sensitive
- {"USER1", false}, // Case sensitive
- {"User2", false},
- }
-
- for _, tc := range testCases {
- t.Run(fmt.Sprintf("identity_%s", tc.requestIdentityId), func(t *testing.T) {
- identity := mockIdentity(tc.requestIdentityId, false)
- isVisible := isBucketVisibleToIdentity(entry, identity)
-
- if tc.shouldMatch {
- assert.True(t, isVisible, "Identity %s should match (case sensitive)", tc.requestIdentityId)
- } else {
- assert.False(t, isVisible, "Identity %s should not match (case sensitive)", tc.requestIdentityId)
- }
- })
- }
-}
-
-// TestListBucketsIssue7647 reproduces and verifies the fix for issue #7647
-// where an admin user with proper permissions could create buckets but couldn't list them
-func TestListBucketsIssue7647(t *testing.T) {
- t.Run("admin user can see their created buckets", func(t *testing.T) {
- // Simulate the exact scenario from issue #7647:
- // User "root" with ["Admin", "Read", "Write", "Tagging", "List"] permissions
-
- // Create identity for root user with Admin action
- rootIdentity := &Identity{
- Name: "root",
- Credentials: []*Credential{
- {
- AccessKey: "ROOTID",
- SecretKey: "ROOTSECRET",
- },
- },
- Actions: []Action{
- s3_constants.ACTION_ADMIN,
- s3_constants.ACTION_READ,
- s3_constants.ACTION_WRITE,
- s3_constants.ACTION_TAGGING,
- s3_constants.ACTION_LIST,
- },
- }
-
- // Create a bucket entry as if it was created by the root user
- bucketEntry := &filer_pb.Entry{
- Name: "test",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("root"),
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- Mtime: time.Now().Unix(),
- },
- }
-
- // Test bucket visibility - should be visible to root (owner)
- isVisible := isBucketVisibleToIdentity(bucketEntry, rootIdentity)
- assert.True(t, isVisible, "Root user should see their own bucket")
-
- // Test that admin can also see buckets they don't own
- otherUserBucket := &filer_pb.Entry{
- Name: "other-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("otheruser"),
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- Mtime: time.Now().Unix(),
- },
- }
-
- isVisible = isBucketVisibleToIdentity(otherUserBucket, rootIdentity)
- assert.True(t, isVisible, "Admin user should see all buckets, even ones they don't own")
-
- // Test permission check for List action
- canList := rootIdentity.CanDo(s3_constants.ACTION_LIST, "test", "")
- assert.True(t, canList, "Root user with List action should be able to list buckets")
- })
-
- t.Run("admin user sees buckets without owner metadata", func(t *testing.T) {
- // Admin users should see buckets even if they don't have owner metadata
- // (this can happen with legacy buckets or manual creation)
-
- rootIdentity := &Identity{
- Name: "root",
- Actions: []Action{
- s3_constants.ACTION_ADMIN,
- s3_constants.ACTION_LIST,
- },
- }
-
- bucketWithoutOwner := &filer_pb.Entry{
- Name: "legacy-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{}, // No owner metadata
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- isVisible := isBucketVisibleToIdentity(bucketWithoutOwner, rootIdentity)
- assert.True(t, isVisible, "Admin should see buckets without owner metadata")
- })
-
- t.Run("non-admin user cannot see buckets without owner", func(t *testing.T) {
- // Non-admin users should not see buckets without owner metadata
-
- regularUser := &Identity{
- Name: "user1",
- Actions: []Action{
- s3_constants.ACTION_READ,
- s3_constants.ACTION_LIST,
- },
- }
-
- bucketWithoutOwner := &filer_pb.Entry{
- Name: "legacy-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{}, // No owner metadata
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- isVisible := isBucketVisibleToIdentity(bucketWithoutOwner, regularUser)
- assert.False(t, isVisible, "Non-admin should not see buckets without owner metadata")
- })
-}
-
-// TestListBucketsIssue7796 reproduces and verifies the fix for issue #7796
-// where a user with bucket-specific List permission (e.g., "List:geoserver")
-// couldn't see buckets they have access to but don't own
-func TestListBucketsIssue7796(t *testing.T) {
- t.Run("user with bucket-specific List permission can see bucket they don't own", func(t *testing.T) {
- // Simulate the exact scenario from issue #7796:
- // User "geoserver" with ["List:geoserver", "Read:geoserver", "Write:geoserver", ...] permissions
- // But the bucket "geoserver" was created by a different user (e.g., admin)
-
- geoserverIdentity := &Identity{
- Name: "geoserver",
- Credentials: []*Credential{
- {
- AccessKey: "geoserver",
- SecretKey: "secret",
- },
- },
- Actions: []Action{
- Action("List:geoserver"),
- Action("Read:geoserver"),
- Action("Write:geoserver"),
- Action("Admin:geoserver"),
- Action("List:geoserver-ttl"),
- Action("Read:geoserver-ttl"),
- Action("Write:geoserver-ttl"),
- },
- }
-
- // Bucket was created by admin, not by geoserver user
- geoserverBucket := &filer_pb.Entry{
- Name: "geoserver",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("admin"), // Different owner
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- Mtime: time.Now().Unix(),
- },
- }
-
- // Test ownership check - should return false (not owned by geoserver)
- isOwner := isBucketOwnedByIdentity(geoserverBucket, geoserverIdentity)
- assert.False(t, isOwner, "geoserver user should not be owner of bucket created by admin")
-
- // Test permission check - should return true (has List:geoserver permission)
- canList := geoserverIdentity.CanDo(s3_constants.ACTION_LIST, "geoserver", "")
- assert.True(t, canList, "geoserver user with List:geoserver should be able to list geoserver bucket")
-
- // Verify the combined visibility logic: ownership OR permission
- isVisible := isOwner || canList
- assert.True(t, isVisible, "Bucket should be visible due to permission (even though not owner)")
- })
-
- t.Run("user with bucket-specific permission sees bucket without owner metadata", func(t *testing.T) {
- // Bucket exists but has no owner metadata (legacy bucket or created before ownership tracking)
-
- geoserverIdentity := &Identity{
- Name: "geoserver",
- Actions: []Action{
- Action("List:geoserver"),
- Action("Read:geoserver"),
- },
- }
-
- bucketWithoutOwner := &filer_pb.Entry{
- Name: "geoserver",
- IsDirectory: true,
- Extended: map[string][]byte{}, // No owner metadata
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- // Not owner (no owner metadata)
- isOwner := isBucketOwnedByIdentity(bucketWithoutOwner, geoserverIdentity)
- assert.False(t, isOwner, "No owner metadata means not owned")
-
- // But has permission
- canList := geoserverIdentity.CanDo(s3_constants.ACTION_LIST, "geoserver", "")
- assert.True(t, canList, "Has explicit List:geoserver permission")
-
- // Verify the combined visibility logic: ownership OR permission
- isVisible := isOwner || canList
- assert.True(t, isVisible, "Bucket should be visible due to permission (even without owner metadata)")
- })
-
- t.Run("user cannot see bucket they neither own nor have permission for", func(t *testing.T) {
- // User has no ownership and no permission for the bucket
-
- geoserverIdentity := &Identity{
- Name: "geoserver",
- Actions: []Action{
- Action("List:geoserver"),
- Action("Read:geoserver"),
- },
- }
-
- otherBucket := &filer_pb.Entry{
- Name: "otherbucket",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("admin"),
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- // Not owner
- isOwner := isBucketOwnedByIdentity(otherBucket, geoserverIdentity)
- assert.False(t, isOwner, "geoserver doesn't own otherbucket")
-
- // No permission for this bucket
- canList := geoserverIdentity.CanDo(s3_constants.ACTION_LIST, "otherbucket", "")
- assert.False(t, canList, "geoserver has no List permission for otherbucket")
-
- // Verify the combined visibility logic: ownership OR permission
- isVisible := isOwner || canList
- assert.False(t, isVisible, "Bucket should NOT be visible (neither owner nor has permission)")
- })
-
- t.Run("user with wildcard permission sees matching buckets", func(t *testing.T) {
- // User has "List:geo*" permission - should see any bucket starting with "geo"
-
- geoIdentity := &Identity{
- Name: "geouser",
- Actions: []Action{
- Action("List:geo*"),
- Action("Read:geo*"),
- },
- }
-
- geoBucket := &filer_pb.Entry{
- Name: "geoserver",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("admin"),
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- geoTTLBucket := &filer_pb.Entry{
- Name: "geoserver-ttl",
- IsDirectory: true,
- Extended: map[string][]byte{
- s3_constants.AmzIdentityId: []byte("admin"),
- },
- Attributes: &filer_pb.FuseAttributes{
- Crtime: time.Now().Unix(),
- },
- }
-
- // Not owner of either bucket
- isOwnerGeo := isBucketOwnedByIdentity(geoBucket, geoIdentity)
- isOwnerGeoTTL := isBucketOwnedByIdentity(geoTTLBucket, geoIdentity)
- assert.False(t, isOwnerGeo)
- assert.False(t, isOwnerGeoTTL)
-
- // But has permission via wildcard
- canListGeo := geoIdentity.CanDo(s3_constants.ACTION_LIST, "geoserver", "")
- canListGeoTTL := geoIdentity.CanDo(s3_constants.ACTION_LIST, "geoserver-ttl", "")
- assert.True(t, canListGeo)
- assert.True(t, canListGeoTTL)
-
- // Verify the combined visibility logic for matching buckets
- assert.True(t, isOwnerGeo || canListGeo, "geoserver bucket should be visible via wildcard permission")
- assert.True(t, isOwnerGeoTTL || canListGeoTTL, "geoserver-ttl bucket should be visible via wildcard permission")
-
- // Should NOT have permission for unrelated buckets
- canListOther := geoIdentity.CanDo(s3_constants.ACTION_LIST, "otherbucket", "")
- assert.False(t, canListOther, "No permission for otherbucket")
- assert.False(t, false || canListOther, "otherbucket should NOT be visible (no ownership, no permission)")
- })
-
- t.Run("integration test: complete handler filtering logic", func(t *testing.T) {
- // This test simulates the complete filtering logic as used in ListBucketsHandler
- // to verify that the combination of ownership OR permission check works correctly
-
- // User "geoserver" with bucket-specific permissions (same as issue #7796)
- geoserverIdentity := &Identity{
- Name: "geoserver",
- Credentials: []*Credential{
- {AccessKey: "geoserver", SecretKey: "secret"},
- },
- Actions: []Action{
- Action("List:geoserver"),
- Action("Read:geoserver"),
- Action("Write:geoserver"),
- Action("Admin:geoserver"),
- Action("List:geoserver-ttl"),
- Action("Read:geoserver-ttl"),
- Action("Write:geoserver-ttl"),
- },
- }
-
- // Create test buckets with various ownership scenarios
- buckets := []*filer_pb.Entry{
- {
- // Bucket owned by admin but geoserver has permission
- Name: "geoserver",
- IsDirectory: true,
- Extended: map[string][]byte{s3_constants.AmzIdentityId: []byte("admin")},
- Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()},
- },
- {
- // Bucket with no owner but geoserver has permission
- Name: "geoserver-ttl",
- IsDirectory: true,
- Extended: map[string][]byte{},
- Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()},
- },
- {
- // Bucket owned by geoserver (should be visible via ownership)
- Name: "geoserver-owned",
- IsDirectory: true,
- Extended: map[string][]byte{s3_constants.AmzIdentityId: []byte("geoserver")},
- Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()},
- },
- {
- // Bucket owned by someone else, no permission for geoserver
- Name: "otherbucket",
- IsDirectory: true,
- Extended: map[string][]byte{s3_constants.AmzIdentityId: []byte("otheruser")},
- Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()},
- },
- }
-
- // Simulate the exact filtering logic from ListBucketsHandler
- var visibleBuckets []string
- for _, entry := range buckets {
- if !entry.IsDirectory {
- continue
- }
-
- // Check ownership
- isOwner := isBucketOwnedByIdentity(entry, geoserverIdentity)
-
- // Skip permission check if user is already the owner (optimization)
- if !isOwner {
- // Check permission
- hasPermission := geoserverIdentity.CanDo(s3_constants.ACTION_LIST, entry.Name, "")
- if !hasPermission {
- continue
- }
- }
-
- visibleBuckets = append(visibleBuckets, entry.Name)
- }
-
- // Expected: geoserver should see:
- // - "geoserver" (has List:geoserver permission, even though owned by admin)
- // - "geoserver-ttl" (has List:geoserver-ttl permission, even though no owner)
- // - "geoserver-owned" (owns this bucket)
- // NOT "otherbucket" (neither owns nor has permission)
- expectedBuckets := []string{"geoserver", "geoserver-ttl", "geoserver-owned"}
- assert.ElementsMatch(t, expectedBuckets, visibleBuckets,
- "geoserver should see buckets they own OR have permission for")
-
- // Verify "otherbucket" is NOT in the list
- assert.NotContains(t, visibleBuckets, "otherbucket",
- "geoserver should NOT see buckets they neither own nor have permission for")
- })
-}
-
-func TestListBucketsIssue8516PolicyBasedVisibility(t *testing.T) {
- iam := &IdentityAccessManagement{}
- require.NoError(t, iam.PutPolicy("listOnly", `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:ListBucket","Resource":"arn:aws:s3:::policy-bucket"}]}`))
-
- identity := &Identity{
- Name: "policy-user",
- Account: &AccountAdmin,
- PolicyNames: []string{"listOnly"},
- }
-
- req := httptest.NewRequest("GET", "http://s3.amazonaws.com/", nil)
- buckets := []*filer_pb.Entry{
- {
- Name: "policy-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{s3_constants.AmzIdentityId: []byte("admin")},
- Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()},
- },
- {
- Name: "other-bucket",
- IsDirectory: true,
- Extended: map[string][]byte{s3_constants.AmzIdentityId: []byte("admin")},
- Attributes: &filer_pb.FuseAttributes{Crtime: time.Now().Unix()},
- },
- }
-
- var visibleBuckets []string
- for _, entry := range buckets {
- isOwner := isBucketOwnedByIdentity(entry, identity)
- if !isOwner {
- if errCode := iam.VerifyActionPermission(req, identity, s3_constants.ACTION_LIST, entry.Name, ""); errCode != s3err.ErrNone {
- continue
- }
- }
- visibleBuckets = append(visibleBuckets, entry.Name)
- }
-
- assert.Equal(t, []string{"policy-bucket"}, visibleBuckets)
-}
diff --git a/weed/s3api/s3api_conditional_headers_test.go b/weed/s3api/s3api_conditional_headers_test.go
deleted file mode 100644
index 9cd220603..000000000
--- a/weed/s3api/s3api_conditional_headers_test.go
+++ /dev/null
@@ -1,984 +0,0 @@
-package s3api
-
-import (
- "bytes"
- "encoding/hex"
- "fmt"
- "net/http"
- "net/url"
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
-)
-
-// TestConditionalHeadersWithExistingObjects tests conditional headers against existing objects
-// This addresses the PR feedback about missing test coverage for object existence scenarios
-func TestConditionalHeadersWithExistingObjects(t *testing.T) {
- bucket := "test-bucket"
- object := "/test-object"
-
- // Mock object with known ETag and modification time
- testObject := &filer_pb.Entry{
- Name: "test-object",
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte("\"abc123\""),
- },
- Attributes: &filer_pb.FuseAttributes{
- Mtime: time.Date(2024, 6, 15, 12, 0, 0, 0, time.UTC).Unix(), // June 15, 2024
- FileSize: 1024, // Add file size
- },
- Chunks: []*filer_pb.FileChunk{
- // Add a mock chunk to make calculateETagFromChunks work
- {
- FileId: "test-file-id",
- Offset: 0,
- Size: 1024,
- },
- },
- }
-
- // Test If-None-Match with existing object
- t.Run("IfNoneMatch_ObjectExists", func(t *testing.T) {
- // Test case 1: If-None-Match=* when object exists (should fail)
- t.Run("Asterisk_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfNoneMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when object exists with If-None-Match=*, got %v", errCode)
- }
- })
-
- // Test case 2: If-None-Match with matching ETag (should fail)
- t.Run("MatchingETag_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfNoneMatch, "\"abc123\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when ETag matches, got %v", errCode)
- }
- })
-
- // Test case 3: If-None-Match with non-matching ETag (should succeed)
- t.Run("NonMatchingETag_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfNoneMatch, "\"xyz789\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when ETag doesn't match, got %v", errCode)
- }
- })
-
- // Test case 4: If-None-Match with multiple ETags, one matching (should fail)
- t.Run("MultipleETags_OneMatches_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfNoneMatch, "\"xyz789\", \"abc123\", \"def456\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when one ETag matches, got %v", errCode)
- }
- })
-
- // Test case 5: If-None-Match with multiple ETags, none matching (should succeed)
- t.Run("MultipleETags_NoneMatch_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfNoneMatch, "\"xyz789\", \"def456\", \"ghi123\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when no ETags match, got %v", errCode)
- }
- })
- })
-
- // Test If-Match with existing object
- t.Run("IfMatch_ObjectExists", func(t *testing.T) {
- // Test case 1: If-Match with matching ETag (should succeed)
- t.Run("MatchingETag_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfMatch, "\"abc123\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when ETag matches, got %v", errCode)
- }
- })
-
- // Test case 2: If-Match with non-matching ETag (should fail)
- t.Run("NonMatchingETag_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfMatch, "\"xyz789\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when ETag doesn't match, got %v", errCode)
- }
- })
-
- // Test case 3: If-Match with multiple ETags, one matching (should succeed)
- t.Run("MultipleETags_OneMatches_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfMatch, "\"xyz789\", \"abc123\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when one ETag matches, got %v", errCode)
- }
- })
-
- // Test case 4: If-Match with wildcard * (should succeed if object exists)
- t.Run("Wildcard_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when If-Match=* and object exists, got %v", errCode)
- }
- })
- })
-
- // Test If-Modified-Since with existing object
- t.Run("IfModifiedSince_ObjectExists", func(t *testing.T) {
- // Test case 1: If-Modified-Since with date before object modification (should succeed)
- t.Run("DateBefore_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- dateBeforeModification := time.Date(2024, 6, 14, 12, 0, 0, 0, time.UTC)
- req.Header.Set(s3_constants.IfModifiedSince, dateBeforeModification.Format(time.RFC1123))
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when object was modified after date, got %v", errCode)
- }
- })
-
- // Test case 2: If-Modified-Since with date after object modification (should fail)
- t.Run("DateAfter_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- dateAfterModification := time.Date(2024, 6, 16, 12, 0, 0, 0, time.UTC)
- req.Header.Set(s3_constants.IfModifiedSince, dateAfterModification.Format(time.RFC1123))
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when object wasn't modified since date, got %v", errCode)
- }
- })
-
- // Test case 3: If-Modified-Since with exact modification date (should fail - not after)
- t.Run("ExactDate_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- exactDate := time.Date(2024, 6, 15, 12, 0, 0, 0, time.UTC)
- req.Header.Set(s3_constants.IfModifiedSince, exactDate.Format(time.RFC1123))
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when object modification time equals header date, got %v", errCode)
- }
- })
- })
-
- // Test If-Unmodified-Since with existing object
- t.Run("IfUnmodifiedSince_ObjectExists", func(t *testing.T) {
- // Test case 1: If-Unmodified-Since with date after object modification (should succeed)
- t.Run("DateAfter_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- dateAfterModification := time.Date(2024, 6, 16, 12, 0, 0, 0, time.UTC)
- req.Header.Set(s3_constants.IfUnmodifiedSince, dateAfterModification.Format(time.RFC1123))
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when object wasn't modified after date, got %v", errCode)
- }
- })
-
- // Test case 2: If-Unmodified-Since with date before object modification (should fail)
- t.Run("DateBefore_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(testObject)
- req := createTestPutRequest(bucket, object, "test content")
- dateBeforeModification := time.Date(2024, 6, 14, 12, 0, 0, 0, time.UTC)
- req.Header.Set(s3_constants.IfUnmodifiedSince, dateBeforeModification.Format(time.RFC1123))
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when object was modified after date, got %v", errCode)
- }
- })
- })
-}
-
-// TestConditionalHeadersForReads tests conditional headers for read operations (GET, HEAD)
-// This implements AWS S3 conditional reads behavior where different conditions return different status codes
-// See: https://docs.aws.amazon.com/AmazonS3/latest/userguide/conditional-reads.html
-func TestConditionalHeadersForReads(t *testing.T) {
- bucket := "test-bucket"
- object := "/test-read-object"
-
- // Mock existing object to test conditional headers against
- existingObject := &filer_pb.Entry{
- Name: "test-read-object",
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte("\"read123\""),
- },
- Attributes: &filer_pb.FuseAttributes{
- Mtime: time.Date(2024, 6, 15, 12, 0, 0, 0, time.UTC).Unix(),
- FileSize: 1024,
- },
- Chunks: []*filer_pb.FileChunk{
- {
- FileId: "read-file-id",
- Offset: 0,
- Size: 1024,
- },
- },
- }
-
- // Test conditional reads with existing object
- t.Run("ConditionalReads_ObjectExists", func(t *testing.T) {
- // Test If-None-Match with existing object (should return 304 Not Modified)
- t.Run("IfNoneMatch_ObjectExists_ShouldReturn304", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfNoneMatch, "\"read123\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrNotModified {
- t.Errorf("Expected ErrNotModified when If-None-Match matches, got %v", errCode)
- }
- })
-
- // Test If-None-Match=* with existing object (should return 304 Not Modified)
- t.Run("IfNoneMatchAsterisk_ObjectExists_ShouldReturn304", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfNoneMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrNotModified {
- t.Errorf("Expected ErrNotModified when If-None-Match=* with existing object, got %v", errCode)
- }
- })
-
- // Test If-None-Match with non-matching ETag (should succeed)
- t.Run("IfNoneMatch_NonMatchingETag_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfNoneMatch, "\"different-etag\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when If-None-Match doesn't match, got %v", errCode)
- }
- })
-
- // Test If-Match with matching ETag (should succeed)
- t.Run("IfMatch_MatchingETag_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfMatch, "\"read123\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when If-Match matches, got %v", errCode)
- }
- })
-
- // Test If-Match with non-matching ETag (should return 412 Precondition Failed)
- t.Run("IfMatch_NonMatchingETag_ShouldReturn412", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfMatch, "\"different-etag\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when If-Match doesn't match, got %v", errCode)
- }
- })
-
- // Test If-Match=* with existing object (should succeed)
- t.Run("IfMatchAsterisk_ObjectExists_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when If-Match=* with existing object, got %v", errCode)
- }
- })
-
- // Test If-Modified-Since (object modified after date - should succeed)
- t.Run("IfModifiedSince_ObjectModifiedAfter_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfModifiedSince, "Sat, 14 Jun 2024 12:00:00 GMT") // Before object mtime
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when object modified after If-Modified-Since date, got %v", errCode)
- }
- })
-
- // Test If-Modified-Since (object not modified since date - should return 304)
- t.Run("IfModifiedSince_ObjectNotModified_ShouldReturn304", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfModifiedSince, "Sun, 16 Jun 2024 12:00:00 GMT") // After object mtime
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrNotModified {
- t.Errorf("Expected ErrNotModified when object not modified since If-Modified-Since date, got %v", errCode)
- }
- })
-
- // Test If-Unmodified-Since (object not modified since date - should succeed)
- t.Run("IfUnmodifiedSince_ObjectNotModified_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfUnmodifiedSince, "Sun, 16 Jun 2024 12:00:00 GMT") // After object mtime
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when object not modified since If-Unmodified-Since date, got %v", errCode)
- }
- })
-
- // Test If-Unmodified-Since (object modified since date - should return 412)
- t.Run("IfUnmodifiedSince_ObjectModified_ShouldReturn412", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfUnmodifiedSince, "Fri, 14 Jun 2024 12:00:00 GMT") // Before object mtime
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when object modified since If-Unmodified-Since date, got %v", errCode)
- }
- })
- })
-
- // Test conditional reads with non-existent object
- t.Run("ConditionalReads_ObjectNotExists", func(t *testing.T) {
- // Test If-None-Match with non-existent object (should succeed)
- t.Run("IfNoneMatch_ObjectNotExists_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfNoneMatch, "\"any-etag\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when object doesn't exist with If-None-Match, got %v", errCode)
- }
- })
-
- // Test If-Match with non-existent object (should return 412)
- t.Run("IfMatch_ObjectNotExists_ShouldReturn412", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfMatch, "\"any-etag\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when object doesn't exist with If-Match, got %v", errCode)
- }
- })
-
- // Test If-Modified-Since with non-existent object (should succeed)
- t.Run("IfModifiedSince_ObjectNotExists_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfModifiedSince, "Sat, 15 Jun 2024 12:00:00 GMT")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when object doesn't exist with If-Modified-Since, got %v", errCode)
- }
- })
-
- // Test If-Unmodified-Since with non-existent object (should return 412)
- t.Run("IfUnmodifiedSince_ObjectNotExists_ShouldReturn412", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object
-
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfUnmodifiedSince, "Sat, 15 Jun 2024 12:00:00 GMT")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if errCode.ErrorCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when object doesn't exist with If-Unmodified-Since, got %v", errCode)
- }
- })
- })
-}
-
-// Helper function to create a GET request for testing
-func createTestGetRequest(bucket, object string) *http.Request {
- return &http.Request{
- Method: "GET",
- Header: make(http.Header),
- URL: &url.URL{
- Path: fmt.Sprintf("/%s/%s", bucket, object),
- },
- }
-}
-
-// TestConditionalHeadersWithNonExistentObjects tests the original scenarios (object doesn't exist)
-func TestConditionalHeadersWithNonExistentObjects(t *testing.T) {
- s3a := NewS3ApiServerForTest()
- if s3a == nil {
- t.Skip("S3ApiServer not available for testing")
- }
-
- bucket := "test-bucket"
- object := "/test-object"
-
- // Test If-None-Match header when object doesn't exist
- t.Run("IfNoneMatch_ObjectDoesNotExist", func(t *testing.T) {
- // Test case 1: If-None-Match=* when object doesn't exist (should return ErrNone)
- t.Run("Asterisk_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object exists
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfNoneMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when object doesn't exist, got %v", errCode)
- }
- })
-
- // Test case 2: If-None-Match with specific ETag when object doesn't exist
- t.Run("SpecificETag_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object exists
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfNoneMatch, "\"some-etag\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when object doesn't exist, got %v", errCode)
- }
- })
- })
-
- // Test If-Match header when object doesn't exist
- t.Run("IfMatch_ObjectDoesNotExist", func(t *testing.T) {
- // Test case 1: If-Match with specific ETag when object doesn't exist (should fail - critical bug fix)
- t.Run("SpecificETag_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object exists
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfMatch, "\"some-etag\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when object doesn't exist with If-Match header, got %v", errCode)
- }
- })
-
- // Test case 2: If-Match with wildcard * when object doesn't exist (should fail)
- t.Run("Wildcard_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object exists
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when object doesn't exist with If-Match=*, got %v", errCode)
- }
- })
- })
-
- // Test date format validation (works regardless of object existence)
- t.Run("DateFormatValidation", func(t *testing.T) {
- // Test case 1: Valid If-Modified-Since date format
- t.Run("IfModifiedSince_ValidFormat", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object exists
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfModifiedSince, time.Now().Format(time.RFC1123))
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone with valid date format, got %v", errCode)
- }
- })
-
- // Test case 2: Invalid If-Modified-Since date format
- t.Run("IfModifiedSince_InvalidFormat", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object exists
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfModifiedSince, "invalid-date")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrInvalidRequest {
- t.Errorf("Expected ErrInvalidRequest for invalid date format, got %v", errCode)
- }
- })
-
- // Test case 3: Invalid If-Unmodified-Since date format
- t.Run("IfUnmodifiedSince_InvalidFormat", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object exists
- req := createTestPutRequest(bucket, object, "test content")
- req.Header.Set(s3_constants.IfUnmodifiedSince, "invalid-date")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrInvalidRequest {
- t.Errorf("Expected ErrInvalidRequest for invalid date format, got %v", errCode)
- }
- })
- })
-
- // Test no conditional headers
- t.Run("NoConditionalHeaders", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No object exists
- req := createTestPutRequest(bucket, object, "test content")
- // Don't set any conditional headers
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when no conditional headers, got %v", errCode)
- }
- })
-}
-
-// TestETagMatching tests the etagMatches helper function
-func TestETagMatching(t *testing.T) {
- s3a := NewS3ApiServerForTest()
- if s3a == nil {
- t.Skip("S3ApiServer not available for testing")
- }
-
- testCases := []struct {
- name string
- headerValue string
- objectETag string
- expected bool
- }{
- {
- name: "ExactMatch",
- headerValue: "\"abc123\"",
- objectETag: "abc123",
- expected: true,
- },
- {
- name: "ExactMatchWithQuotes",
- headerValue: "\"abc123\"",
- objectETag: "\"abc123\"",
- expected: true,
- },
- {
- name: "NoMatch",
- headerValue: "\"abc123\"",
- objectETag: "def456",
- expected: false,
- },
- {
- name: "MultipleETags_FirstMatch",
- headerValue: "\"abc123\", \"def456\"",
- objectETag: "abc123",
- expected: true,
- },
- {
- name: "MultipleETags_SecondMatch",
- headerValue: "\"abc123\", \"def456\"",
- objectETag: "def456",
- expected: true,
- },
- {
- name: "MultipleETags_NoMatch",
- headerValue: "\"abc123\", \"def456\"",
- objectETag: "ghi789",
- expected: false,
- },
- {
- name: "WithSpaces",
- headerValue: " \"abc123\" , \"def456\" ",
- objectETag: "def456",
- expected: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- result := s3a.etagMatches(tc.headerValue, tc.objectETag)
- if result != tc.expected {
- t.Errorf("Expected %v, got %v for headerValue='%s', objectETag='%s'",
- tc.expected, result, tc.headerValue, tc.objectETag)
- }
- })
- }
-}
-
-// TestGetObjectETagWithMd5AndChunks tests the fix for issue #7274
-// When an object has both Attributes.Md5 and multiple chunks, getObjectETag should
-// prefer Attributes.Md5 to match the behavior of HeadObject and filer.ETag
-func TestGetObjectETagWithMd5AndChunks(t *testing.T) {
- s3a := NewS3ApiServerForTest()
- if s3a == nil {
- t.Skip("S3ApiServer not available for testing")
- }
-
- // Create an object with both Md5 and multiple chunks (like in issue #7274)
- // Md5: ZjcmMwrCVGNVgb4HoqHe9g== (base64) = 663726330ac254635581be07a2a1def6 (hex)
- md5HexString := "663726330ac254635581be07a2a1def6"
- md5Bytes, err := hex.DecodeString(md5HexString)
- if err != nil {
- t.Fatalf("failed to decode md5 hex string: %v", err)
- }
-
- entry := &filer_pb.Entry{
- Name: "test-multipart-object",
- Attributes: &filer_pb.FuseAttributes{
- Mtime: time.Now().Unix(),
- FileSize: 5597744,
- Md5: md5Bytes,
- },
- // Two chunks - if we only used ETagChunks, it would return format "hash-2"
- Chunks: []*filer_pb.FileChunk{
- {
- FileId: "chunk1",
- Offset: 0,
- Size: 4194304,
- ETag: "9+yCD2DGwMG5uKwAd+y04Q==",
- },
- {
- FileId: "chunk2",
- Offset: 4194304,
- Size: 1403440,
- ETag: "cs6SVSTgZ8W3IbIrAKmklg==",
- },
- },
- }
-
- // getObjectETag should return the Md5 in hex with quotes
- expectedETag := "\"" + md5HexString + "\""
- actualETag := s3a.getObjectETag(entry)
-
- if actualETag != expectedETag {
- t.Errorf("Expected ETag %s, got %s", expectedETag, actualETag)
- }
-
- // Now test that conditional headers work with this ETag
- bucket := "test-bucket"
- object := "/test-object"
-
- // Test If-Match with the Md5-based ETag (should succeed)
- t.Run("IfMatch_WithMd5BasedETag_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(entry)
- req := createTestGetRequest(bucket, object)
- // Client sends the ETag from HeadObject (without quotes)
- req.Header.Set(s3_constants.IfMatch, md5HexString)
-
- result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if result.ErrorCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when If-Match uses Md5-based ETag, got %v (ETag was %s)", result.ErrorCode, actualETag)
- }
- })
-
- // Test If-Match with chunk-based ETag format (should fail - this was the old incorrect behavior)
- t.Run("IfMatch_WithChunkBasedETag_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(entry)
- req := createTestGetRequest(bucket, object)
- // If we incorrectly calculated ETag from chunks, it would be in format "hash-2"
- req.Header.Set(s3_constants.IfMatch, "123294de680f28bde364b81477549f7d-2")
-
- result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if result.ErrorCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when If-Match uses chunk-based ETag format, got %v", result.ErrorCode)
- }
- })
-}
-
-// TestConditionalHeadersIntegration tests conditional headers with full integration
-func TestConditionalHeadersIntegration(t *testing.T) {
- // This would be a full integration test that requires a running SeaweedFS instance
- t.Skip("Integration test - requires running SeaweedFS instance")
-}
-
-// createTestPutRequest creates a test HTTP PUT request
-func createTestPutRequest(bucket, object, content string) *http.Request {
- req, _ := http.NewRequest("PUT", "/"+bucket+object, bytes.NewReader([]byte(content)))
- req.Header.Set("Content-Type", "application/octet-stream")
-
- // Set up mux vars to simulate the bucket and object extraction
- // In real tests, this would be handled by the gorilla mux router
- return req
-}
-
-// NewS3ApiServerForTest creates a minimal S3ApiServer for testing
-// Note: This is a simplified version for unit testing conditional logic
-func NewS3ApiServerForTest() *S3ApiServer {
- // In a real test environment, this would set up a proper S3ApiServer
- // with filer connection, etc. For unit testing conditional header logic,
- // we create a minimal instance
- return &S3ApiServer{
- option: &S3ApiServerOption{
- BucketsPath: "/buckets",
- },
- bucketConfigCache: NewBucketConfigCache(60 * time.Minute),
- }
-}
-
-// MockEntryGetter implements the simplified EntryGetter interface for testing
-// Only mocks the data access dependency - tests use production getObjectETag and etagMatches
-type MockEntryGetter struct {
- mockEntry *filer_pb.Entry
-}
-
-// Implement only the simplified EntryGetter interface
-func (m *MockEntryGetter) getEntry(parentDirectoryPath, entryName string) (*filer_pb.Entry, error) {
- if m.mockEntry != nil {
- return m.mockEntry, nil
- }
- return nil, filer_pb.ErrNotFound
-}
-
-// createMockEntryGetter creates a mock EntryGetter for testing
-func createMockEntryGetter(mockEntry *filer_pb.Entry) *MockEntryGetter {
- return &MockEntryGetter{
- mockEntry: mockEntry,
- }
-}
-
-// TestConditionalHeadersMultipartUpload tests conditional headers with multipart uploads
-// This verifies AWS S3 compatibility where conditional headers only apply to CompleteMultipartUpload
-func TestConditionalHeadersMultipartUpload(t *testing.T) {
- bucket := "test-bucket"
- object := "/test-multipart-object"
-
- // Mock existing object to test conditional headers against
- existingObject := &filer_pb.Entry{
- Name: "test-multipart-object",
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte("\"existing123\""),
- },
- Attributes: &filer_pb.FuseAttributes{
- Mtime: time.Date(2024, 6, 15, 12, 0, 0, 0, time.UTC).Unix(),
- FileSize: 2048,
- },
- Chunks: []*filer_pb.FileChunk{
- {
- FileId: "existing-file-id",
- Offset: 0,
- Size: 2048,
- },
- },
- }
-
- // Test CompleteMultipartUpload with If-None-Match: * (should fail when object exists)
- t.Run("CompleteMultipartUpload_IfNoneMatchAsterisk_ObjectExists_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- // Create a mock CompleteMultipartUpload request with If-None-Match: *
- req := &http.Request{
- Method: "POST",
- Header: make(http.Header),
- URL: &url.URL{
- RawQuery: "uploadId=test-upload-id",
- },
- }
- req.Header.Set(s3_constants.IfNoneMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when object exists with If-None-Match=*, got %v", errCode)
- }
- })
-
- // Test CompleteMultipartUpload with If-None-Match: * (should succeed when object doesn't exist)
- t.Run("CompleteMultipartUpload_IfNoneMatchAsterisk_ObjectNotExists_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No existing object
-
- req := &http.Request{
- Method: "POST",
- Header: make(http.Header),
- URL: &url.URL{
- RawQuery: "uploadId=test-upload-id",
- },
- }
- req.Header.Set(s3_constants.IfNoneMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when object doesn't exist with If-None-Match=*, got %v", errCode)
- }
- })
-
- // Test CompleteMultipartUpload with If-Match (should succeed when ETag matches)
- t.Run("CompleteMultipartUpload_IfMatch_ETagMatches_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := &http.Request{
- Method: "POST",
- Header: make(http.Header),
- URL: &url.URL{
- RawQuery: "uploadId=test-upload-id",
- },
- }
- req.Header.Set(s3_constants.IfMatch, "\"existing123\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when ETag matches, got %v", errCode)
- }
- })
-
- // Test CompleteMultipartUpload with If-Match (should fail when object doesn't exist)
- t.Run("CompleteMultipartUpload_IfMatch_ObjectNotExists_ShouldFail", func(t *testing.T) {
- getter := createMockEntryGetter(nil) // No existing object
-
- req := &http.Request{
- Method: "POST",
- Header: make(http.Header),
- URL: &url.URL{
- RawQuery: "uploadId=test-upload-id",
- },
- }
- req.Header.Set(s3_constants.IfMatch, "\"any-etag\"")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Errorf("Expected ErrPreconditionFailed when object doesn't exist with If-Match, got %v", errCode)
- }
- })
-
- // Test CompleteMultipartUpload with If-Match wildcard (should succeed when object exists)
- t.Run("CompleteMultipartUpload_IfMatchWildcard_ObjectExists_ShouldSucceed", func(t *testing.T) {
- getter := createMockEntryGetter(existingObject)
-
- req := &http.Request{
- Method: "POST",
- Header: make(http.Header),
- URL: &url.URL{
- RawQuery: "uploadId=test-upload-id",
- },
- }
- req.Header.Set(s3_constants.IfMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Errorf("Expected ErrNone when object exists with If-Match=*, got %v", errCode)
- }
- })
-}
-
-func TestConditionalHeadersTreatDeleteMarkerAsMissing(t *testing.T) {
- bucket := "test-bucket"
- object := "/deleted-object"
- deleteMarkerEntry := &filer_pb.Entry{
- Name: "deleted-object",
- Extended: map[string][]byte{
- s3_constants.ExtDeleteMarkerKey: []byte("true"),
- },
- Attributes: &filer_pb.FuseAttributes{
- Mtime: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC).Unix(),
- },
- }
-
- t.Run("WriteIfNoneMatchAsteriskSucceeds", func(t *testing.T) {
- getter := createMockEntryGetter(deleteMarkerEntry)
- req := createTestPutRequest(bucket, object, "new content")
- req.Header.Set(s3_constants.IfNoneMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrNone {
- t.Fatalf("expected ErrNone for delete marker with If-None-Match=*, got %v", errCode)
- }
- })
-
- t.Run("WriteIfMatchAsteriskFails", func(t *testing.T) {
- getter := createMockEntryGetter(deleteMarkerEntry)
- req := createTestPutRequest(bucket, object, "new content")
- req.Header.Set(s3_constants.IfMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- errCode := s3a.checkConditionalHeadersWithGetter(getter, req, bucket, object)
- if errCode != s3err.ErrPreconditionFailed {
- t.Fatalf("expected ErrPreconditionFailed for delete marker with If-Match=*, got %v", errCode)
- }
- })
-
- t.Run("ReadIfMatchAsteriskFails", func(t *testing.T) {
- getter := createMockEntryGetter(deleteMarkerEntry)
- req := &http.Request{Method: http.MethodGet, Header: make(http.Header)}
- req.Header.Set(s3_constants.IfMatch, "*")
-
- s3a := NewS3ApiServerForTest()
- result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
- if result.ErrorCode != s3err.ErrPreconditionFailed {
- t.Fatalf("expected ErrPreconditionFailed for read against delete marker with If-Match=*, got %v", result.ErrorCode)
- }
- if result.Entry != nil {
- t.Fatalf("expected no entry to be returned for delete marker, got %#v", result.Entry)
- }
- })
-}
diff --git a/weed/s3api/s3api_copy_size_calculation.go b/weed/s3api/s3api_copy_size_calculation.go
index a11c46cdf..eb8bbf0d8 100644
--- a/weed/s3api/s3api_copy_size_calculation.go
+++ b/weed/s3api/s3api_copy_size_calculation.go
@@ -4,7 +4,6 @@ import (
"net/http"
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
)
// CopySizeCalculator handles size calculations for different copy scenarios
@@ -78,12 +77,6 @@ func (calc *CopySizeCalculator) CalculateActualSize() int64 {
return calc.srcSize
}
-// CalculateEncryptedSize calculates the encrypted size for the given encryption type
-func (calc *CopySizeCalculator) CalculateEncryptedSize(encType EncryptionType) int64 {
- // With IV in metadata, encrypted size equals actual size
- return calc.CalculateActualSize()
-}
-
// getSourceEncryptionType determines the encryption type of the source object
func getSourceEncryptionType(metadata map[string][]byte) (EncryptionType, bool) {
if IsSSECEncrypted(metadata) {
@@ -169,22 +162,6 @@ func (calc *CopySizeCalculator) GetSizeTransitionInfo() *SizeTransitionInfo {
return info
}
-// String returns a string representation of the encryption type
-func (e EncryptionType) String() string {
- switch e {
- case EncryptionTypeNone:
- return "None"
- case EncryptionTypeSSEC:
- return s3_constants.SSETypeC
- case EncryptionTypeSSEKMS:
- return s3_constants.SSETypeKMS
- case EncryptionTypeSSES3:
- return s3_constants.SSETypeS3
- default:
- return "Unknown"
- }
-}
-
// OptimizedSizeCalculation provides size calculations optimized for different scenarios
type OptimizedSizeCalculation struct {
Strategy UnifiedCopyStrategy
diff --git a/weed/s3api/s3api_etag_quoting_test.go b/weed/s3api/s3api_etag_quoting_test.go
deleted file mode 100644
index 89223c9b3..000000000
--- a/weed/s3api/s3api_etag_quoting_test.go
+++ /dev/null
@@ -1,167 +0,0 @@
-package s3api
-
-import (
- "fmt"
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
-)
-
-// TestReproIfMatchMismatch tests specifically for the scenario where internal ETag
-// is unquoted (common in SeaweedFS) but client sends quoted ETag in If-Match.
-func TestReproIfMatchMismatch(t *testing.T) {
- bucket := "test-bucket"
- object := "/test-key"
- etagValue := "37b51d194a7513e45b56f6524f2d51f2"
-
- // Scenario 1: Internal ETag is UNQUOTED (stored in Extended), Client sends QUOTED If-Match
- // This mirrors the behavior we enforced in filer_multipart.go
- t.Run("UnquotedInternal_QuotedHeader", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Name: "test-key",
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte(etagValue), // Unquoted
- },
- Attributes: &filer_pb.FuseAttributes{
- Mtime: time.Now().Unix(),
- FileSize: 1024,
- },
- }
-
- getter := &MockEntryGetter{mockEntry: entry}
- req := createTestGetRequest(bucket, object)
- // Client sends quoted ETag
- req.Header.Set(s3_constants.IfMatch, "\""+etagValue+"\"")
-
- s3a := NewS3ApiServerForTest()
- result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
-
- if result.ErrorCode != s3err.ErrNone {
- t.Errorf("Expected success (ErrNone) for unquoted internal ETag and quoted header, got %v. Internal ETag: %s", result.ErrorCode, string(entry.Extended[s3_constants.ExtETagKey]))
- }
- })
-
- // Scenario 2: Internal ETag is QUOTED (stored in Extended), Client sends QUOTED If-Match
- // This handles legacy or mixed content
- t.Run("QuotedInternal_QuotedHeader", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Name: "test-key",
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte("\"" + etagValue + "\""), // Quoted
- },
- Attributes: &filer_pb.FuseAttributes{
- Mtime: time.Now().Unix(),
- FileSize: 1024,
- },
- }
-
- getter := &MockEntryGetter{mockEntry: entry}
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfMatch, "\""+etagValue+"\"")
-
- s3a := NewS3ApiServerForTest()
- result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
-
- if result.ErrorCode != s3err.ErrNone {
- t.Errorf("Expected success (ErrNone) for quoted internal ETag and quoted header, got %v", result.ErrorCode)
- }
- })
-
- // Scenario 3: Internal ETag is from Md5 (QUOTED by getObjectETag), Client sends QUOTED If-Match
- t.Run("Md5Internal_QuotedHeader", func(t *testing.T) {
- // Mock Md5 attribute (16 bytes)
- md5Bytes := make([]byte, 16)
- copy(md5Bytes, []byte("1234567890123456")) // This doesn't match the hex string below, but getObjectETag formats it as hex
-
- // Expected ETag from Md5 is hex string of bytes
- expectedHex := fmt.Sprintf("%x", md5Bytes)
-
- entry := &filer_pb.Entry{
- Name: "test-key",
- Attributes: &filer_pb.FuseAttributes{
- Mtime: time.Now().Unix(),
- FileSize: 1024,
- Md5: md5Bytes,
- },
- }
-
- getter := &MockEntryGetter{mockEntry: entry}
- req := createTestGetRequest(bucket, object)
- req.Header.Set(s3_constants.IfMatch, "\""+expectedHex+"\"")
-
- s3a := NewS3ApiServerForTest()
- result := s3a.checkConditionalHeadersForReadsWithGetter(getter, req, bucket, object)
-
- if result.ErrorCode != s3err.ErrNone {
- t.Errorf("Expected success (ErrNone) for Md5 internal ETag and quoted header, got %v", result.ErrorCode)
- }
- })
-
- // Test getObjectETag specifically ensuring it returns quoted strings
- t.Run("getObjectETag_ShouldReturnQuoted", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Name: "test-key",
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte("unquoted-etag"),
- },
- }
-
- s3a := NewS3ApiServerForTest()
- etag := s3a.getObjectETag(entry)
-
- expected := "\"unquoted-etag\""
- if etag != expected {
- t.Errorf("Expected quoted ETag %s, got %s", expected, etag)
- }
- })
-
- // Test getObjectETag fallback when Extended ETag is present but empty
- t.Run("getObjectETag_EmptyExtended_ShouldFallback", func(t *testing.T) {
- md5Bytes := []byte("1234567890123456")
- expectedHex := fmt.Sprintf("\"%x\"", md5Bytes)
-
- entry := &filer_pb.Entry{
- Name: "test-key-fallback",
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte(""), // Present but empty
- },
- Attributes: &filer_pb.FuseAttributes{
- Mtime: time.Now().Unix(),
- FileSize: 1024,
- Md5: md5Bytes,
- },
- }
-
- s3a := NewS3ApiServerForTest()
- etag := s3a.getObjectETag(entry)
-
- if etag != expectedHex {
- t.Errorf("Expected fallback ETag %s, got %s", expectedHex, etag)
- }
- })
-
- // Test newListEntry ETag behavior
- t.Run("newListEntry_ShouldReturnQuoted", func(t *testing.T) {
- entry := &filer_pb.Entry{
- Name: "test-key",
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte("unquoted-etag"),
- },
- Attributes: &filer_pb.FuseAttributes{
- Mtime: time.Now().Unix(),
- FileSize: 1024,
- },
- }
-
- s3a := NewS3ApiServerForTest()
- listEntry := newListEntry(s3a, entry, "", "bucket/dir", "test-key", "bucket/", false, false, false)
-
- expected := "\"unquoted-etag\""
- if listEntry.ETag != expected {
- t.Errorf("Expected quoted ETag %s, got %s", expected, listEntry.ETag)
- }
- })
-}
diff --git a/weed/s3api/s3api_key_rotation.go b/weed/s3api/s3api_key_rotation.go
deleted file mode 100644
index c99c13415..000000000
--- a/weed/s3api/s3api_key_rotation.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package s3api
-
-import (
- "net/http"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
-)
-
-// IsSameObjectCopy determines if this is a same-object copy operation
-func IsSameObjectCopy(r *http.Request, srcBucket, srcObject, dstBucket, dstObject string) bool {
- return srcBucket == dstBucket && srcObject == dstObject
-}
-
-// NeedsKeyRotation determines if the copy operation requires key rotation
-func NeedsKeyRotation(entry *filer_pb.Entry, r *http.Request) bool {
- // Check for SSE-C key rotation
- if IsSSECEncrypted(entry.Extended) && IsSSECRequest(r) {
- return true // Assume different keys for safety
- }
-
- // Check for SSE-KMS key rotation
- if IsSSEKMSEncrypted(entry.Extended) && IsSSEKMSRequest(r) {
- srcKeyID, _ := GetSourceSSEKMSInfo(entry.Extended)
- dstKeyID := r.Header.Get(s3_constants.AmzServerSideEncryptionAwsKmsKeyId)
- return srcKeyID != dstKeyID
- }
-
- return false
-}
diff --git a/weed/s3api/s3api_object_handlers.go b/weed/s3api/s3api_object_handlers.go
index 7a7538214..71d6bc26d 100644
--- a/weed/s3api/s3api_object_handlers.go
+++ b/weed/s3api/s3api_object_handlers.go
@@ -268,15 +268,6 @@ func mimeDetect(r *http.Request, dataReader io.Reader) io.ReadCloser {
return io.NopCloser(dataReader)
}
-func urlEscapeObject(object string) string {
- normalized := s3_constants.NormalizeObjectKey(object)
- // Ensure leading slash for filer paths
- if normalized != "" && !strings.HasPrefix(normalized, "/") {
- normalized = "/" + normalized
- }
- return urlPathEscape(normalized)
-}
-
func entryUrlEncode(dir string, entry string, encodingTypeUrl bool) (dirName string, entryName string, prefix string) {
if !encodingTypeUrl {
return dir, entry, entry
@@ -2895,59 +2886,6 @@ func (m *MultipartSSEReader) Close() error {
return lastErr
}
-// Read implements the io.Reader interface for SSERangeReader
-func (r *SSERangeReader) Read(p []byte) (n int, err error) {
- // Skip bytes iteratively (no recursion) until we reach the offset
- for r.skipped < r.offset {
- skipNeeded := r.offset - r.skipped
-
- // Lazily allocate skip buffer on first use, reuse thereafter
- if r.skipBuf == nil {
- // Use a fixed 32KB buffer for skipping (avoids per-call allocation)
- r.skipBuf = make([]byte, 32*1024)
- }
-
- // Determine how much to skip in this iteration
- bufSize := int64(len(r.skipBuf))
- if skipNeeded < bufSize {
- bufSize = skipNeeded
- }
-
- skipRead, skipErr := r.reader.Read(r.skipBuf[:bufSize])
- r.skipped += int64(skipRead)
-
- if skipErr != nil {
- return 0, skipErr
- }
-
- // Guard against infinite loop: io.Reader may return (0, nil)
- // which is permitted by the interface contract for non-empty buffers.
- // If we get zero bytes without an error, treat it as an unexpected EOF.
- if skipRead == 0 {
- return 0, io.ErrUnexpectedEOF
- }
- }
-
- // If we have a remaining limit and it's reached
- if r.remaining == 0 {
- return 0, io.EOF
- }
-
- // Calculate how much to read
- readSize := len(p)
- if r.remaining > 0 && int64(readSize) > r.remaining {
- readSize = int(r.remaining)
- }
-
- // Read the data
- n, err = r.reader.Read(p[:readSize])
- if r.remaining > 0 {
- r.remaining -= int64(n)
- }
-
- return n, err
-}
-
// PartBoundaryInfo holds information about a part's chunk boundaries
type PartBoundaryInfo struct {
PartNumber int `json:"part"`
diff --git a/weed/s3api/s3api_object_handlers_copy.go b/weed/s3api/s3api_object_handlers_copy.go
index 58f18a038..cac24c946 100644
--- a/weed/s3api/s3api_object_handlers_copy.go
+++ b/weed/s3api/s3api_object_handlers_copy.go
@@ -14,8 +14,6 @@ import (
"strings"
"time"
- "modernc.org/strutil"
-
"github.com/seaweedfs/seaweedfs/weed/filer"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/operation"
@@ -797,58 +795,6 @@ func replaceDirective(reqHeader http.Header) (replaceMeta, replaceTagging bool)
return reqHeader.Get(s3_constants.AmzUserMetaDirective) == DirectiveReplace, reqHeader.Get(s3_constants.AmzObjectTaggingDirective) == DirectiveReplace
}
-func processMetadata(reqHeader, existing http.Header, replaceMeta, replaceTagging bool, getTags func(parentDirectoryPath string, entryName string) (tags map[string]string, err error), dir, name string) (err error) {
- if sc := reqHeader.Get(s3_constants.AmzStorageClass); len(sc) == 0 {
- if sc := existing.Get(s3_constants.AmzStorageClass); len(sc) > 0 {
- reqHeader.Set(s3_constants.AmzStorageClass, sc)
- }
- }
-
- if !replaceMeta {
- for header := range reqHeader {
- if strings.HasPrefix(header, s3_constants.AmzUserMetaPrefix) {
- delete(reqHeader, header)
- }
- }
- for k, v := range existing {
- if strings.HasPrefix(k, s3_constants.AmzUserMetaPrefix) {
- reqHeader[k] = v
- }
- }
- }
-
- if !replaceTagging {
- for header, _ := range reqHeader {
- if strings.HasPrefix(header, s3_constants.AmzObjectTagging) {
- delete(reqHeader, header)
- }
- }
-
- found := false
- for k, _ := range existing {
- if strings.HasPrefix(k, s3_constants.AmzObjectTaggingPrefix) {
- found = true
- break
- }
- }
-
- if found {
- tags, err := getTags(dir, name)
- if err != nil {
- return err
- }
-
- var tagArr []string
- for k, v := range tags {
- tagArr = append(tagArr, fmt.Sprintf("%s=%s", k, v))
- }
- tagStr := strutil.JoinFields(tagArr, "&")
- reqHeader.Set(s3_constants.AmzObjectTagging, tagStr)
- }
- }
- return
-}
-
func processMetadataBytes(reqHeader http.Header, existing map[string][]byte, replaceMeta, replaceTagging bool) (metadata map[string][]byte, err error) {
metadata = make(map[string][]byte)
@@ -2632,13 +2578,6 @@ func cleanupVersioningMetadata(metadata map[string][]byte) {
delete(metadata, s3_constants.ExtETagKey)
}
-// shouldCreateVersionForCopy determines whether a version should be created during a copy operation
-// based on the destination bucket's versioning state.
-// Returns true only if versioning is explicitly "Enabled", not "Suspended" or unconfigured.
-func shouldCreateVersionForCopy(versioningState string) bool {
- return versioningState == s3_constants.VersioningEnabled
-}
-
// isOrphanedSSES3Header checks if a header is an orphaned SSE-S3 encryption header.
// An orphaned header is one where the encryption indicator exists but the actual key is missing.
// This can happen when an object was previously encrypted but then copied without encryption,
diff --git a/weed/s3api/s3api_object_handlers_copy_test.go b/weed/s3api/s3api_object_handlers_copy_test.go
deleted file mode 100644
index 93d1475cd..000000000
--- a/weed/s3api/s3api_object_handlers_copy_test.go
+++ /dev/null
@@ -1,760 +0,0 @@
-package s3api
-
-import (
- "fmt"
- "net/http"
- "reflect"
- "sort"
- "strings"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/util"
-)
-
-type H map[string]string
-
-func (h H) String() string {
- pairs := make([]string, 0, len(h))
- for k, v := range h {
- pairs = append(pairs, fmt.Sprintf("%s : %s", k, v))
- }
- sort.Strings(pairs)
- join := strings.Join(pairs, "\n")
- return "\n" + join + "\n"
-}
-
-var processMetadataTestCases = []struct {
- caseId int
- request H
- existing H
- getTags H
- want H
-}{
- {
- 201,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-Type": "existing",
- },
- H{
- "A": "B",
- "a": "b",
- "type": "existing",
- },
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging": "A=B&a=b&type=existing",
- },
- },
- {
- 202,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-Type": "existing",
- },
- H{
- "A": "B",
- "a": "b",
- "type": "existing",
- },
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=existing",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- },
- },
-
- {
- 203,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-Type": "existing",
- },
- H{
- "A": "B",
- "a": "b",
- "type": "existing",
- },
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- },
-
- {
- 204,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-Type": "existing",
- },
- H{
- "A": "B",
- "a": "b",
- "type": "existing",
- },
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- },
-
- {
- 205,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- H{},
- H{},
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- },
-
- {
- 206,
- H{
- "User-Agent": "firefox",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-Type": "existing",
- },
- H{
- "A": "B",
- "a": "b",
- "type": "existing",
- },
- H{
- "User-Agent": "firefox",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- },
-
- {
- 207,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-Type": "existing",
- },
- H{
- "A": "B",
- "a": "b",
- "type": "existing",
- },
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- },
-}
-var processMetadataBytesTestCases = []struct {
- caseId int
- request H
- existing H
- want H
-}{
- {
- 101,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-type": "existing",
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-type": "existing",
- },
- },
-
- {
- 102,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-type": "existing",
- },
- H{
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-type": "existing",
- },
- },
-
- {
- 103,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-type": "existing",
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-type": "request",
- },
- },
-
- {
- 104,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-type": "existing",
- },
- H{
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-type": "request",
- },
- },
-
- {
- 105,
- H{
- "User-Agent": "firefox",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- H{
- "X-Amz-Meta-My-Meta": "existing",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-type": "existing",
- },
- H{},
- },
-
- {
- 107,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- H{},
- H{
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging-A": "B",
- "X-Amz-Tagging-a": "b",
- "X-Amz-Tagging-type": "request",
- },
- },
-
- {
- 108,
- H{
- "User-Agent": "firefox",
- "X-Amz-Meta-My-Meta": "request",
- "X-Amz-Tagging": "A=B&a=b&type=request*",
- s3_constants.AmzUserMetaDirective: DirectiveReplace,
- s3_constants.AmzObjectTaggingDirective: DirectiveReplace,
- },
- H{},
- H{},
- },
-}
-
-func TestProcessMetadata(t *testing.T) {
- for _, tc := range processMetadataTestCases {
- reqHeader := transferHToHeader(tc.request)
- existing := transferHToHeader(tc.existing)
- replaceMeta, replaceTagging := replaceDirective(reqHeader)
- err := processMetadata(reqHeader, existing, replaceMeta, replaceTagging, func(_ string, _ string) (tags map[string]string, err error) {
- return tc.getTags, nil
- }, "", "")
- if err != nil {
- t.Error(err)
- }
-
- result := transferHeaderToH(reqHeader)
- fmtTagging(result, tc.want)
-
- if !reflect.DeepEqual(result, tc.want) {
- t.Error(fmt.Errorf("\n### CaseID: %d ###"+
- "\nRequest:%v"+
- "\nExisting:%v"+
- "\nGetTags:%v"+
- "\nWant:%v"+
- "\nActual:%v",
- tc.caseId, tc.request, tc.existing, tc.getTags, tc.want, result))
- }
- }
-}
-
-func TestProcessMetadataBytes(t *testing.T) {
- for _, tc := range processMetadataBytesTestCases {
- reqHeader := transferHToHeader(tc.request)
- existing := transferHToBytesArr(tc.existing)
- replaceMeta, replaceTagging := replaceDirective(reqHeader)
- extends, _ := processMetadataBytes(reqHeader, existing, replaceMeta, replaceTagging)
-
- result := transferBytesArrToH(extends)
- fmtTagging(result, tc.want)
-
- if !reflect.DeepEqual(result, tc.want) {
- t.Error(fmt.Errorf("\n### CaseID: %d ###"+
- "\nRequest:%v"+
- "\nExisting:%v"+
- "\nWant:%v"+
- "\nActual:%v",
- tc.caseId, tc.request, tc.existing, tc.want, result))
- }
- }
-}
-
-func TestMergeCopyMetadataPreservesInternalFields(t *testing.T) {
- existing := map[string][]byte{
- s3_constants.SeaweedFSSSEKMSKey: []byte("kms-secret"),
- s3_constants.SeaweedFSSSEIV: []byte("iv"),
- "X-Amz-Meta-Old": []byte("old"),
- "X-Amz-Tagging-Old": []byte("old-tag"),
- s3_constants.AmzStorageClass: []byte("STANDARD"),
- }
- updated := map[string][]byte{
- "X-Amz-Meta-New": []byte("new"),
- "X-Amz-Tagging-New": []byte("new-tag"),
- s3_constants.AmzStorageClass: []byte("GLACIER"),
- }
-
- merged := mergeCopyMetadata(existing, updated)
-
- if got := string(merged[s3_constants.SeaweedFSSSEKMSKey]); got != "kms-secret" {
- t.Fatalf("expected internal KMS key to be preserved, got %q", got)
- }
- if got := string(merged[s3_constants.SeaweedFSSSEIV]); got != "iv" {
- t.Fatalf("expected internal IV to be preserved, got %q", got)
- }
- if _, ok := merged["X-Amz-Meta-Old"]; ok {
- t.Fatalf("expected stale user metadata to be removed, got %#v", merged)
- }
- if _, ok := merged["X-Amz-Tagging-Old"]; ok {
- t.Fatalf("expected stale tagging metadata to be removed, got %#v", merged)
- }
- if got := string(merged["X-Amz-Meta-New"]); got != "new" {
- t.Fatalf("expected replacement user metadata to be applied, got %q", got)
- }
- if got := string(merged["X-Amz-Tagging-New"]); got != "new-tag" {
- t.Fatalf("expected replacement tagging metadata to be applied, got %q", got)
- }
- if got := string(merged[s3_constants.AmzStorageClass]); got != "GLACIER" {
- t.Fatalf("expected storage class to be updated, got %q", got)
- }
-}
-
-func TestCopyEntryETagPrefersStoredETag(t *testing.T) {
- entry := &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte("\"stored-etag\""),
- },
- Attributes: &filer_pb.FuseAttributes{},
- }
-
- if got := copyEntryETag(util.FullPath("/buckets/test-bucket/object.txt"), entry); got != "\"stored-etag\"" {
- t.Fatalf("copyEntryETag() = %q, want %q", got, "\"stored-etag\"")
- }
-}
-
-func fmtTagging(maps ...map[string]string) {
- for _, m := range maps {
- if tagging := m[s3_constants.AmzObjectTagging]; len(tagging) > 0 {
- split := strings.Split(tagging, "&")
- sort.Strings(split)
- m[s3_constants.AmzObjectTagging] = strings.Join(split, "&")
- }
- }
-}
-
-func transferHToHeader(data map[string]string) http.Header {
- header := http.Header{}
- for k, v := range data {
- header.Add(k, v)
- }
- return header
-}
-
-func transferHToBytesArr(data map[string]string) map[string][]byte {
- m := make(map[string][]byte, len(data))
- for k, v := range data {
- m[k] = []byte(v)
- }
- return m
-}
-
-func transferBytesArrToH(data map[string][]byte) H {
- m := make(map[string]string, len(data))
- for k, v := range data {
- m[k] = string(v)
- }
- return m
-}
-
-func transferHeaderToH(data map[string][]string) H {
- m := make(map[string]string, len(data))
- for k, v := range data {
- m[k] = v[len(v)-1]
- }
- return m
-}
-
-// TestShouldCreateVersionForCopy tests the production function that determines
-// whether a version should be created during a copy operation.
-// This addresses issue #7505 where copies were incorrectly creating versions for non-versioned buckets.
-func TestShouldCreateVersionForCopy(t *testing.T) {
- testCases := []struct {
- name string
- versioningState string
- expectedResult bool
- description string
- }{
- {
- name: "VersioningEnabled",
- versioningState: s3_constants.VersioningEnabled,
- expectedResult: true,
- description: "Should create versions in .versions/ directory when versioning is Enabled",
- },
- {
- name: "VersioningSuspended",
- versioningState: s3_constants.VersioningSuspended,
- expectedResult: false,
- description: "Should NOT create versions when versioning is Suspended",
- },
- {
- name: "VersioningNotConfigured",
- versioningState: "",
- expectedResult: false,
- description: "Should NOT create versions when versioning is not configured",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- // Call the actual production function
- result := shouldCreateVersionForCopy(tc.versioningState)
-
- if result != tc.expectedResult {
- t.Errorf("Test case %s failed: %s\nExpected shouldCreateVersionForCopy(%q)=%v, got %v",
- tc.name, tc.description, tc.versioningState, tc.expectedResult, result)
- }
- })
- }
-}
-
-// TestCleanupVersioningMetadata tests the production function that removes versioning metadata.
-// This ensures objects copied to non-versioned buckets don't carry invalid versioning metadata
-// or stale ETag values from the source.
-func TestCleanupVersioningMetadata(t *testing.T) {
- testCases := []struct {
- name string
- sourceMetadata map[string][]byte
- expectedKeys []string // Keys that should be present after cleanup
- removedKeys []string // Keys that should be removed
- }{
- {
- name: "RemovesAllVersioningMetadata",
- sourceMetadata: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte("version-123"),
- s3_constants.ExtDeleteMarkerKey: []byte("false"),
- s3_constants.ExtIsLatestKey: []byte("true"),
- s3_constants.ExtETagKey: []byte("\"abc123\""),
- "X-Amz-Meta-Custom": []byte("value"),
- },
- expectedKeys: []string{"X-Amz-Meta-Custom"},
- removedKeys: []string{s3_constants.ExtVersionIdKey, s3_constants.ExtDeleteMarkerKey, s3_constants.ExtIsLatestKey, s3_constants.ExtETagKey},
- },
- {
- name: "HandlesEmptyMetadata",
- sourceMetadata: map[string][]byte{},
- expectedKeys: []string{},
- removedKeys: []string{s3_constants.ExtVersionIdKey, s3_constants.ExtDeleteMarkerKey, s3_constants.ExtIsLatestKey, s3_constants.ExtETagKey},
- },
- {
- name: "PreservesNonVersioningMetadata",
- sourceMetadata: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte("version-456"),
- s3_constants.ExtETagKey: []byte("\"def456\""),
- "X-Amz-Meta-Custom": []byte("value1"),
- "X-Amz-Meta-Another": []byte("value2"),
- s3_constants.ExtIsLatestKey: []byte("true"),
- },
- expectedKeys: []string{"X-Amz-Meta-Custom", "X-Amz-Meta-Another"},
- removedKeys: []string{s3_constants.ExtVersionIdKey, s3_constants.ExtETagKey, s3_constants.ExtIsLatestKey},
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- // Create a copy of the source metadata
- dstMetadata := make(map[string][]byte)
- for k, v := range tc.sourceMetadata {
- dstMetadata[k] = v
- }
-
- // Call the actual production function
- cleanupVersioningMetadata(dstMetadata)
-
- // Verify expected keys are present
- for _, key := range tc.expectedKeys {
- if _, exists := dstMetadata[key]; !exists {
- t.Errorf("Expected key %s to be present in destination metadata", key)
- }
- }
-
- // Verify removed keys are absent
- for _, key := range tc.removedKeys {
- if _, exists := dstMetadata[key]; exists {
- t.Errorf("Expected key %s to be removed from destination metadata, but it's still present", key)
- }
- }
-
- // Verify the count matches to ensure no extra keys are present
- if len(dstMetadata) != len(tc.expectedKeys) {
- t.Errorf("Expected %d metadata keys, but got %d. Extra keys might be present.", len(tc.expectedKeys), len(dstMetadata))
- }
- })
- }
-}
-
-// TestCopyVersioningIntegration validates the metadata shaping that happens
-// before copy finalization for each destination versioning mode.
-func TestCopyVersioningIntegration(t *testing.T) {
- testCases := []struct {
- name string
- versioningState string
- sourceMetadata map[string][]byte
- expectVersionPath bool
- expectMetadataKeys []string
- }{
- {
- name: "EnabledPreservesMetadata",
- versioningState: s3_constants.VersioningEnabled,
- sourceMetadata: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte("v123"),
- "X-Amz-Meta-Custom": []byte("value"),
- },
- expectVersionPath: true,
- expectMetadataKeys: []string{
- s3_constants.ExtVersionIdKey,
- "X-Amz-Meta-Custom",
- },
- },
- {
- name: "SuspendedCleansVersionMetadataBeforeFinalize",
- versioningState: s3_constants.VersioningSuspended,
- sourceMetadata: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte("v123"),
- "X-Amz-Meta-Custom": []byte("value"),
- },
- expectVersionPath: false,
- expectMetadataKeys: []string{
- "X-Amz-Meta-Custom",
- },
- },
- {
- name: "NotConfiguredCleansMetadata",
- versioningState: "",
- sourceMetadata: map[string][]byte{
- s3_constants.ExtVersionIdKey: []byte("v123"),
- s3_constants.ExtDeleteMarkerKey: []byte("false"),
- "X-Amz-Meta-Custom": []byte("value"),
- },
- expectVersionPath: false,
- expectMetadataKeys: []string{
- "X-Amz-Meta-Custom",
- },
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- // Test version creation decision using production function
- shouldCreateVersion := shouldCreateVersionForCopy(tc.versioningState)
- if shouldCreateVersion != tc.expectVersionPath {
- t.Errorf("shouldCreateVersionForCopy(%q) = %v, expected %v",
- tc.versioningState, shouldCreateVersion, tc.expectVersionPath)
- }
-
- // Test metadata cleanup using production function
- metadata := make(map[string][]byte)
- for k, v := range tc.sourceMetadata {
- metadata[k] = v
- }
-
- if !shouldCreateVersion {
- cleanupVersioningMetadata(metadata)
- }
-
- // Verify only expected keys remain
- for _, expectedKey := range tc.expectMetadataKeys {
- if _, exists := metadata[expectedKey]; !exists {
- t.Errorf("Expected key %q to be present in metadata", expectedKey)
- }
- }
-
- // Verify the count matches (no extra keys)
- if len(metadata) != len(tc.expectMetadataKeys) {
- t.Errorf("Expected %d metadata keys, got %d", len(tc.expectMetadataKeys), len(metadata))
- }
- })
- }
-}
-
-// TestIsOrphanedSSES3Header tests detection of orphaned SSE-S3 headers.
-// This is a regression test for GitHub issue #7562 where copying from an
-// encrypted bucket to an unencrypted bucket left behind the encryption header
-// without the actual key, causing subsequent copy operations to fail.
-func TestIsOrphanedSSES3Header(t *testing.T) {
- testCases := []struct {
- name string
- headerKey string
- metadata map[string][]byte
- expected bool
- }{
- {
- name: "Not an encryption header",
- headerKey: "X-Amz-Meta-Custom",
- metadata: map[string][]byte{
- "X-Amz-Meta-Custom": []byte("value"),
- },
- expected: false,
- },
- {
- name: "SSE-S3 header with key present (valid)",
- headerKey: s3_constants.AmzServerSideEncryption,
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- s3_constants.SeaweedFSSSES3Key: []byte("key-data"),
- },
- expected: false,
- },
- {
- name: "SSE-S3 header without key (orphaned - GitHub #7562)",
- headerKey: s3_constants.AmzServerSideEncryption,
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- },
- expected: true,
- },
- {
- name: "SSE-KMS header (not SSE-S3)",
- headerKey: s3_constants.AmzServerSideEncryption,
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("aws:kms"),
- },
- expected: false,
- },
- {
- name: "Different header key entirely",
- headerKey: s3_constants.SeaweedFSSSES3Key,
- metadata: map[string][]byte{
- s3_constants.AmzServerSideEncryption: []byte("AES256"),
- },
- expected: false,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- result := isOrphanedSSES3Header(tc.headerKey, tc.metadata)
- if result != tc.expected {
- t.Errorf("isOrphanedSSES3Header(%q, metadata) = %v, expected %v",
- tc.headerKey, result, tc.expected)
- }
- })
- }
-}
diff --git a/weed/s3api/s3api_object_handlers_delete_test.go b/weed/s3api/s3api_object_handlers_delete_test.go
deleted file mode 100644
index 5596d6130..000000000
--- a/weed/s3api/s3api_object_handlers_delete_test.go
+++ /dev/null
@@ -1,119 +0,0 @@
-package s3api
-
-import (
- "encoding/xml"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
-)
-
-func TestValidateDeleteIfMatch(t *testing.T) {
- s3a := NewS3ApiServerForTest()
- existingEntry := &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.ExtETagKey: []byte("\"abc123\""),
- },
- }
- deleteMarkerEntry := &filer_pb.Entry{
- Extended: map[string][]byte{
- s3_constants.ExtDeleteMarkerKey: []byte("true"),
- },
- }
-
- testCases := []struct {
- name string
- entry *filer_pb.Entry
- ifMatch string
- missingCode s3err.ErrorCode
- expected s3err.ErrorCode
- }{
- {
- name: "matching etag succeeds",
- entry: existingEntry,
- ifMatch: "\"abc123\"",
- missingCode: s3err.ErrPreconditionFailed,
- expected: s3err.ErrNone,
- },
- {
- name: "wildcard succeeds for existing entry",
- entry: existingEntry,
- ifMatch: "*",
- missingCode: s3err.ErrPreconditionFailed,
- expected: s3err.ErrNone,
- },
- {
- name: "mismatched etag fails",
- entry: existingEntry,
- ifMatch: "\"other\"",
- missingCode: s3err.ErrPreconditionFailed,
- expected: s3err.ErrPreconditionFailed,
- },
- {
- name: "missing current object fails single delete",
- entry: nil,
- ifMatch: "*",
- missingCode: s3err.ErrPreconditionFailed,
- expected: s3err.ErrPreconditionFailed,
- },
- {
- name: "missing current object returns no such key for batch delete",
- entry: nil,
- ifMatch: "*",
- missingCode: s3err.ErrNoSuchKey,
- expected: s3err.ErrNoSuchKey,
- },
- {
- name: "current delete marker behaves like missing object",
- entry: normalizeConditionalTargetEntry(deleteMarkerEntry),
- ifMatch: "*",
- missingCode: s3err.ErrPreconditionFailed,
- expected: s3err.ErrPreconditionFailed,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- if errCode := s3a.validateDeleteIfMatch(tc.entry, tc.ifMatch, tc.missingCode); errCode != tc.expected {
- t.Fatalf("validateDeleteIfMatch() = %v, want %v", errCode, tc.expected)
- }
- })
- }
-}
-
-func TestDeleteObjectsRequestUnmarshalConditionalETags(t *testing.T) {
- var req DeleteObjectsRequest
- body := []byte(`
-
- true
-
-
-`)
-
- if err := xml.Unmarshal(body, &req); err != nil {
- t.Fatalf("xml.Unmarshal() error = %v", err)
- }
- if !req.Quiet {
- t.Fatalf("expected Quiet=true")
- }
- if len(req.Objects) != 2 {
- t.Fatalf("expected 2 objects, got %d", len(req.Objects))
- }
- if req.Objects[0].ETag != "*" {
- t.Fatalf("expected first object ETag to be '*', got %q", req.Objects[0].ETag)
- }
- if req.Objects[1].ETag != "\"abc123\"" {
- t.Fatalf("expected second object ETag to preserve quotes, got %q", req.Objects[1].ETag)
- }
- if req.Objects[1].VersionId != "3HL4kqCxf3vjVBH40Nrjfkd" {
- t.Fatalf("expected second object VersionId to unmarshal, got %q", req.Objects[1].VersionId)
- }
-}
diff --git a/weed/s3api/s3api_object_handlers_put.go b/weed/s3api/s3api_object_handlers_put.go
index 805d63133..adda8b1c7 100644
--- a/weed/s3api/s3api_object_handlers_put.go
+++ b/weed/s3api/s3api_object_handlers_put.go
@@ -1859,28 +1859,6 @@ func (s3a *S3ApiServer) validateConditionalHeaders(r *http.Request, headers cond
return s3err.ErrNone
}
-// checkConditionalHeadersWithGetter is a testable method that accepts a simple EntryGetter
-// Uses the production getObjectETag and etagMatches methods to ensure testing of real logic
-func (s3a *S3ApiServer) checkConditionalHeadersWithGetter(getter EntryGetter, r *http.Request, bucket, object string) s3err.ErrorCode {
- headers, errCode := parseConditionalHeaders(r)
- if errCode != s3err.ErrNone {
- return errCode
- }
- // Get object entry for conditional checks.
- bucketDir := "/buckets/" + bucket
- entry, entryErr := getter.getEntry(bucketDir, object)
- if entryErr != nil {
- if errors.Is(entryErr, filer_pb.ErrNotFound) {
- entry = nil
- } else {
- glog.Errorf("checkConditionalHeadersWithGetter: failed to get entry for %s/%s: %v", bucket, object, entryErr)
- return s3err.ErrInternalError
- }
- }
-
- return s3a.validateConditionalHeaders(r, headers, entry, bucket, object)
-}
-
// checkConditionalHeaders is the production method that uses the S3ApiServer as EntryGetter
func (s3a *S3ApiServer) checkConditionalHeaders(r *http.Request, bucket, object string) s3err.ErrorCode {
// Fast path: if no conditional headers are present, skip object resolution entirely.
@@ -2002,28 +1980,6 @@ func (s3a *S3ApiServer) validateConditionalHeadersForReads(r *http.Request, head
return ConditionalHeaderResult{ErrorCode: s3err.ErrNone, Entry: entry}
}
-// checkConditionalHeadersForReadsWithGetter is a testable method for read operations
-// Uses the production getObjectETag and etagMatches methods to ensure testing of real logic
-func (s3a *S3ApiServer) checkConditionalHeadersForReadsWithGetter(getter EntryGetter, r *http.Request, bucket, object string) ConditionalHeaderResult {
- headers, errCode := parseConditionalHeaders(r)
- if errCode != s3err.ErrNone {
- return ConditionalHeaderResult{ErrorCode: errCode}
- }
- // Get object entry for conditional checks.
- bucketDir := "/buckets/" + bucket
- entry, entryErr := getter.getEntry(bucketDir, object)
- if entryErr != nil {
- if errors.Is(entryErr, filer_pb.ErrNotFound) {
- entry = nil
- } else {
- glog.Errorf("checkConditionalHeadersForReadsWithGetter: failed to get entry for %s/%s: %v", bucket, object, entryErr)
- return ConditionalHeaderResult{ErrorCode: s3err.ErrInternalError}
- }
- }
-
- return s3a.validateConditionalHeadersForReads(r, headers, entry, bucket, object)
-}
-
// checkConditionalHeadersForReads is the production method that uses the S3ApiServer as EntryGetter
func (s3a *S3ApiServer) checkConditionalHeadersForReads(r *http.Request, bucket, object string) ConditionalHeaderResult {
// Fast path: if no conditional headers are present, skip object resolution entirely.
diff --git a/weed/s3api/s3api_object_handlers_put_test.go b/weed/s3api/s3api_object_handlers_put_test.go
deleted file mode 100644
index a5646bff7..000000000
--- a/weed/s3api/s3api_object_handlers_put_test.go
+++ /dev/null
@@ -1,341 +0,0 @@
-package s3api
-
-import (
- "encoding/xml"
- "errors"
- "fmt"
- "net/http"
- "net/http/httptest"
- "strings"
- "sync"
- "testing"
-
- "github.com/gorilla/mux"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
- weed_server "github.com/seaweedfs/seaweedfs/weed/server"
- "github.com/seaweedfs/seaweedfs/weed/util/constants"
-)
-
-func TestFilerErrorToS3Error(t *testing.T) {
- tests := []struct {
- name string
- err error
- expectedErr s3err.ErrorCode
- }{
- {
- name: "nil error",
- err: nil,
- expectedErr: s3err.ErrNone,
- },
- {
- name: "MD5 mismatch error",
- err: errors.New(constants.ErrMsgBadDigest),
- expectedErr: s3err.ErrBadDigest,
- },
- {
- name: "Read only error (direct)",
- err: weed_server.ErrReadOnly,
- expectedErr: s3err.ErrAccessDenied,
- },
- {
- name: "Read only error (wrapped)",
- err: fmt.Errorf("create file /buckets/test/file.txt: %w", weed_server.ErrReadOnly),
- expectedErr: s3err.ErrAccessDenied,
- },
- {
- name: "Context canceled error",
- err: errors.New("rpc error: code = Canceled desc = context canceled"),
- expectedErr: s3err.ErrInvalidRequest,
- },
- {
- name: "Context canceled error (simple)",
- err: errors.New("context canceled"),
- expectedErr: s3err.ErrInvalidRequest,
- },
- {
- name: "Directory exists error (sentinel)",
- err: fmt.Errorf("CreateEntry /path: %w", filer_pb.ErrExistingIsDirectory),
- expectedErr: s3err.ErrExistingObjectIsDirectory,
- },
- {
- name: "Parent is file error (sentinel)",
- err: fmt.Errorf("CreateEntry /path: %w", filer_pb.ErrParentIsFile),
- expectedErr: s3err.ErrExistingObjectIsFile,
- },
- {
- name: "Existing is file error (sentinel)",
- err: fmt.Errorf("CreateEntry /path: %w", filer_pb.ErrExistingIsFile),
- expectedErr: s3err.ErrExistingObjectIsFile,
- },
- {
- name: "Entry name too long (sentinel)",
- err: fmt.Errorf("CreateEntry: %w", filer_pb.ErrEntryNameTooLong),
- expectedErr: s3err.ErrKeyTooLongError,
- },
- {
- name: "Entry name too long (bare sentinel)",
- err: filer_pb.ErrEntryNameTooLong,
- expectedErr: s3err.ErrKeyTooLongError,
- },
- {
- name: "Unknown error",
- err: errors.New("some random error"),
- expectedErr: s3err.ErrInternalError,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := filerErrorToS3Error(tt.err)
- if result != tt.expectedErr {
- t.Errorf("filerErrorToS3Error(%v) = %v, want %v", tt.err, result, tt.expectedErr)
- }
- })
- }
-}
-
-// setupKeyLengthTestRouter creates a minimal router that maps requests directly
-// to the given handler with {bucket} and {object} mux vars, bypassing auth.
-func setupKeyLengthTestRouter(handler http.HandlerFunc) *mux.Router {
- router := mux.NewRouter()
- bucket := router.PathPrefix("/{bucket}").Subrouter()
- bucket.Path("/{object:.+}").HandlerFunc(handler)
- return router
-}
-
-func TestPutObjectHandler_KeyTooLong(t *testing.T) {
- s3a := &S3ApiServer{}
- router := setupKeyLengthTestRouter(s3a.PutObjectHandler)
-
- longKey := strings.Repeat("a", s3_constants.MaxS3ObjectKeyLength+1)
- req := httptest.NewRequest(http.MethodPut, "/bucket/"+longKey, nil)
- rr := httptest.NewRecorder()
- router.ServeHTTP(rr, req)
-
- if rr.Code != http.StatusBadRequest {
- t.Errorf("expected status %d, got %d", http.StatusBadRequest, rr.Code)
- }
- var errResp s3err.RESTErrorResponse
- if err := xml.Unmarshal(rr.Body.Bytes(), &errResp); err != nil {
- t.Fatalf("failed to parse error XML: %v", err)
- }
- if errResp.Code != "KeyTooLongError" {
- t.Errorf("expected error code KeyTooLongError, got %s", errResp.Code)
- }
-}
-
-func TestPutObjectHandler_KeyAtLimit(t *testing.T) {
- s3a := &S3ApiServer{}
-
- // Wrap handler to convert panics from uninitialized server state into 500
- // responses. The key length check runs early and writes 400 KeyTooLongError
- // before reaching any code that needs a fully initialized server. A panic
- // means the handler accepted the key and continued past the check.
- panicSafe := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- defer func() {
- if p := recover(); p != nil {
- w.WriteHeader(http.StatusInternalServerError)
- }
- }()
- s3a.PutObjectHandler(w, r)
- })
- router := setupKeyLengthTestRouter(panicSafe)
-
- atLimitKey := strings.Repeat("a", s3_constants.MaxS3ObjectKeyLength)
- req := httptest.NewRequest(http.MethodPut, "/bucket/"+atLimitKey, nil)
- rr := httptest.NewRecorder()
- router.ServeHTTP(rr, req)
-
- // Must NOT be KeyTooLongError — any other response (including 500 from
- // the minimal server hitting uninitialized state) proves the key passed.
- var errResp s3err.RESTErrorResponse
- if rr.Code == http.StatusBadRequest {
- if err := xml.Unmarshal(rr.Body.Bytes(), &errResp); err == nil && errResp.Code == "KeyTooLongError" {
- t.Errorf("key at exactly %d bytes should not be rejected as too long", s3_constants.MaxS3ObjectKeyLength)
- }
- }
-}
-
-func TestCopyObjectHandler_KeyTooLong(t *testing.T) {
- s3a := &S3ApiServer{}
- router := setupKeyLengthTestRouter(s3a.CopyObjectHandler)
-
- longKey := strings.Repeat("a", s3_constants.MaxS3ObjectKeyLength+1)
- req := httptest.NewRequest(http.MethodPut, "/bucket/"+longKey, nil)
- req.Header.Set("X-Amz-Copy-Source", "/src-bucket/src-object")
- rr := httptest.NewRecorder()
- router.ServeHTTP(rr, req)
-
- if rr.Code != http.StatusBadRequest {
- t.Errorf("expected status %d, got %d", http.StatusBadRequest, rr.Code)
- }
- var errResp s3err.RESTErrorResponse
- if err := xml.Unmarshal(rr.Body.Bytes(), &errResp); err != nil {
- t.Fatalf("failed to parse error XML: %v", err)
- }
- if errResp.Code != "KeyTooLongError" {
- t.Errorf("expected error code KeyTooLongError, got %s", errResp.Code)
- }
-}
-
-func TestNewMultipartUploadHandler_KeyTooLong(t *testing.T) {
- s3a := &S3ApiServer{}
- router := setupKeyLengthTestRouter(s3a.NewMultipartUploadHandler)
-
- longKey := strings.Repeat("a", s3_constants.MaxS3ObjectKeyLength+1)
- req := httptest.NewRequest(http.MethodPost, "/bucket/"+longKey+"?uploads", nil)
- rr := httptest.NewRecorder()
- router.ServeHTTP(rr, req)
-
- if rr.Code != http.StatusBadRequest {
- t.Errorf("expected status %d, got %d", http.StatusBadRequest, rr.Code)
- }
- var errResp s3err.RESTErrorResponse
- if err := xml.Unmarshal(rr.Body.Bytes(), &errResp); err != nil {
- t.Fatalf("failed to parse error XML: %v", err)
- }
- if errResp.Code != "KeyTooLongError" {
- t.Errorf("expected error code KeyTooLongError, got %s", errResp.Code)
- }
-}
-
-type testObjectWriteLockFactory struct {
- mu sync.Mutex
- locks map[string]*sync.Mutex
-}
-
-func (f *testObjectWriteLockFactory) newLock(bucket, object string) objectWriteLock {
- key := bucket + "|" + object
-
- f.mu.Lock()
- lock, ok := f.locks[key]
- if !ok {
- lock = &sync.Mutex{}
- f.locks[key] = lock
- }
- f.mu.Unlock()
-
- lock.Lock()
- return &testObjectWriteLock{unlock: lock.Unlock}
-}
-
-type testObjectWriteLock struct {
- once sync.Once
- unlock func()
-}
-
-func (l *testObjectWriteLock) StopShortLivedLock() error {
- l.once.Do(l.unlock)
- return nil
-}
-
-func TestWithObjectWriteLockSerializesConcurrentPreconditions(t *testing.T) {
- s3a := NewS3ApiServerForTest()
- lockFactory := &testObjectWriteLockFactory{
- locks: make(map[string]*sync.Mutex),
- }
- s3a.newObjectWriteLock = lockFactory.newLock
-
- const workers = 3
- const bucket = "test-bucket"
- const object = "/file.txt"
-
- start := make(chan struct{})
- results := make(chan s3err.ErrorCode, workers)
- var wg sync.WaitGroup
-
- var stateMu sync.Mutex
- objectExists := false
-
- for i := 0; i < workers; i++ {
- wg.Add(1)
- go func() {
- defer wg.Done()
- <-start
-
- errCode := s3a.withObjectWriteLock(bucket, object,
- func() s3err.ErrorCode {
- stateMu.Lock()
- defer stateMu.Unlock()
- if objectExists {
- return s3err.ErrPreconditionFailed
- }
- return s3err.ErrNone
- },
- func() s3err.ErrorCode {
- stateMu.Lock()
- defer stateMu.Unlock()
- objectExists = true
- return s3err.ErrNone
- },
- )
-
- results <- errCode
- }()
- }
-
- close(start)
- wg.Wait()
- close(results)
-
- var successCount int
- var preconditionFailedCount int
-
- for errCode := range results {
- switch errCode {
- case s3err.ErrNone:
- successCount++
- case s3err.ErrPreconditionFailed:
- preconditionFailedCount++
- default:
- t.Fatalf("unexpected error code: %v", errCode)
- }
- }
-
- if successCount != 1 {
- t.Fatalf("expected exactly one successful writer, got %d", successCount)
- }
- if preconditionFailedCount != workers-1 {
- t.Fatalf("expected %d precondition failures, got %d", workers-1, preconditionFailedCount)
- }
-}
-
-func TestResolveFileMode(t *testing.T) {
- tests := []struct {
- name string
- acl string
- defaultFileMode uint32
- expected uint32
- }{
- {"no acl, no default", "", 0, 0660},
- {"no acl, with default", "", 0644, 0644},
- {"private", s3_constants.CannedAclPrivate, 0, 0660},
- {"private overrides default", s3_constants.CannedAclPrivate, 0644, 0660},
- {"public-read", s3_constants.CannedAclPublicRead, 0, 0644},
- {"public-read overrides default", s3_constants.CannedAclPublicRead, 0666, 0644},
- {"public-read-write", s3_constants.CannedAclPublicReadWrite, 0, 0666},
- {"authenticated-read", s3_constants.CannedAclAuthenticatedRead, 0, 0644},
- {"bucket-owner-read", s3_constants.CannedAclBucketOwnerRead, 0, 0644},
- {"bucket-owner-full-control", s3_constants.CannedAclBucketOwnerFullControl, 0, 0660},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- s3a := &S3ApiServer{
- option: &S3ApiServerOption{
- DefaultFileMode: tt.defaultFileMode,
- },
- }
- req := httptest.NewRequest(http.MethodPut, "/bucket/object", nil)
- if tt.acl != "" {
- req.Header.Set(s3_constants.AmzCannedAcl, tt.acl)
- }
- got := s3a.resolveFileMode(req)
- if got != tt.expected {
- t.Errorf("resolveFileMode() = %04o, want %04o", got, tt.expected)
- }
- })
- }
-}
diff --git a/weed/s3api/s3api_object_handlers_test.go b/weed/s3api/s3api_object_handlers_test.go
deleted file mode 100644
index 5ca04c3ce..000000000
--- a/weed/s3api/s3api_object_handlers_test.go
+++ /dev/null
@@ -1,244 +0,0 @@
-package s3api
-
-import (
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/stretchr/testify/assert"
-)
-
-func TestNewListEntryOwnerDisplayName(t *testing.T) {
- // Create S3ApiServer with a properly initialized IAM
- s3a := &S3ApiServer{
- iam: &IdentityAccessManagement{
- accounts: map[string]*Account{
- "testid": {Id: "testid", DisplayName: "M. Tester"},
- "userid123": {Id: "userid123", DisplayName: "John Doe"},
- },
- },
- }
-
- // Create test entry with owner metadata
- entry := &filer_pb.Entry{
- Name: "test-object",
- Attributes: &filer_pb.FuseAttributes{
- Mtime: time.Now().Unix(),
- FileSize: 1024,
- },
- Extended: map[string][]byte{
- s3_constants.ExtAmzOwnerKey: []byte("testid"),
- },
- }
-
- // Test that display name is correctly looked up from IAM
- listEntry := newListEntry(s3a, entry, "", "dir", "test-object", "/buckets/test/", true, false, false)
-
- assert.NotNil(t, listEntry.Owner, "Owner should be set when fetchOwner is true")
- assert.Equal(t, "testid", listEntry.Owner.ID, "Owner ID should match stored owner")
- assert.Equal(t, "M. Tester", listEntry.Owner.DisplayName, "Display name should be looked up from IAM")
-
- // Test with owner that doesn't exist in IAM (should fallback to ID)
- entry.Extended[s3_constants.ExtAmzOwnerKey] = []byte("unknown-user")
- listEntry = newListEntry(s3a, entry, "", "dir", "test-object", "/buckets/test/", true, false, false)
-
- assert.Equal(t, "unknown-user", listEntry.Owner.ID, "Owner ID should match stored owner")
- assert.Equal(t, "unknown-user", listEntry.Owner.DisplayName, "Display name should fallback to ID when not found in IAM")
-
- // Test with no owner metadata (should use anonymous)
- entry.Extended = make(map[string][]byte)
- listEntry = newListEntry(s3a, entry, "", "dir", "test-object", "/buckets/test/", true, false, false)
-
- assert.Equal(t, s3_constants.AccountAnonymousId, listEntry.Owner.ID, "Should use anonymous ID when no owner metadata")
- assert.Equal(t, "anonymous", listEntry.Owner.DisplayName, "Should use anonymous display name when no owner metadata")
-
- // Test with fetchOwner false (should not set owner)
- listEntry = newListEntry(s3a, entry, "", "dir", "test-object", "/buckets/test/", false, false, false)
-
- assert.Nil(t, listEntry.Owner, "Owner should not be set when fetchOwner is false")
-}
-
-func TestRemoveDuplicateSlashes(t *testing.T) {
- tests := []struct {
- name string
- path string
- expectedResult string
- }{
- {
- name: "empty",
- path: "",
- expectedResult: "",
- },
- {
- name: "slash",
- path: "/",
- expectedResult: "/",
- },
- {
- name: "object",
- path: "object",
- expectedResult: "object",
- },
- {
- name: "correct path",
- path: "/path/to/object",
- expectedResult: "/path/to/object",
- },
- {
- name: "path with duplicates",
- path: "///path//to/object//",
- expectedResult: "/path/to/object/",
- },
- }
-
- for _, tst := range tests {
- t.Run(tst.name, func(t *testing.T) {
- obj := removeDuplicateSlashes(tst.path)
- assert.Equal(t, tst.expectedResult, obj)
- })
- }
-}
-
-func TestS3ApiServer_toFilerPath(t *testing.T) {
- tests := []struct {
- name string
- args string
- want string
- }{
- {
- "simple",
- "/uploads/eaf10b3b-3b3a-4dcd-92a7-edf2a512276e/67b8b9bf-7cca-4cb6-9b34-22fcb4d6e27d/Bildschirmfoto 2022-09-19 um 21.38.37.png",
- "/uploads/eaf10b3b-3b3a-4dcd-92a7-edf2a512276e/67b8b9bf-7cca-4cb6-9b34-22fcb4d6e27d/Bildschirmfoto%202022-09-19%20um%2021.38.37.png",
- },
- {
- "double prefix",
- "//uploads/t.png",
- "/uploads/t.png",
- },
- {
- "triple prefix",
- "///uploads/t.png",
- "/uploads/t.png",
- },
- {
- "empty prefix",
- "uploads/t.png",
- "/uploads/t.png",
- },
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- assert.Equalf(t, tt.want, urlEscapeObject(tt.args), "clean %v", tt.args)
- })
- }
-}
-
-func TestPartNumberWithRangeHeader(t *testing.T) {
- tests := []struct {
- name string
- partStartOffset int64 // Part's start offset in the object
- partEndOffset int64 // Part's end offset in the object
- clientRangeHeader string
- expectedStart int64 // Expected absolute start offset
- expectedEnd int64 // Expected absolute end offset
- expectError bool
- }{
- {
- name: "No client range - full part",
- partStartOffset: 1000,
- partEndOffset: 1999,
- clientRangeHeader: "",
- expectedStart: 1000,
- expectedEnd: 1999,
- expectError: false,
- },
- {
- name: "Range within part - start and end",
- partStartOffset: 1000,
- partEndOffset: 1999, // Part size: 1000 bytes
- clientRangeHeader: "bytes=0-99",
- expectedStart: 1000, // 1000 + 0
- expectedEnd: 1099, // 1000 + 99
- expectError: false,
- },
- {
- name: "Range within part - start to end",
- partStartOffset: 1000,
- partEndOffset: 1999,
- clientRangeHeader: "bytes=100-",
- expectedStart: 1100, // 1000 + 100
- expectedEnd: 1999, // 1000 + 999 (end of part)
- expectError: false,
- },
- {
- name: "Range suffix - last 100 bytes",
- partStartOffset: 1000,
- partEndOffset: 1999, // Part size: 1000 bytes
- clientRangeHeader: "bytes=-100",
- expectedStart: 1900, // 1000 + (1000 - 100)
- expectedEnd: 1999, // 1000 + 999
- expectError: false,
- },
- {
- name: "Range suffix larger than part",
- partStartOffset: 1000,
- partEndOffset: 1999, // Part size: 1000 bytes
- clientRangeHeader: "bytes=-2000",
- expectedStart: 1000, // Start of part (clamped)
- expectedEnd: 1999, // End of part
- expectError: false,
- },
- {
- name: "Range start beyond part size",
- partStartOffset: 1000,
- partEndOffset: 1999,
- clientRangeHeader: "bytes=1000-1100",
- expectedStart: 0,
- expectedEnd: 0,
- expectError: true,
- },
- {
- name: "Range end clamped to part size",
- partStartOffset: 1000,
- partEndOffset: 1999,
- clientRangeHeader: "bytes=0-2000",
- expectedStart: 1000, // 1000 + 0
- expectedEnd: 1999, // Clamped to end of part
- expectError: false,
- },
- {
- name: "Single byte range at start",
- partStartOffset: 5000,
- partEndOffset: 9999, // Part size: 5000 bytes
- clientRangeHeader: "bytes=0-0",
- expectedStart: 5000,
- expectedEnd: 5000,
- expectError: false,
- },
- {
- name: "Single byte range in middle",
- partStartOffset: 5000,
- partEndOffset: 9999,
- clientRangeHeader: "bytes=100-100",
- expectedStart: 5100,
- expectedEnd: 5100,
- expectError: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Test the actual range adjustment logic from GetObjectHandler
- startOffset, endOffset, err := adjustRangeForPart(tt.partStartOffset, tt.partEndOffset, tt.clientRangeHeader)
-
- if tt.expectError {
- assert.Error(t, err, "Expected error for range %s", tt.clientRangeHeader)
- } else {
- assert.NoError(t, err, "Unexpected error for range %s: %v", tt.clientRangeHeader, err)
- assert.Equal(t, tt.expectedStart, startOffset, "Start offset mismatch")
- assert.Equal(t, tt.expectedEnd, endOffset, "End offset mismatch")
- }
- })
- }
-}
diff --git a/weed/s3api/s3api_sosapi.go b/weed/s3api/s3api_sosapi.go
index 53d7acdb4..673b60993 100644
--- a/weed/s3api/s3api_sosapi.go
+++ b/weed/s3api/s3api_sosapi.go
@@ -14,7 +14,6 @@ import (
"fmt"
"net/http"
"strconv"
- "strings"
"time"
"github.com/seaweedfs/seaweedfs/weed/glog"
@@ -97,13 +96,6 @@ func isSOSAPIObject(object string) bool {
}
}
-// isSOSAPIClient checks if the request comes from a SOSAPI-compatible client
-// by examining the User-Agent header.
-func isSOSAPIClient(r *http.Request) bool {
- userAgent := r.Header.Get("User-Agent")
- return strings.Contains(userAgent, sosAPIClientUserAgent)
-}
-
// generateSystemXML creates the system.xml response containing storage system
// capabilities and recommendations.
func generateSystemXML() ([]byte, error) {
diff --git a/weed/s3api/s3api_sosapi_test.go b/weed/s3api/s3api_sosapi_test.go
deleted file mode 100644
index c14bd16f6..000000000
--- a/weed/s3api/s3api_sosapi_test.go
+++ /dev/null
@@ -1,248 +0,0 @@
-package s3api
-
-import (
- "encoding/xml"
- "net/http/httptest"
- "strings"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/master_pb"
-)
-
-func TestIsSOSAPIObject(t *testing.T) {
- tests := []struct {
- name string
- object string
- expected bool
- }{
- {
- name: "system.xml should be detected",
- object: ".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/system.xml",
- expected: true,
- },
- {
- name: "capacity.xml should be detected",
- object: ".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/capacity.xml",
- expected: true,
- },
- {
- name: "regular object should not be detected",
- object: "myfile.txt",
- expected: false,
- },
- {
- name: "similar but different path should not be detected",
- object: ".system-other-uuid/system.xml",
- expected: false,
- },
- {
- name: "nested path should not be detected",
- object: "prefix/.system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/system.xml",
- expected: false,
- },
- {
- name: "empty string should not be detected",
- object: "",
- expected: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := isSOSAPIObject(tt.object)
- if result != tt.expected {
- t.Errorf("isSOSAPIObject(%q) = %v, want %v", tt.object, result, tt.expected)
- }
- })
- }
-}
-
-func TestIsSOSAPIClient(t *testing.T) {
- tests := []struct {
- name string
- userAgent string
- expected bool
- }{
- {
- name: "Veeam backup client should be detected",
- userAgent: "APN/1.0 Veeam/1.0 Backup/10.0",
- expected: true,
- },
- {
- name: "exact match should be detected",
- userAgent: "APN/1.0 Veeam/1.0",
- expected: true,
- },
- {
- name: "AWS CLI should not be detected",
- userAgent: "aws-cli/2.0.0 Python/3.8",
- expected: false,
- },
- {
- name: "empty user agent should not be detected",
- userAgent: "",
- expected: false,
- },
- {
- name: "partial match should not be detected",
- userAgent: "Veeam/1.0",
- expected: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- req := httptest.NewRequest("GET", "/bucket/object", nil)
- req.Header.Set("User-Agent", tt.userAgent)
- result := isSOSAPIClient(req)
- if result != tt.expected {
- t.Errorf("isSOSAPIClient() with User-Agent %q = %v, want %v", tt.userAgent, result, tt.expected)
- }
- })
- }
-}
-
-func TestGenerateSystemXML(t *testing.T) {
- xmlData, err := generateSystemXML()
- if err != nil {
- t.Fatalf("generateSystemXML() failed: %v", err)
- }
-
- // Verify it's valid XML
- var si SystemInfo
- if err := xml.Unmarshal(xmlData, &si); err != nil {
- t.Fatalf("generated XML is invalid: %v", err)
- }
-
- // Verify required fields
- if si.ProtocolVersion != sosAPIProtocolVersion {
- t.Errorf("ProtocolVersion = %q, want %q", si.ProtocolVersion, sosAPIProtocolVersion)
- }
-
- if !strings.Contains(si.ModelName, "SeaweedFS") {
- t.Errorf("ModelName = %q, should contain 'SeaweedFS'", si.ModelName)
- }
-
- if !si.ProtocolCapabilities.CapacityInfo {
- t.Error("ProtocolCapabilities.CapacityInfo should be true")
- }
-
- if si.SystemRecommendations == nil {
- t.Fatal("SystemRecommendations should not be nil")
- }
-
- if si.SystemRecommendations.KBBlockSize != sosAPIDefaultBlockSizeKB {
- t.Errorf("KBBlockSize = %d, want %d", si.SystemRecommendations.KBBlockSize, sosAPIDefaultBlockSizeKB)
- }
-}
-
-func TestSOSAPIObjectDetectionEdgeCases(t *testing.T) {
- edgeCases := []struct {
- object string
- expected bool
- }{
- // With leading slash
- {"/.system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/system.xml", false},
- // URL encoded
- {".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c%2Fsystem.xml", false},
- // Mixed case
- {".System-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/system.xml", false},
- // Extra slashes
- {".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c//system.xml", false},
- // Correct paths
- {".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/system.xml", true},
- {".system-d26a9498-cb7c-4a87-a44a-8ae204f5ba6c/capacity.xml", true},
- }
-
- for _, tc := range edgeCases {
- result := isSOSAPIObject(tc.object)
- if result != tc.expected {
- t.Errorf("isSOSAPIObject(%q) = %v, want %v", tc.object, result, tc.expected)
- }
- }
-}
-
-func TestCollectBucketUsageFromTopology(t *testing.T) {
- topo := &master_pb.TopologyInfo{
- DataCenterInfos: []*master_pb.DataCenterInfo{
- {
- RackInfos: []*master_pb.RackInfo{
- {
- DataNodeInfos: []*master_pb.DataNodeInfo{
- {
- DiskInfos: map[string]*master_pb.DiskInfo{
- "hdd": {
- VolumeInfos: []*master_pb.VolumeInformationMessage{
- {Id: 1, Size: 100, Collection: "bucket1"},
- {Id: 2, Size: 200, Collection: "bucket2"},
- {Id: 3, Size: 300, Collection: "bucket1"},
- {Id: 1, Size: 100, Collection: "bucket1"}, // Duplicate (replica), should be ignored
- },
- },
- },
- },
- },
- },
- },
- },
- },
- }
-
- usage := collectBucketUsageFromTopology(topo, "bucket1")
- expected := int64(400) // 100 + 300
- if usage != expected {
- t.Errorf("collectBucketUsageFromTopology = %d, want %d", usage, expected)
- }
-
- usage2 := collectBucketUsageFromTopology(topo, "bucket2")
- expected2 := int64(200)
- if usage2 != expected2 {
- t.Errorf("collectBucketUsageFromTopology = %d, want %d", usage2, expected2)
- }
-}
-
-func TestCalculateClusterCapacity(t *testing.T) {
- topo := &master_pb.TopologyInfo{
- DataCenterInfos: []*master_pb.DataCenterInfo{
- {
- RackInfos: []*master_pb.RackInfo{
- {
- DataNodeInfos: []*master_pb.DataNodeInfo{
- {
- DiskInfos: map[string]*master_pb.DiskInfo{
- "hdd": {
- MaxVolumeCount: 100,
- FreeVolumeCount: 40,
- },
- },
- },
- {
- DiskInfos: map[string]*master_pb.DiskInfo{
- "hdd": {
- MaxVolumeCount: 200,
- FreeVolumeCount: 160,
- },
- },
- },
- },
- },
- },
- },
- },
- }
-
- volumeSizeLimitMb := uint64(1000) // 1GB
- volumeSizeBytes := int64(1000) * 1024 * 1024
-
- total, available := calculateClusterCapacity(topo, volumeSizeLimitMb)
-
- expectedTotal := int64(300) * volumeSizeBytes
- expectedAvailable := int64(200) * volumeSizeBytes
-
- if total != expectedTotal {
- t.Errorf("calculateClusterCapacity total = %d, want %d", total, expectedTotal)
- }
- if available != expectedAvailable {
- t.Errorf("calculateClusterCapacity available = %d, want %d", available, expectedAvailable)
- }
-}
diff --git a/weed/s3api/s3api_sse_chunk_metadata_test.go b/weed/s3api/s3api_sse_chunk_metadata_test.go
deleted file mode 100644
index ca38f44f4..000000000
--- a/weed/s3api/s3api_sse_chunk_metadata_test.go
+++ /dev/null
@@ -1,361 +0,0 @@
-package s3api
-
-import (
- "bytes"
- "crypto/aes"
- "crypto/cipher"
- "crypto/rand"
- "encoding/json"
- "io"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
-)
-
-// TestSSEKMSChunkMetadataAssignment tests that SSE-KMS creates per-chunk metadata
-// with correct ChunkOffset values for each chunk (matching the fix in putToFiler)
-func TestSSEKMSChunkMetadataAssignment(t *testing.T) {
- kmsKey := SetupTestKMS(t)
- defer kmsKey.Cleanup()
-
- // Generate SSE-KMS key by encrypting test data (this gives us a real SSEKMSKey)
- encryptionContext := BuildEncryptionContext("test-bucket", "test-object", false)
- testData := "Test data for SSE-KMS chunk metadata validation"
- encryptedReader, sseKMSKey, err := CreateSSEKMSEncryptedReader(bytes.NewReader([]byte(testData)), kmsKey.KeyID, encryptionContext)
- if err != nil {
- t.Fatalf("Failed to create encrypted reader: %v", err)
- }
- // Read to complete encryption setup
- io.ReadAll(encryptedReader)
-
- // Serialize the base metadata (what putToFiler receives before chunking)
- baseMetadata, err := SerializeSSEKMSMetadata(sseKMSKey)
- if err != nil {
- t.Fatalf("Failed to serialize base SSE-KMS metadata: %v", err)
- }
-
- // Simulate multi-chunk upload scenario (what putToFiler does after UploadReaderInChunks)
- simulatedChunks := []*filer_pb.FileChunk{
- {FileId: "chunk1", Offset: 0, Size: 8 * 1024 * 1024}, // 8MB chunk at offset 0
- {FileId: "chunk2", Offset: 8 * 1024 * 1024, Size: 8 * 1024 * 1024}, // 8MB chunk at offset 8MB
- {FileId: "chunk3", Offset: 16 * 1024 * 1024, Size: 4 * 1024 * 1024}, // 4MB chunk at offset 16MB
- }
-
- // THIS IS THE CRITICAL FIX: Create per-chunk metadata (lines 421-443 in putToFiler)
- for _, chunk := range simulatedChunks {
- chunk.SseType = filer_pb.SSEType_SSE_KMS
-
- // Create a copy of the SSE-KMS key with chunk-specific offset
- chunkSSEKey := &SSEKMSKey{
- KeyID: sseKMSKey.KeyID,
- EncryptedDataKey: sseKMSKey.EncryptedDataKey,
- EncryptionContext: sseKMSKey.EncryptionContext,
- BucketKeyEnabled: sseKMSKey.BucketKeyEnabled,
- IV: sseKMSKey.IV,
- ChunkOffset: chunk.Offset, // Set chunk-specific offset
- }
-
- // Serialize per-chunk metadata
- chunkMetadata, serErr := SerializeSSEKMSMetadata(chunkSSEKey)
- if serErr != nil {
- t.Fatalf("Failed to serialize SSE-KMS metadata for chunk at offset %d: %v", chunk.Offset, serErr)
- }
- chunk.SseMetadata = chunkMetadata
- }
-
- // VERIFICATION 1: Each chunk should have different metadata (due to different ChunkOffset)
- metadataSet := make(map[string]bool)
- for i, chunk := range simulatedChunks {
- metadataStr := string(chunk.SseMetadata)
- if metadataSet[metadataStr] {
- t.Errorf("Chunk %d has duplicate metadata (should be unique per chunk)", i)
- }
- metadataSet[metadataStr] = true
-
- // Deserialize and verify ChunkOffset
- var metadata SSEKMSMetadata
- if err := json.Unmarshal(chunk.SseMetadata, &metadata); err != nil {
- t.Fatalf("Failed to deserialize chunk %d metadata: %v", i, err)
- }
-
- expectedOffset := chunk.Offset
- if metadata.PartOffset != expectedOffset {
- t.Errorf("Chunk %d: expected PartOffset=%d, got %d", i, expectedOffset, metadata.PartOffset)
- }
-
- t.Logf("✓ Chunk %d: PartOffset=%d (correct)", i, metadata.PartOffset)
- }
-
- // VERIFICATION 2: Verify metadata can be deserialized and has correct ChunkOffset
- for i, chunk := range simulatedChunks {
- // Deserialize chunk metadata
- deserializedKey, err := DeserializeSSEKMSMetadata(chunk.SseMetadata)
- if err != nil {
- t.Fatalf("Failed to deserialize chunk %d metadata: %v", i, err)
- }
-
- // Verify the deserialized key has correct ChunkOffset
- if deserializedKey.ChunkOffset != chunk.Offset {
- t.Errorf("Chunk %d: deserialized ChunkOffset=%d, expected %d",
- i, deserializedKey.ChunkOffset, chunk.Offset)
- }
-
- // Verify IV is set (should be inherited from base)
- if len(deserializedKey.IV) != aes.BlockSize {
- t.Errorf("Chunk %d: invalid IV length: %d", i, len(deserializedKey.IV))
- }
-
- // Verify KeyID matches
- if deserializedKey.KeyID != sseKMSKey.KeyID {
- t.Errorf("Chunk %d: KeyID mismatch", i)
- }
-
- t.Logf("✓ Chunk %d: metadata deserialized successfully (ChunkOffset=%d, KeyID=%s)",
- i, deserializedKey.ChunkOffset, deserializedKey.KeyID)
- }
-
- // VERIFICATION 3: Ensure base metadata is NOT reused (the bug we're preventing)
- var baseMetadataStruct SSEKMSMetadata
- if err := json.Unmarshal(baseMetadata, &baseMetadataStruct); err != nil {
- t.Fatalf("Failed to deserialize base metadata: %v", err)
- }
-
- // Base metadata should have ChunkOffset=0
- if baseMetadataStruct.PartOffset != 0 {
- t.Errorf("Base metadata should have PartOffset=0, got %d", baseMetadataStruct.PartOffset)
- }
-
- // Chunks 2 and 3 should NOT have the same metadata as base (proving we're not reusing)
- for i := 1; i < len(simulatedChunks); i++ {
- if bytes.Equal(simulatedChunks[i].SseMetadata, baseMetadata) {
- t.Errorf("CRITICAL BUG: Chunk %d reuses base metadata (should have per-chunk metadata)", i)
- }
- }
-
- t.Log("✓ All chunks have unique per-chunk metadata (bug prevented)")
-}
-
-// TestSSES3ChunkMetadataAssignment tests that SSE-S3 creates per-chunk metadata
-// with offset-adjusted IVs for each chunk (matching the fix in putToFiler)
-func TestSSES3ChunkMetadataAssignment(t *testing.T) {
- // Initialize global SSE-S3 key manager
- globalSSES3KeyManager = NewSSES3KeyManager()
- defer func() {
- globalSSES3KeyManager = NewSSES3KeyManager()
- }()
-
- keyManager := GetSSES3KeyManager()
- keyManager.superKey = make([]byte, 32)
- rand.Read(keyManager.superKey)
-
- // Generate SSE-S3 key
- sseS3Key, err := GenerateSSES3Key()
- if err != nil {
- t.Fatalf("Failed to generate SSE-S3 key: %v", err)
- }
-
- // Generate base IV
- baseIV := make([]byte, aes.BlockSize)
- rand.Read(baseIV)
- sseS3Key.IV = baseIV
-
- // Serialize base metadata (what putToFiler receives)
- baseMetadata, err := SerializeSSES3Metadata(sseS3Key)
- if err != nil {
- t.Fatalf("Failed to serialize base SSE-S3 metadata: %v", err)
- }
-
- // Simulate multi-chunk upload scenario (what putToFiler does after UploadReaderInChunks)
- simulatedChunks := []*filer_pb.FileChunk{
- {FileId: "chunk1", Offset: 0, Size: 8 * 1024 * 1024}, // 8MB chunk at offset 0
- {FileId: "chunk2", Offset: 8 * 1024 * 1024, Size: 8 * 1024 * 1024}, // 8MB chunk at offset 8MB
- {FileId: "chunk3", Offset: 16 * 1024 * 1024, Size: 4 * 1024 * 1024}, // 4MB chunk at offset 16MB
- }
-
- // THIS IS THE CRITICAL FIX: Create per-chunk metadata (lines 444-468 in putToFiler)
- for _, chunk := range simulatedChunks {
- chunk.SseType = filer_pb.SSEType_SSE_S3
-
- // Calculate chunk-specific IV using base IV and chunk offset
- chunkIV, _ := calculateIVWithOffset(sseS3Key.IV, chunk.Offset)
-
- // Create a copy of the SSE-S3 key with chunk-specific IV
- chunkSSEKey := &SSES3Key{
- Key: sseS3Key.Key,
- KeyID: sseS3Key.KeyID,
- Algorithm: sseS3Key.Algorithm,
- IV: chunkIV, // Use chunk-specific IV
- }
-
- // Serialize per-chunk metadata
- chunkMetadata, serErr := SerializeSSES3Metadata(chunkSSEKey)
- if serErr != nil {
- t.Fatalf("Failed to serialize SSE-S3 metadata for chunk at offset %d: %v", chunk.Offset, serErr)
- }
- chunk.SseMetadata = chunkMetadata
- }
-
- // VERIFICATION 1: Each chunk should have different metadata (due to different IVs)
- metadataSet := make(map[string]bool)
- for i, chunk := range simulatedChunks {
- metadataStr := string(chunk.SseMetadata)
- if metadataSet[metadataStr] {
- t.Errorf("Chunk %d has duplicate metadata (should be unique per chunk)", i)
- }
- metadataSet[metadataStr] = true
-
- // Deserialize and verify IV
- deserializedKey, err := DeserializeSSES3Metadata(chunk.SseMetadata, keyManager)
- if err != nil {
- t.Fatalf("Failed to deserialize chunk %d metadata: %v", i, err)
- }
-
- // Calculate expected IV for this chunk
- expectedIV, _ := calculateIVWithOffset(baseIV, chunk.Offset)
- if !bytes.Equal(deserializedKey.IV, expectedIV) {
- t.Errorf("Chunk %d: IV mismatch\nExpected: %x\nGot: %x",
- i, expectedIV[:8], deserializedKey.IV[:8])
- }
-
- t.Logf("✓ Chunk %d: IV correctly adjusted for offset=%d", i, chunk.Offset)
- }
-
- // VERIFICATION 2: Verify decryption works with per-chunk IVs
- for i, chunk := range simulatedChunks {
- // Deserialize chunk metadata
- deserializedKey, err := DeserializeSSES3Metadata(chunk.SseMetadata, keyManager)
- if err != nil {
- t.Fatalf("Failed to deserialize chunk %d metadata: %v", i, err)
- }
-
- // Simulate encryption/decryption with the chunk's IV
- testData := []byte("Test data for SSE-S3 chunk decryption verification")
- block, err := aes.NewCipher(deserializedKey.Key)
- if err != nil {
- t.Fatalf("Failed to create cipher: %v", err)
- }
-
- // Encrypt with chunk's IV
- ciphertext := make([]byte, len(testData))
- stream := cipher.NewCTR(block, deserializedKey.IV)
- stream.XORKeyStream(ciphertext, testData)
-
- // Decrypt with chunk's IV
- plaintext := make([]byte, len(ciphertext))
- block2, _ := aes.NewCipher(deserializedKey.Key)
- stream2 := cipher.NewCTR(block2, deserializedKey.IV)
- stream2.XORKeyStream(plaintext, ciphertext)
-
- if !bytes.Equal(plaintext, testData) {
- t.Errorf("Chunk %d: decryption failed", i)
- }
-
- t.Logf("✓ Chunk %d: encryption/decryption successful with chunk-specific IV", i)
- }
-
- // VERIFICATION 3: Ensure base IV is NOT reused for non-zero offset chunks (the bug we're preventing)
- for i := 1; i < len(simulatedChunks); i++ {
- if bytes.Equal(simulatedChunks[i].SseMetadata, baseMetadata) {
- t.Errorf("CRITICAL BUG: Chunk %d reuses base metadata (should have per-chunk metadata)", i)
- }
-
- // Verify chunk metadata has different IV than base IV
- deserializedKey, _ := DeserializeSSES3Metadata(simulatedChunks[i].SseMetadata, keyManager)
- if bytes.Equal(deserializedKey.IV, baseIV) {
- t.Errorf("CRITICAL BUG: Chunk %d uses base IV (should use offset-adjusted IV)", i)
- }
- }
-
- t.Log("✓ All chunks have unique per-chunk IVs (bug prevented)")
-}
-
-// TestSSEChunkMetadataComparison tests that the bug (reusing same metadata for all chunks)
-// would cause decryption failures, while the fix (per-chunk metadata) works correctly
-func TestSSEChunkMetadataComparison(t *testing.T) {
- // Generate test key and IV
- key := make([]byte, 32)
- rand.Read(key)
- baseIV := make([]byte, aes.BlockSize)
- rand.Read(baseIV)
-
- // Create test data for 3 chunks
- chunk0Data := []byte("Chunk 0 data at offset 0")
- chunk1Data := []byte("Chunk 1 data at offset 8MB")
- chunk2Data := []byte("Chunk 2 data at offset 16MB")
-
- chunkOffsets := []int64{0, 8 * 1024 * 1024, 16 * 1024 * 1024}
- chunkDataList := [][]byte{chunk0Data, chunk1Data, chunk2Data}
-
- // Scenario 1: BUG - Using same IV for all chunks (what the old code did)
- t.Run("Bug: Reusing base IV causes decryption failures", func(t *testing.T) {
- var encryptedChunks [][]byte
-
- // Encrypt each chunk with offset-adjusted IV (what encryption does)
- for i, offset := range chunkOffsets {
- adjustedIV, _ := calculateIVWithOffset(baseIV, offset)
- block, _ := aes.NewCipher(key)
- stream := cipher.NewCTR(block, adjustedIV)
-
- ciphertext := make([]byte, len(chunkDataList[i]))
- stream.XORKeyStream(ciphertext, chunkDataList[i])
- encryptedChunks = append(encryptedChunks, ciphertext)
- }
-
- // Try to decrypt with base IV (THE BUG)
- for i := range encryptedChunks {
- block, _ := aes.NewCipher(key)
- stream := cipher.NewCTR(block, baseIV) // BUG: Always using base IV
-
- plaintext := make([]byte, len(encryptedChunks[i]))
- stream.XORKeyStream(plaintext, encryptedChunks[i])
-
- if i == 0 {
- // Chunk 0 should work (offset 0 means base IV = adjusted IV)
- if !bytes.Equal(plaintext, chunkDataList[i]) {
- t.Errorf("Chunk 0 decryption failed (unexpected)")
- }
- } else {
- // Chunks 1 and 2 should FAIL (wrong IV)
- if bytes.Equal(plaintext, chunkDataList[i]) {
- t.Errorf("BUG NOT REPRODUCED: Chunk %d decrypted correctly with base IV (should fail)", i)
- } else {
- t.Logf("✓ Chunk %d: Correctly failed to decrypt with base IV (bug reproduced)", i)
- }
- }
- }
- })
-
- // Scenario 2: FIX - Using per-chunk offset-adjusted IVs (what the new code does)
- t.Run("Fix: Per-chunk IVs enable correct decryption", func(t *testing.T) {
- var encryptedChunks [][]byte
- var chunkIVs [][]byte
-
- // Encrypt each chunk with offset-adjusted IV
- for i, offset := range chunkOffsets {
- adjustedIV, _ := calculateIVWithOffset(baseIV, offset)
- chunkIVs = append(chunkIVs, adjustedIV)
-
- block, _ := aes.NewCipher(key)
- stream := cipher.NewCTR(block, adjustedIV)
-
- ciphertext := make([]byte, len(chunkDataList[i]))
- stream.XORKeyStream(ciphertext, chunkDataList[i])
- encryptedChunks = append(encryptedChunks, ciphertext)
- }
-
- // Decrypt with per-chunk IVs (THE FIX)
- for i := range encryptedChunks {
- block, _ := aes.NewCipher(key)
- stream := cipher.NewCTR(block, chunkIVs[i]) // FIX: Using per-chunk IV
-
- plaintext := make([]byte, len(encryptedChunks[i]))
- stream.XORKeyStream(plaintext, encryptedChunks[i])
-
- if !bytes.Equal(plaintext, chunkDataList[i]) {
- t.Errorf("Chunk %d decryption failed with per-chunk IV (unexpected)", i)
- } else {
- t.Logf("✓ Chunk %d: Successfully decrypted with per-chunk IV", i)
- }
- }
- })
-}
diff --git a/weed/s3api/s3api_streaming_copy.go b/weed/s3api/s3api_streaming_copy.go
deleted file mode 100644
index f50f715e3..000000000
--- a/weed/s3api/s3api_streaming_copy.go
+++ /dev/null
@@ -1,601 +0,0 @@
-package s3api
-
-import (
- "context"
- "crypto/md5"
- "crypto/sha256"
- "encoding/hex"
- "fmt"
- "hash"
- "io"
- "net/http"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
- "github.com/seaweedfs/seaweedfs/weed/util"
-)
-
-// StreamingCopySpec defines the specification for streaming copy operations
-type StreamingCopySpec struct {
- SourceReader io.Reader
- TargetSize int64
- EncryptionSpec *EncryptionSpec
- CompressionSpec *CompressionSpec
- HashCalculation bool
- BufferSize int
-}
-
-// EncryptionSpec defines encryption parameters for streaming
-type EncryptionSpec struct {
- NeedsDecryption bool
- NeedsEncryption bool
- SourceKey interface{} // SSECustomerKey or SSEKMSKey
- DestinationKey interface{} // SSECustomerKey or SSEKMSKey
- SourceType EncryptionType
- DestinationType EncryptionType
- SourceMetadata map[string][]byte // Source metadata for IV extraction
- DestinationIV []byte // Generated IV for destination
-}
-
-// CompressionSpec defines compression parameters for streaming
-type CompressionSpec struct {
- IsCompressed bool
- CompressionType string
- NeedsDecompression bool
- NeedsCompression bool
-}
-
-// StreamingCopyManager handles streaming copy operations
-type StreamingCopyManager struct {
- s3a *S3ApiServer
- bufferSize int
-}
-
-// NewStreamingCopyManager creates a new streaming copy manager
-func NewStreamingCopyManager(s3a *S3ApiServer) *StreamingCopyManager {
- return &StreamingCopyManager{
- s3a: s3a,
- bufferSize: 64 * 1024, // 64KB default buffer
- }
-}
-
-// ExecuteStreamingCopy performs a streaming copy operation and returns the encryption spec
-// The encryption spec is needed for SSE-S3 to properly set destination metadata (fixes GitHub #7562)
-func (scm *StreamingCopyManager) ExecuteStreamingCopy(ctx context.Context, entry *filer_pb.Entry, r *http.Request, dstPath string, state *EncryptionState) ([]*filer_pb.FileChunk, *EncryptionSpec, error) {
- // Create streaming copy specification
- spec, err := scm.createStreamingSpec(entry, r, state)
- if err != nil {
- return nil, nil, fmt.Errorf("create streaming spec: %w", err)
- }
-
- // Create source reader from entry
- sourceReader, err := scm.createSourceReader(entry)
- if err != nil {
- return nil, nil, fmt.Errorf("create source reader: %w", err)
- }
- defer sourceReader.Close()
-
- spec.SourceReader = sourceReader
-
- // Create processing pipeline
- processedReader, err := scm.createProcessingPipeline(spec)
- if err != nil {
- return nil, nil, fmt.Errorf("create processing pipeline: %w", err)
- }
-
- // Stream to destination
- chunks, err := scm.streamToDestination(ctx, processedReader, spec, dstPath)
- if err != nil {
- return nil, nil, err
- }
-
- return chunks, spec.EncryptionSpec, nil
-}
-
-// createStreamingSpec creates a streaming specification based on copy parameters
-func (scm *StreamingCopyManager) createStreamingSpec(entry *filer_pb.Entry, r *http.Request, state *EncryptionState) (*StreamingCopySpec, error) {
- spec := &StreamingCopySpec{
- BufferSize: scm.bufferSize,
- HashCalculation: true,
- }
-
- // Calculate target size
- sizeCalc := NewCopySizeCalculator(entry, r)
- spec.TargetSize = sizeCalc.CalculateTargetSize()
-
- // Create encryption specification
- encSpec, err := scm.createEncryptionSpec(entry, r, state)
- if err != nil {
- return nil, err
- }
- spec.EncryptionSpec = encSpec
-
- // Create compression specification
- spec.CompressionSpec = scm.createCompressionSpec(entry, r)
-
- return spec, nil
-}
-
-// createEncryptionSpec creates encryption specification for streaming
-func (scm *StreamingCopyManager) createEncryptionSpec(entry *filer_pb.Entry, r *http.Request, state *EncryptionState) (*EncryptionSpec, error) {
- spec := &EncryptionSpec{
- NeedsDecryption: state.IsSourceEncrypted(),
- NeedsEncryption: state.IsTargetEncrypted(),
- SourceMetadata: entry.Extended, // Pass source metadata for IV extraction
- }
-
- // Set source encryption details
- if state.SrcSSEC {
- spec.SourceType = EncryptionTypeSSEC
- sourceKey, err := ParseSSECCopySourceHeaders(r)
- if err != nil {
- return nil, fmt.Errorf("parse SSE-C copy source headers: %w", err)
- }
- spec.SourceKey = sourceKey
- } else if state.SrcSSEKMS {
- spec.SourceType = EncryptionTypeSSEKMS
- // Extract SSE-KMS key from metadata
- if keyData, exists := entry.Extended[s3_constants.SeaweedFSSSEKMSKey]; exists {
- sseKey, err := DeserializeSSEKMSMetadata(keyData)
- if err != nil {
- return nil, fmt.Errorf("deserialize SSE-KMS metadata: %w", err)
- }
- spec.SourceKey = sseKey
- }
- } else if state.SrcSSES3 {
- spec.SourceType = EncryptionTypeSSES3
- // Extract SSE-S3 key from metadata
- if keyData, exists := entry.Extended[s3_constants.SeaweedFSSSES3Key]; exists {
- keyManager := GetSSES3KeyManager()
- sseKey, err := DeserializeSSES3Metadata(keyData, keyManager)
- if err != nil {
- return nil, fmt.Errorf("deserialize SSE-S3 metadata: %w", err)
- }
- spec.SourceKey = sseKey
- }
- }
-
- // Set destination encryption details
- if state.DstSSEC {
- spec.DestinationType = EncryptionTypeSSEC
- destKey, err := ParseSSECHeaders(r)
- if err != nil {
- return nil, fmt.Errorf("parse SSE-C headers: %w", err)
- }
- spec.DestinationKey = destKey
- } else if state.DstSSEKMS {
- spec.DestinationType = EncryptionTypeSSEKMS
- // Parse KMS parameters
- keyID, encryptionContext, bucketKeyEnabled, err := ParseSSEKMSCopyHeaders(r)
- if err != nil {
- return nil, fmt.Errorf("parse SSE-KMS copy headers: %w", err)
- }
-
- // Create SSE-KMS key for destination
- sseKey := &SSEKMSKey{
- KeyID: keyID,
- EncryptionContext: encryptionContext,
- BucketKeyEnabled: bucketKeyEnabled,
- }
- spec.DestinationKey = sseKey
- } else if state.DstSSES3 {
- spec.DestinationType = EncryptionTypeSSES3
- // Generate or retrieve SSE-S3 key
- keyManager := GetSSES3KeyManager()
- sseKey, err := keyManager.GetOrCreateKey("")
- if err != nil {
- return nil, fmt.Errorf("get SSE-S3 key: %w", err)
- }
- spec.DestinationKey = sseKey
- }
-
- return spec, nil
-}
-
-// createCompressionSpec creates compression specification for streaming
-func (scm *StreamingCopyManager) createCompressionSpec(entry *filer_pb.Entry, r *http.Request) *CompressionSpec {
- return &CompressionSpec{
- IsCompressed: isCompressedEntry(entry),
- // For now, we don't change compression during copy
- NeedsDecompression: false,
- NeedsCompression: false,
- }
-}
-
-// createSourceReader creates a reader for the source entry
-func (scm *StreamingCopyManager) createSourceReader(entry *filer_pb.Entry) (io.ReadCloser, error) {
- // Create a multi-chunk reader that streams from all chunks
- return scm.s3a.createMultiChunkReader(entry)
-}
-
-// createProcessingPipeline creates a processing pipeline for the copy operation
-func (scm *StreamingCopyManager) createProcessingPipeline(spec *StreamingCopySpec) (io.Reader, error) {
- reader := spec.SourceReader
-
- // Add decryption if needed
- if spec.EncryptionSpec.NeedsDecryption {
- decryptedReader, err := scm.createDecryptionReader(reader, spec.EncryptionSpec)
- if err != nil {
- return nil, fmt.Errorf("create decryption reader: %w", err)
- }
- reader = decryptedReader
- }
-
- // Add decompression if needed
- if spec.CompressionSpec.NeedsDecompression {
- decompressedReader, err := scm.createDecompressionReader(reader, spec.CompressionSpec)
- if err != nil {
- return nil, fmt.Errorf("create decompression reader: %w", err)
- }
- reader = decompressedReader
- }
-
- // Add compression if needed
- if spec.CompressionSpec.NeedsCompression {
- compressedReader, err := scm.createCompressionReader(reader, spec.CompressionSpec)
- if err != nil {
- return nil, fmt.Errorf("create compression reader: %w", err)
- }
- reader = compressedReader
- }
-
- // Add encryption if needed
- if spec.EncryptionSpec.NeedsEncryption {
- encryptedReader, err := scm.createEncryptionReader(reader, spec.EncryptionSpec)
- if err != nil {
- return nil, fmt.Errorf("create encryption reader: %w", err)
- }
- reader = encryptedReader
- }
-
- // Add hash calculation if needed
- if spec.HashCalculation {
- reader = scm.createHashReader(reader)
- }
-
- return reader, nil
-}
-
-// createDecryptionReader creates a decryption reader based on encryption type
-func (scm *StreamingCopyManager) createDecryptionReader(reader io.Reader, encSpec *EncryptionSpec) (io.Reader, error) {
- switch encSpec.SourceType {
- case EncryptionTypeSSEC:
- if sourceKey, ok := encSpec.SourceKey.(*SSECustomerKey); ok {
- // Get IV from metadata
- iv, err := GetSSECIVFromMetadata(encSpec.SourceMetadata)
- if err != nil {
- return nil, fmt.Errorf("get IV from metadata: %w", err)
- }
- return CreateSSECDecryptedReader(reader, sourceKey, iv)
- }
- return nil, fmt.Errorf("invalid SSE-C source key type")
-
- case EncryptionTypeSSEKMS:
- if sseKey, ok := encSpec.SourceKey.(*SSEKMSKey); ok {
- return CreateSSEKMSDecryptedReader(reader, sseKey)
- }
- return nil, fmt.Errorf("invalid SSE-KMS source key type")
-
- case EncryptionTypeSSES3:
- if sseKey, ok := encSpec.SourceKey.(*SSES3Key); ok {
- // For SSE-S3, the IV is stored within the SSES3Key metadata, not as separate metadata
- iv := sseKey.IV
- if len(iv) == 0 {
- return nil, fmt.Errorf("SSE-S3 key is missing IV for streaming copy")
- }
- return CreateSSES3DecryptedReader(reader, sseKey, iv)
- }
- return nil, fmt.Errorf("invalid SSE-S3 source key type")
-
- default:
- return reader, nil
- }
-}
-
-// createEncryptionReader creates an encryption reader based on encryption type
-func (scm *StreamingCopyManager) createEncryptionReader(reader io.Reader, encSpec *EncryptionSpec) (io.Reader, error) {
- switch encSpec.DestinationType {
- case EncryptionTypeSSEC:
- if destKey, ok := encSpec.DestinationKey.(*SSECustomerKey); ok {
- encryptedReader, iv, err := CreateSSECEncryptedReader(reader, destKey)
- if err != nil {
- return nil, err
- }
- // Store IV in destination metadata (this would need to be handled by caller)
- encSpec.DestinationIV = iv
- return encryptedReader, nil
- }
- return nil, fmt.Errorf("invalid SSE-C destination key type")
-
- case EncryptionTypeSSEKMS:
- if sseKey, ok := encSpec.DestinationKey.(*SSEKMSKey); ok {
- encryptedReader, updatedKey, err := CreateSSEKMSEncryptedReaderWithBucketKey(reader, sseKey.KeyID, sseKey.EncryptionContext, sseKey.BucketKeyEnabled)
- if err != nil {
- return nil, err
- }
- // Store IV from the updated key
- encSpec.DestinationIV = updatedKey.IV
- return encryptedReader, nil
- }
- return nil, fmt.Errorf("invalid SSE-KMS destination key type")
-
- case EncryptionTypeSSES3:
- if sseKey, ok := encSpec.DestinationKey.(*SSES3Key); ok {
- encryptedReader, iv, err := CreateSSES3EncryptedReader(reader, sseKey)
- if err != nil {
- return nil, err
- }
- // Store IV for metadata
- encSpec.DestinationIV = iv
- return encryptedReader, nil
- }
- return nil, fmt.Errorf("invalid SSE-S3 destination key type")
-
- default:
- return reader, nil
- }
-}
-
-// createDecompressionReader creates a decompression reader
-func (scm *StreamingCopyManager) createDecompressionReader(reader io.Reader, compSpec *CompressionSpec) (io.Reader, error) {
- if !compSpec.NeedsDecompression {
- return reader, nil
- }
-
- switch compSpec.CompressionType {
- case "gzip":
- // Use SeaweedFS's streaming gzip decompression
- pr, pw := io.Pipe()
- go func() {
- defer pw.Close()
- _, err := util.GunzipStream(pw, reader)
- if err != nil {
- pw.CloseWithError(fmt.Errorf("gzip decompression failed: %v", err))
- }
- }()
- return pr, nil
- default:
- // Unknown compression type, return as-is
- return reader, nil
- }
-}
-
-// createCompressionReader creates a compression reader
-func (scm *StreamingCopyManager) createCompressionReader(reader io.Reader, compSpec *CompressionSpec) (io.Reader, error) {
- if !compSpec.NeedsCompression {
- return reader, nil
- }
-
- switch compSpec.CompressionType {
- case "gzip":
- // Use SeaweedFS's streaming gzip compression
- pr, pw := io.Pipe()
- go func() {
- defer pw.Close()
- _, err := util.GzipStream(pw, reader)
- if err != nil {
- pw.CloseWithError(fmt.Errorf("gzip compression failed: %v", err))
- }
- }()
- return pr, nil
- default:
- // Unknown compression type, return as-is
- return reader, nil
- }
-}
-
-// HashReader wraps an io.Reader to calculate MD5 and SHA256 hashes
-type HashReader struct {
- reader io.Reader
- md5Hash hash.Hash
- sha256Hash hash.Hash
-}
-
-// NewHashReader creates a new hash calculating reader
-func NewHashReader(reader io.Reader) *HashReader {
- return &HashReader{
- reader: reader,
- md5Hash: md5.New(),
- sha256Hash: sha256.New(),
- }
-}
-
-// Read implements io.Reader and calculates hashes as data flows through
-func (hr *HashReader) Read(p []byte) (n int, err error) {
- n, err = hr.reader.Read(p)
- if n > 0 {
- // Update both hashes with the data read
- hr.md5Hash.Write(p[:n])
- hr.sha256Hash.Write(p[:n])
- }
- return n, err
-}
-
-// MD5Sum returns the current MD5 hash
-func (hr *HashReader) MD5Sum() []byte {
- return hr.md5Hash.Sum(nil)
-}
-
-// SHA256Sum returns the current SHA256 hash
-func (hr *HashReader) SHA256Sum() []byte {
- return hr.sha256Hash.Sum(nil)
-}
-
-// MD5Hex returns the MD5 hash as a hex string
-func (hr *HashReader) MD5Hex() string {
- return hex.EncodeToString(hr.MD5Sum())
-}
-
-// SHA256Hex returns the SHA256 hash as a hex string
-func (hr *HashReader) SHA256Hex() string {
- return hex.EncodeToString(hr.SHA256Sum())
-}
-
-// createHashReader creates a hash calculation reader
-func (scm *StreamingCopyManager) createHashReader(reader io.Reader) io.Reader {
- return NewHashReader(reader)
-}
-
-// streamToDestination streams the processed data to the destination
-func (scm *StreamingCopyManager) streamToDestination(ctx context.Context, reader io.Reader, spec *StreamingCopySpec, dstPath string) ([]*filer_pb.FileChunk, error) {
- // For now, we'll use the existing chunk-based approach
- // In a full implementation, this would stream directly to the destination
- // without creating intermediate chunks
-
- // This is a placeholder that converts back to chunk-based approach
- // A full streaming implementation would write directly to the destination
- return scm.streamToChunks(ctx, reader, spec, dstPath)
-}
-
-// streamToChunks converts streaming data back to chunks (temporary implementation)
-func (scm *StreamingCopyManager) streamToChunks(ctx context.Context, reader io.Reader, spec *StreamingCopySpec, dstPath string) ([]*filer_pb.FileChunk, error) {
- // This is a simplified implementation that reads the stream and creates chunks
- // A full implementation would be more sophisticated
-
- var chunks []*filer_pb.FileChunk
- buffer := make([]byte, spec.BufferSize)
- offset := int64(0)
-
- for {
- n, err := reader.Read(buffer)
- if n > 0 {
- // Create chunk for this data, setting SSE type and per-chunk metadata (including chunk-specific IVs for SSE-S3)
- chunk, chunkErr := scm.createChunkFromData(buffer[:n], offset, dstPath, spec.EncryptionSpec)
- if chunkErr != nil {
- return nil, fmt.Errorf("create chunk from data: %w", chunkErr)
- }
- chunks = append(chunks, chunk)
- offset += int64(n)
- }
-
- if err == io.EOF {
- break
- }
- if err != nil {
- return nil, fmt.Errorf("read stream: %w", err)
- }
- }
-
- return chunks, nil
-}
-
-// createChunkFromData creates a chunk from streaming data
-func (scm *StreamingCopyManager) createChunkFromData(data []byte, offset int64, dstPath string, encSpec *EncryptionSpec) (*filer_pb.FileChunk, error) {
- // Assign new volume
- assignResult, err := scm.s3a.assignNewVolume(dstPath)
- if err != nil {
- return nil, fmt.Errorf("assign volume: %w", err)
- }
-
- // Create chunk
- chunk := &filer_pb.FileChunk{
- Offset: offset,
- Size: uint64(len(data)),
- }
-
- // Set SSE type and metadata on chunk if destination is encrypted
- // This is critical for GetObject to know to decrypt the data - fixes GitHub #7562
- if encSpec != nil && encSpec.NeedsEncryption {
- switch encSpec.DestinationType {
- case EncryptionTypeSSEC:
- chunk.SseType = filer_pb.SSEType_SSE_C
- // SSE-C metadata is handled at object level, not per-chunk for streaming copy
- case EncryptionTypeSSEKMS:
- chunk.SseType = filer_pb.SSEType_SSE_KMS
- // SSE-KMS metadata is handled at object level, not per-chunk for streaming copy
- case EncryptionTypeSSES3:
- chunk.SseType = filer_pb.SSEType_SSE_S3
- // Create per-chunk SSE-S3 metadata with chunk-specific IV
- if sseKey, ok := encSpec.DestinationKey.(*SSES3Key); ok {
- // Calculate chunk-specific IV using base IV and chunk offset
- baseIV := encSpec.DestinationIV
- if len(baseIV) == 0 {
- return nil, fmt.Errorf("SSE-S3 encryption requires DestinationIV to be set for chunk at offset %d", offset)
- }
- chunkIV, _ := calculateIVWithOffset(baseIV, offset)
- // Create chunk key with the chunk-specific IV
- chunkSSEKey := &SSES3Key{
- Key: sseKey.Key,
- KeyID: sseKey.KeyID,
- Algorithm: sseKey.Algorithm,
- IV: chunkIV,
- }
- chunkMetadata, serErr := SerializeSSES3Metadata(chunkSSEKey)
- if serErr != nil {
- return nil, fmt.Errorf("failed to serialize chunk SSE-S3 metadata: %w", serErr)
- }
- chunk.SseMetadata = chunkMetadata
- }
- }
- }
-
- // Set file ID
- if err := scm.s3a.setChunkFileId(chunk, assignResult); err != nil {
- return nil, err
- }
-
- // Upload data
- if err := scm.s3a.uploadChunkData(data, assignResult, false); err != nil {
- return nil, fmt.Errorf("upload chunk data: %w", err)
- }
-
- return chunk, nil
-}
-
-// createMultiChunkReader creates a reader that streams from multiple chunks
-func (s3a *S3ApiServer) createMultiChunkReader(entry *filer_pb.Entry) (io.ReadCloser, error) {
- // Create a multi-reader that combines all chunks
- var readers []io.Reader
-
- for _, chunk := range entry.GetChunks() {
- chunkReader, err := s3a.createChunkReader(chunk)
- if err != nil {
- return nil, fmt.Errorf("create chunk reader: %w", err)
- }
- readers = append(readers, chunkReader)
- }
-
- multiReader := io.MultiReader(readers...)
- return &multiReadCloser{reader: multiReader}, nil
-}
-
-// createChunkReader creates a reader for a single chunk
-func (s3a *S3ApiServer) createChunkReader(chunk *filer_pb.FileChunk) (io.Reader, error) {
- // Get chunk URL
- srcUrl, err := s3a.lookupVolumeUrl(chunk.GetFileIdString())
- if err != nil {
- return nil, fmt.Errorf("lookup volume URL: %w", err)
- }
-
- // Create HTTP request for chunk data
- req, err := http.NewRequest("GET", srcUrl, nil)
- if err != nil {
- return nil, fmt.Errorf("create HTTP request: %w", err)
- }
-
- // Execute request
- resp, err := http.DefaultClient.Do(req)
- if err != nil {
- return nil, fmt.Errorf("execute HTTP request: %w", err)
- }
-
- if resp.StatusCode != http.StatusOK {
- resp.Body.Close()
- return nil, fmt.Errorf("HTTP request failed: %d", resp.StatusCode)
- }
-
- return resp.Body, nil
-}
-
-// multiReadCloser wraps a multi-reader with a close method
-type multiReadCloser struct {
- reader io.Reader
-}
-
-func (mrc *multiReadCloser) Read(p []byte) (int, error) {
- return mrc.reader.Read(p)
-}
-
-func (mrc *multiReadCloser) Close() error {
- return nil
-}
diff --git a/weed/s3api/s3err/audit_fluent.go b/weed/s3api/s3err/audit_fluent.go
index b63533f1c..ad101cca2 100644
--- a/weed/s3api/s3err/audit_fluent.go
+++ b/weed/s3api/s3err/audit_fluent.go
@@ -128,13 +128,6 @@ func getOperation(object string, r *http.Request) string {
return operation
}
-func GetAccessHttpLog(r *http.Request, statusCode int, s3errCode ErrorCode) AccessLogHTTP {
- return AccessLogHTTP{
- RequestURI: r.RequestURI,
- Referer: r.Header.Get("Referer"),
- }
-}
-
func GetAccessLog(r *http.Request, HTTPStatusCode int, s3errCode ErrorCode) *AccessLog {
bucket, key := s3_constants.GetBucketAndObject(r)
var errorCode string
diff --git a/weed/s3api/s3lifecycle/evaluator.go b/weed/s3api/s3lifecycle/evaluator.go
deleted file mode 100644
index 181b08e44..000000000
--- a/weed/s3api/s3lifecycle/evaluator.go
+++ /dev/null
@@ -1,127 +0,0 @@
-package s3lifecycle
-
-import "time"
-
-// Evaluate checks the given lifecycle rules against an object and returns
-// the highest-priority action that applies. The evaluation follows S3's
-// action priority:
-// 1. ExpiredObjectDeleteMarker (delete marker is sole version)
-// 2. NoncurrentVersionExpiration (non-current version age/count)
-// 3. Current version Expiration (Days or Date)
-//
-// AbortIncompleteMultipartUpload is evaluated separately since it applies
-// to uploads, not objects. Use EvaluateMPUAbort for that.
-func Evaluate(rules []Rule, obj ObjectInfo, now time.Time) EvalResult {
- // Phase 1: ExpiredObjectDeleteMarker
- if obj.IsDeleteMarker && obj.IsLatest && obj.NumVersions == 1 {
- for _, rule := range rules {
- if rule.Status != "Enabled" {
- continue
- }
- if !MatchesFilter(rule, obj) {
- continue
- }
- if rule.ExpiredObjectDeleteMarker {
- return EvalResult{Action: ActionExpireDeleteMarker, RuleID: rule.ID}
- }
- }
- }
-
- // Phase 2: NoncurrentVersionExpiration
- if !obj.IsLatest && !obj.SuccessorModTime.IsZero() {
- for _, rule := range rules {
- if ShouldExpireNoncurrentVersion(rule, obj, obj.NoncurrentIndex, now) {
- return EvalResult{Action: ActionDeleteVersion, RuleID: rule.ID}
- }
- }
- }
-
- // Phase 3: Current version Expiration
- if obj.IsLatest && !obj.IsDeleteMarker {
- for _, rule := range rules {
- if rule.Status != "Enabled" {
- continue
- }
- if !MatchesFilter(rule, obj) {
- continue
- }
- // Date-based expiration
- if !rule.ExpirationDate.IsZero() && !now.Before(rule.ExpirationDate) {
- return EvalResult{Action: ActionDeleteObject, RuleID: rule.ID}
- }
- // Days-based expiration
- if rule.ExpirationDays > 0 {
- expiryTime := expectedExpiryTime(obj.ModTime, rule.ExpirationDays)
- if !now.Before(expiryTime) {
- return EvalResult{Action: ActionDeleteObject, RuleID: rule.ID}
- }
- }
- }
- }
-
- return EvalResult{Action: ActionNone}
-}
-
-// ShouldExpireNoncurrentVersion checks whether a non-current version should
-// be expired considering both NoncurrentDays and NewerNoncurrentVersions.
-// noncurrentIndex is the 0-based position among non-current versions sorted
-// newest-first (0 = newest non-current version).
-func ShouldExpireNoncurrentVersion(rule Rule, obj ObjectInfo, noncurrentIndex int, now time.Time) bool {
- if rule.Status != "Enabled" {
- return false
- }
- if rule.NoncurrentVersionExpirationDays <= 0 {
- return false
- }
- if obj.IsLatest || obj.SuccessorModTime.IsZero() {
- return false
- }
- if !MatchesFilter(rule, obj) {
- return false
- }
-
- // Check age threshold.
- expiryTime := expectedExpiryTime(obj.SuccessorModTime, rule.NoncurrentVersionExpirationDays)
- if now.Before(expiryTime) {
- return false
- }
-
- // Check NewerNoncurrentVersions count threshold.
- if rule.NewerNoncurrentVersions > 0 && noncurrentIndex < rule.NewerNoncurrentVersions {
- return false
- }
-
- return true
-}
-
-// EvaluateMPUAbort finds the applicable AbortIncompleteMultipartUpload rule
-// for a multipart upload with the given key prefix and creation time.
-func EvaluateMPUAbort(rules []Rule, uploadKey string, createdAt time.Time, now time.Time) EvalResult {
- for _, rule := range rules {
- if rule.Status != "Enabled" {
- continue
- }
- if rule.AbortMPUDaysAfterInitiation <= 0 {
- continue
- }
- if !matchesPrefix(rule.Prefix, uploadKey) {
- continue
- }
- cutoff := expectedExpiryTime(createdAt, rule.AbortMPUDaysAfterInitiation)
- if !now.Before(cutoff) {
- return EvalResult{Action: ActionAbortMultipartUpload, RuleID: rule.ID}
- }
- }
- return EvalResult{Action: ActionNone}
-}
-
-// expectedExpiryTime computes the expiration time given a reference time and
-// a number of days. Following S3 semantics, expiration happens at midnight UTC
-// of the day after the specified number of days.
-func expectedExpiryTime(refTime time.Time, days int) time.Time {
- if days == 0 {
- return refTime
- }
- t := refTime.UTC().Add(time.Duration(days+1) * 24 * time.Hour)
- return t.Truncate(24 * time.Hour)
-}
diff --git a/weed/s3api/s3lifecycle/evaluator_test.go b/weed/s3api/s3lifecycle/evaluator_test.go
deleted file mode 100644
index aa58e4bc8..000000000
--- a/weed/s3api/s3lifecycle/evaluator_test.go
+++ /dev/null
@@ -1,495 +0,0 @@
-package s3lifecycle
-
-import (
- "testing"
- "time"
-)
-
-var now = time.Date(2026, 3, 27, 12, 0, 0, 0, time.UTC)
-
-func TestEvaluate_ExpirationDays(t *testing.T) {
- rules := []Rule{{
- ID: "expire-30d", Status: "Enabled",
- ExpirationDays: 30,
- }}
-
- t.Run("object_older_than_days_is_expired", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "data/file.txt", IsLatest: true,
- ModTime: now.Add(-31 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionDeleteObject, result.Action)
- assertEqual(t, "expire-30d", result.RuleID)
- })
-
- t.Run("object_younger_than_days_is_not_expired", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "data/file.txt", IsLatest: true,
- ModTime: now.Add(-10 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-
- t.Run("non_latest_version_not_affected_by_expiration_days", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "data/file.txt", IsLatest: false,
- ModTime: now.Add(-60 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-
- t.Run("delete_marker_not_affected_by_expiration_days", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "data/file.txt", IsLatest: true, IsDeleteMarker: true,
- ModTime: now.Add(-60 * 24 * time.Hour), NumVersions: 3,
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-}
-
-func TestEvaluate_ExpirationDate(t *testing.T) {
- expirationDate := time.Date(2026, 3, 15, 0, 0, 0, 0, time.UTC)
- rules := []Rule{{
- ID: "expire-date", Status: "Enabled",
- ExpirationDate: expirationDate,
- }}
-
- t.Run("object_expired_after_date", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true,
- ModTime: now.Add(-60 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionDeleteObject, result.Action)
- })
-
- t.Run("object_not_expired_before_date", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true,
- ModTime: now.Add(-1 * time.Hour),
- }
- beforeDate := time.Date(2026, 3, 10, 0, 0, 0, 0, time.UTC)
- result := Evaluate(rules, obj, beforeDate)
- assertAction(t, ActionNone, result.Action)
- })
-}
-
-func TestEvaluate_ExpiredObjectDeleteMarker(t *testing.T) {
- rules := []Rule{{
- ID: "cleanup-markers", Status: "Enabled",
- ExpiredObjectDeleteMarker: true,
- }}
-
- t.Run("sole_delete_marker_is_expired", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true, IsDeleteMarker: true,
- NumVersions: 1,
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionExpireDeleteMarker, result.Action)
- })
-
- t.Run("delete_marker_with_other_versions_not_expired", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true, IsDeleteMarker: true,
- NumVersions: 3,
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-
- t.Run("non_latest_delete_marker_not_expired", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: false, IsDeleteMarker: true,
- NumVersions: 1,
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-
- t.Run("non_delete_marker_not_affected", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true, IsDeleteMarker: false,
- NumVersions: 1,
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-}
-
-func TestEvaluate_NoncurrentVersionExpiration(t *testing.T) {
- rules := []Rule{{
- ID: "expire-noncurrent", Status: "Enabled",
- NoncurrentVersionExpirationDays: 30,
- }}
-
- t.Run("old_noncurrent_version_is_expired", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: false,
- SuccessorModTime: now.Add(-45 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionDeleteVersion, result.Action)
- })
-
- t.Run("recent_noncurrent_version_is_not_expired", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: false,
- SuccessorModTime: now.Add(-10 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-
- t.Run("latest_version_not_affected", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true,
- ModTime: now.Add(-60 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-}
-
-func TestShouldExpireNoncurrentVersion(t *testing.T) {
- rule := Rule{
- ID: "noncurrent-rule", Status: "Enabled",
- NoncurrentVersionExpirationDays: 30,
- NewerNoncurrentVersions: 2,
- }
-
- t.Run("old_version_beyond_count_is_expired", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: false,
- SuccessorModTime: now.Add(-45 * 24 * time.Hour),
- }
- // noncurrentIndex=2 means this is the 3rd noncurrent version (0-indexed)
- // With NewerNoncurrentVersions=2, indices 0 and 1 are kept.
- if !ShouldExpireNoncurrentVersion(rule, obj, 2, now) {
- t.Error("expected version at index 2 to be expired")
- }
- })
-
- t.Run("old_version_within_count_is_kept", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: false,
- SuccessorModTime: now.Add(-45 * 24 * time.Hour),
- }
- // noncurrentIndex=1 is within the keep threshold (NewerNoncurrentVersions=2).
- if ShouldExpireNoncurrentVersion(rule, obj, 1, now) {
- t.Error("expected version at index 1 to be kept")
- }
- })
-
- t.Run("recent_version_beyond_count_is_kept", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: false,
- SuccessorModTime: now.Add(-5 * 24 * time.Hour),
- }
- // Even at index 5 (beyond count), if too young, it's kept.
- if ShouldExpireNoncurrentVersion(rule, obj, 5, now) {
- t.Error("expected recent version to be kept regardless of index")
- }
- })
-
- t.Run("disabled_rule_never_expires", func(t *testing.T) {
- disabled := Rule{
- ID: "disabled", Status: "Disabled",
- NoncurrentVersionExpirationDays: 1,
- }
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: false,
- SuccessorModTime: now.Add(-365 * 24 * time.Hour),
- }
- if ShouldExpireNoncurrentVersion(disabled, obj, 10, now) {
- t.Error("disabled rule should never expire")
- }
- })
-}
-
-func TestEvaluate_PrefixFilter(t *testing.T) {
- rules := []Rule{{
- ID: "logs-only", Status: "Enabled",
- Prefix: "logs/",
- ExpirationDays: 7,
- }}
-
- t.Run("matching_prefix", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "logs/app.log", IsLatest: true,
- ModTime: now.Add(-10 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionDeleteObject, result.Action)
- })
-
- t.Run("non_matching_prefix", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "data/file.txt", IsLatest: true,
- ModTime: now.Add(-10 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-}
-
-func TestEvaluate_TagFilter(t *testing.T) {
- rules := []Rule{{
- ID: "temp-only", Status: "Enabled",
- ExpirationDays: 1,
- FilterTags: map[string]string{"env": "temp"},
- }}
-
- t.Run("matching_tags", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true,
- ModTime: now.Add(-5 * 24 * time.Hour),
- Tags: map[string]string{"env": "temp", "project": "foo"},
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionDeleteObject, result.Action)
- })
-
- t.Run("missing_tag", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true,
- ModTime: now.Add(-5 * 24 * time.Hour),
- Tags: map[string]string{"project": "foo"},
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-
- t.Run("wrong_tag_value", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true,
- ModTime: now.Add(-5 * 24 * time.Hour),
- Tags: map[string]string{"env": "prod"},
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-
- t.Run("nil_object_tags", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true,
- ModTime: now.Add(-5 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-}
-
-func TestEvaluate_SizeFilter(t *testing.T) {
- rules := []Rule{{
- ID: "large-files", Status: "Enabled",
- ExpirationDays: 7,
- FilterSizeGreaterThan: 1024 * 1024, // > 1 MB
- FilterSizeLessThan: 100 * 1024 * 1024, // < 100 MB
- }}
-
- t.Run("matching_size", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.bin", IsLatest: true,
- ModTime: now.Add(-10 * 24 * time.Hour),
- Size: 10 * 1024 * 1024, // 10 MB
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionDeleteObject, result.Action)
- })
-
- t.Run("too_small", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.bin", IsLatest: true,
- ModTime: now.Add(-10 * 24 * time.Hour),
- Size: 512, // 512 bytes
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-
- t.Run("too_large", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "file.bin", IsLatest: true,
- ModTime: now.Add(-10 * 24 * time.Hour),
- Size: 200 * 1024 * 1024, // 200 MB
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-}
-
-func TestEvaluate_CombinedFilters(t *testing.T) {
- rules := []Rule{{
- ID: "combined", Status: "Enabled",
- Prefix: "logs/",
- ExpirationDays: 7,
- FilterTags: map[string]string{"env": "dev"},
- FilterSizeGreaterThan: 100,
- }}
-
- t.Run("all_filters_match", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "logs/app.log", IsLatest: true,
- ModTime: now.Add(-10 * 24 * time.Hour),
- Size: 1024,
- Tags: map[string]string{"env": "dev"},
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionDeleteObject, result.Action)
- })
-
- t.Run("prefix_doesnt_match", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "data/app.log", IsLatest: true,
- ModTime: now.Add(-10 * 24 * time.Hour),
- Size: 1024,
- Tags: map[string]string{"env": "dev"},
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-
- t.Run("tag_doesnt_match", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "logs/app.log", IsLatest: true,
- ModTime: now.Add(-10 * 24 * time.Hour),
- Size: 1024,
- Tags: map[string]string{"env": "prod"},
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-
- t.Run("size_doesnt_match", func(t *testing.T) {
- obj := ObjectInfo{
- Key: "logs/app.log", IsLatest: true,
- ModTime: now.Add(-10 * 24 * time.Hour),
- Size: 50, // too small
- Tags: map[string]string{"env": "dev"},
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
- })
-}
-
-func TestEvaluate_DisabledRule(t *testing.T) {
- rules := []Rule{{
- ID: "disabled", Status: "Disabled",
- ExpirationDays: 1,
- }}
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true,
- ModTime: now.Add(-365 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionNone, result.Action)
-}
-
-func TestEvaluate_MultipleRules_Priority(t *testing.T) {
- t.Run("delete_marker_takes_priority_over_expiration", func(t *testing.T) {
- rules := []Rule{
- {ID: "expire", Status: "Enabled", ExpirationDays: 1},
- {ID: "marker", Status: "Enabled", ExpiredObjectDeleteMarker: true},
- }
- obj := ObjectInfo{
- Key: "file.txt", IsLatest: true, IsDeleteMarker: true,
- NumVersions: 1, ModTime: now.Add(-10 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionExpireDeleteMarker, result.Action)
- assertEqual(t, "marker", result.RuleID)
- })
-
- t.Run("first_matching_expiration_rule_wins", func(t *testing.T) {
- rules := []Rule{
- {ID: "rule1", Status: "Enabled", ExpirationDays: 30, Prefix: "logs/"},
- {ID: "rule2", Status: "Enabled", ExpirationDays: 7},
- }
- obj := ObjectInfo{
- Key: "logs/app.log", IsLatest: true,
- ModTime: now.Add(-31 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionDeleteObject, result.Action)
- assertEqual(t, "rule1", result.RuleID)
- })
-}
-
-func TestEvaluate_EmptyPrefix(t *testing.T) {
- rules := []Rule{{
- ID: "all", Status: "Enabled",
- ExpirationDays: 30,
- }}
- obj := ObjectInfo{
- Key: "any/path/file.txt", IsLatest: true,
- ModTime: now.Add(-31 * 24 * time.Hour),
- }
- result := Evaluate(rules, obj, now)
- assertAction(t, ActionDeleteObject, result.Action)
-}
-
-func TestEvaluateMPUAbort(t *testing.T) {
- rules := []Rule{{
- ID: "abort-mpu", Status: "Enabled",
- AbortMPUDaysAfterInitiation: 7,
- }}
-
- t.Run("old_upload_is_aborted", func(t *testing.T) {
- result := EvaluateMPUAbort(rules, "uploads/file.bin", now.Add(-10*24*time.Hour), now)
- assertAction(t, ActionAbortMultipartUpload, result.Action)
- })
-
- t.Run("recent_upload_is_not_aborted", func(t *testing.T) {
- result := EvaluateMPUAbort(rules, "uploads/file.bin", now.Add(-3*24*time.Hour), now)
- assertAction(t, ActionNone, result.Action)
- })
-
- t.Run("prefix_scoped_abort", func(t *testing.T) {
- prefixRules := []Rule{{
- ID: "abort-logs", Status: "Enabled",
- Prefix: "logs/",
- AbortMPUDaysAfterInitiation: 1,
- }}
- result := EvaluateMPUAbort(prefixRules, "data/file.bin", now.Add(-5*24*time.Hour), now)
- assertAction(t, ActionNone, result.Action)
- })
-}
-
-func TestExpectedExpiryTime(t *testing.T) {
- ref := time.Date(2026, 3, 1, 15, 30, 0, 0, time.UTC)
-
- t.Run("30_days", func(t *testing.T) {
- // S3 spec: expires at midnight UTC of day 32 (ref + 31 days, truncated).
- expiry := expectedExpiryTime(ref, 30)
- expected := time.Date(2026, 4, 1, 0, 0, 0, 0, time.UTC)
- if !expiry.Equal(expected) {
- t.Errorf("expected %v, got %v", expected, expiry)
- }
- })
-
- t.Run("zero_days_returns_ref", func(t *testing.T) {
- expiry := expectedExpiryTime(ref, 0)
- if !expiry.Equal(ref) {
- t.Errorf("expected %v, got %v", ref, expiry)
- }
- })
-}
-
-func assertAction(t *testing.T, expected, actual Action) {
- t.Helper()
- if expected != actual {
- t.Errorf("expected action %d, got %d", expected, actual)
- }
-}
-
-func assertEqual(t *testing.T, expected, actual string) {
- t.Helper()
- if expected != actual {
- t.Errorf("expected %q, got %q", expected, actual)
- }
-}
diff --git a/weed/s3api/s3lifecycle/filter.go b/weed/s3api/s3lifecycle/filter.go
deleted file mode 100644
index 394425d60..000000000
--- a/weed/s3api/s3lifecycle/filter.go
+++ /dev/null
@@ -1,56 +0,0 @@
-package s3lifecycle
-
-import "strings"
-
-// MatchesFilter checks if an object matches the rule's filter criteria
-// (prefix, tags, and size constraints).
-func MatchesFilter(rule Rule, obj ObjectInfo) bool {
- if !matchesPrefix(rule.Prefix, obj.Key) {
- return false
- }
- if !matchesTags(rule.FilterTags, obj.Tags) {
- return false
- }
- if !matchesSize(rule.FilterSizeGreaterThan, rule.FilterSizeLessThan, obj.Size) {
- return false
- }
- return true
-}
-
-// matchesPrefix returns true if the object key starts with the given prefix.
-// An empty prefix matches all keys.
-func matchesPrefix(prefix, key string) bool {
- if prefix == "" {
- return true
- }
- return strings.HasPrefix(key, prefix)
-}
-
-// matchesTags returns true if all rule tags are present in the object's tags
-// with matching values. An empty or nil rule tag set matches all objects.
-func matchesTags(ruleTags, objTags map[string]string) bool {
- if len(ruleTags) == 0 {
- return true
- }
- if len(objTags) == 0 {
- return false
- }
- for k, v := range ruleTags {
- if objVal, ok := objTags[k]; !ok || objVal != v {
- return false
- }
- }
- return true
-}
-
-// matchesSize returns true if the object's size falls within the specified
-// bounds. Zero values mean no constraint on that side.
-func matchesSize(greaterThan, lessThan, objSize int64) bool {
- if greaterThan > 0 && objSize <= greaterThan {
- return false
- }
- if lessThan > 0 && objSize >= lessThan {
- return false
- }
- return true
-}
diff --git a/weed/s3api/s3lifecycle/filter_test.go b/weed/s3api/s3lifecycle/filter_test.go
deleted file mode 100644
index c8bcfeb10..000000000
--- a/weed/s3api/s3lifecycle/filter_test.go
+++ /dev/null
@@ -1,79 +0,0 @@
-package s3lifecycle
-
-import "testing"
-
-func TestMatchesPrefix(t *testing.T) {
- tests := []struct {
- name string
- prefix string
- key string
- want bool
- }{
- {"empty_prefix_matches_all", "", "any/key.txt", true},
- {"exact_prefix_match", "logs/", "logs/app.log", true},
- {"prefix_mismatch", "logs/", "data/file.txt", false},
- {"key_shorter_than_prefix", "very/long/prefix/", "short", false},
- {"prefix_equals_key", "exact", "exact", true},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := matchesPrefix(tt.prefix, tt.key); got != tt.want {
- t.Errorf("matchesPrefix(%q, %q) = %v, want %v", tt.prefix, tt.key, got, tt.want)
- }
- })
- }
-}
-
-func TestMatchesTags(t *testing.T) {
- tests := []struct {
- name string
- ruleTags map[string]string
- objTags map[string]string
- want bool
- }{
- {"nil_rule_tags_match_all", nil, map[string]string{"a": "1"}, true},
- {"empty_rule_tags_match_all", map[string]string{}, map[string]string{"a": "1"}, true},
- {"nil_obj_tags_no_match", map[string]string{"a": "1"}, nil, false},
- {"single_tag_match", map[string]string{"env": "dev"}, map[string]string{"env": "dev", "foo": "bar"}, true},
- {"single_tag_value_mismatch", map[string]string{"env": "dev"}, map[string]string{"env": "prod"}, false},
- {"single_tag_key_missing", map[string]string{"env": "dev"}, map[string]string{"foo": "bar"}, false},
- {"multi_tag_all_match", map[string]string{"env": "dev", "tier": "hot"}, map[string]string{"env": "dev", "tier": "hot", "extra": "x"}, true},
- {"multi_tag_partial_match", map[string]string{"env": "dev", "tier": "hot"}, map[string]string{"env": "dev"}, false},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := matchesTags(tt.ruleTags, tt.objTags); got != tt.want {
- t.Errorf("matchesTags() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func TestMatchesSize(t *testing.T) {
- tests := []struct {
- name string
- greaterThan int64
- lessThan int64
- objSize int64
- want bool
- }{
- {"no_constraints", 0, 0, 1000, true},
- {"only_greater_than_pass", 100, 0, 200, true},
- {"only_greater_than_fail", 100, 0, 50, false},
- {"only_greater_than_equal_fail", 100, 0, 100, false},
- {"only_less_than_pass", 0, 1000, 500, true},
- {"only_less_than_fail", 0, 1000, 2000, false},
- {"only_less_than_equal_fail", 0, 1000, 1000, false},
- {"both_constraints_pass", 100, 1000, 500, true},
- {"both_constraints_too_small", 100, 1000, 50, false},
- {"both_constraints_too_large", 100, 1000, 2000, false},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if got := matchesSize(tt.greaterThan, tt.lessThan, tt.objSize); got != tt.want {
- t.Errorf("matchesSize(%d, %d, %d) = %v, want %v",
- tt.greaterThan, tt.lessThan, tt.objSize, got, tt.want)
- }
- })
- }
-}
diff --git a/weed/s3api/s3lifecycle/tags.go b/weed/s3api/s3lifecycle/tags.go
index 57092ed56..49bbaab66 100644
--- a/weed/s3api/s3lifecycle/tags.go
+++ b/weed/s3api/s3lifecycle/tags.go
@@ -1,34 +1,3 @@
package s3lifecycle
-import "strings"
-
const tagPrefix = "X-Amz-Tagging-"
-
-// ExtractTags extracts S3 object tags from a filer entry's Extended metadata.
-// Tags are stored with the key prefix "X-Amz-Tagging-" followed by the tag key.
-func ExtractTags(extended map[string][]byte) map[string]string {
- if len(extended) == 0 {
- return nil
- }
- var tags map[string]string
- for k, v := range extended {
- if strings.HasPrefix(k, tagPrefix) {
- if tags == nil {
- tags = make(map[string]string)
- }
- tags[k[len(tagPrefix):]] = string(v)
- }
- }
- return tags
-}
-
-// HasTagRules returns true if any enabled rule in the set uses tag-based filtering.
-// This is used as an optimization to skip tag extraction when no rules need it.
-func HasTagRules(rules []Rule) bool {
- for _, r := range rules {
- if r.Status == "Enabled" && len(r.FilterTags) > 0 {
- return true
- }
- }
- return false
-}
diff --git a/weed/s3api/s3lifecycle/tags_test.go b/weed/s3api/s3lifecycle/tags_test.go
deleted file mode 100644
index 0eb198c5f..000000000
--- a/weed/s3api/s3lifecycle/tags_test.go
+++ /dev/null
@@ -1,89 +0,0 @@
-package s3lifecycle
-
-import "testing"
-
-func TestExtractTags(t *testing.T) {
- t.Run("extracts_tags_with_prefix", func(t *testing.T) {
- extended := map[string][]byte{
- "X-Amz-Tagging-env": []byte("prod"),
- "X-Amz-Tagging-project": []byte("foo"),
- "Content-Type": []byte("text/plain"),
- "X-Amz-Meta-Custom": []byte("value"),
- }
- tags := ExtractTags(extended)
- if len(tags) != 2 {
- t.Fatalf("expected 2 tags, got %d", len(tags))
- }
- if tags["env"] != "prod" {
- t.Errorf("expected env=prod, got %q", tags["env"])
- }
- if tags["project"] != "foo" {
- t.Errorf("expected project=foo, got %q", tags["project"])
- }
- })
-
- t.Run("nil_extended_returns_nil", func(t *testing.T) {
- tags := ExtractTags(nil)
- if tags != nil {
- t.Errorf("expected nil, got %v", tags)
- }
- })
-
- t.Run("no_tags_returns_nil", func(t *testing.T) {
- extended := map[string][]byte{
- "Content-Type": []byte("text/plain"),
- }
- tags := ExtractTags(extended)
- if tags != nil {
- t.Errorf("expected nil, got %v", tags)
- }
- })
-
- t.Run("empty_tag_value", func(t *testing.T) {
- extended := map[string][]byte{
- "X-Amz-Tagging-empty": []byte(""),
- }
- tags := ExtractTags(extended)
- if len(tags) != 1 {
- t.Fatalf("expected 1 tag, got %d", len(tags))
- }
- if tags["empty"] != "" {
- t.Errorf("expected empty value, got %q", tags["empty"])
- }
- })
-}
-
-func TestHasTagRules(t *testing.T) {
- t.Run("has_tag_rules", func(t *testing.T) {
- rules := []Rule{
- {Status: "Enabled", FilterTags: map[string]string{"env": "dev"}},
- }
- if !HasTagRules(rules) {
- t.Error("expected true")
- }
- })
-
- t.Run("no_tag_rules", func(t *testing.T) {
- rules := []Rule{
- {Status: "Enabled", ExpirationDays: 30},
- }
- if HasTagRules(rules) {
- t.Error("expected false")
- }
- })
-
- t.Run("disabled_tag_rule", func(t *testing.T) {
- rules := []Rule{
- {Status: "Disabled", FilterTags: map[string]string{"env": "dev"}},
- }
- if HasTagRules(rules) {
- t.Error("expected false for disabled rule")
- }
- })
-
- t.Run("empty_rules", func(t *testing.T) {
- if HasTagRules(nil) {
- t.Error("expected false for nil rules")
- }
- })
-}
diff --git a/weed/s3api/s3lifecycle/version_time.go b/weed/s3api/s3lifecycle/version_time.go
index fb6cfbbf5..a9d2e9ae2 100644
--- a/weed/s3api/s3lifecycle/version_time.go
+++ b/weed/s3api/s3lifecycle/version_time.go
@@ -1,99 +1,6 @@
package s3lifecycle
-import (
- "math"
- "strconv"
- "time"
-)
-
// versionIdFormatThreshold distinguishes old vs new format version IDs.
// New format (inverted timestamps) produces values above this threshold;
// old format (raw timestamps) produces values below it.
const versionIdFormatThreshold = 0x4000000000000000
-
-// GetVersionTimestamp extracts the actual timestamp from a SeaweedFS version ID,
-// handling both old (raw nanosecond) and new (inverted nanosecond) formats.
-// Returns zero time if the version ID is invalid or "null".
-func GetVersionTimestamp(versionId string) time.Time {
- ns := getVersionTimestampNanos(versionId)
- if ns == 0 {
- return time.Time{}
- }
- return time.Unix(0, ns)
-}
-
-// getVersionTimestampNanos extracts the raw nanosecond timestamp from a version ID.
-func getVersionTimestampNanos(versionId string) int64 {
- if len(versionId) < 16 || versionId == "null" {
- return 0
- }
- timestampPart, err := strconv.ParseUint(versionId[:16], 16, 64)
- if err != nil {
- return 0
- }
- if timestampPart > math.MaxInt64 {
- return 0
- }
- if timestampPart > versionIdFormatThreshold {
- // New format: inverted timestamp, convert back.
- return int64(math.MaxInt64 - timestampPart)
- }
- return int64(timestampPart)
-}
-
-// isNewFormatVersionId returns true if the version ID uses inverted timestamps.
-func isNewFormatVersionId(versionId string) bool {
- if len(versionId) < 16 || versionId == "null" {
- return false
- }
- timestampPart, err := strconv.ParseUint(versionId[:16], 16, 64)
- if err != nil {
- return false
- }
- return timestampPart > versionIdFormatThreshold && timestampPart <= math.MaxInt64
-}
-
-// CompareVersionIds compares two version IDs for sorting (newest first).
-// Returns negative if a is newer, positive if b is newer, 0 if equal.
-// Handles both old and new format version IDs and uses full lexicographic
-// comparison (not just timestamps) to break ties from the random suffix.
-func CompareVersionIds(a, b string) int {
- if a == b {
- return 0
- }
- if a == "null" {
- return 1
- }
- if b == "null" {
- return -1
- }
-
- aIsNew := isNewFormatVersionId(a)
- bIsNew := isNewFormatVersionId(b)
-
- if aIsNew == bIsNew {
- if aIsNew {
- // New format: smaller hex = newer (inverted timestamps).
- if a < b {
- return -1
- }
- return 1
- }
- // Old format: smaller hex = older.
- if a < b {
- return 1
- }
- return -1
- }
-
- // Mixed formats: compare by actual timestamp.
- aTime := getVersionTimestampNanos(a)
- bTime := getVersionTimestampNanos(b)
- if aTime > bTime {
- return -1
- }
- if aTime < bTime {
- return 1
- }
- return 0
-}
diff --git a/weed/s3api/s3lifecycle/version_time_test.go b/weed/s3api/s3lifecycle/version_time_test.go
deleted file mode 100644
index 460cbec58..000000000
--- a/weed/s3api/s3lifecycle/version_time_test.go
+++ /dev/null
@@ -1,74 +0,0 @@
-package s3lifecycle
-
-import (
- "fmt"
- "math"
- "testing"
- "time"
-)
-
-func TestGetVersionTimestamp(t *testing.T) {
- t.Run("new_format_inverted_timestamp", func(t *testing.T) {
- // Simulate a new-format version ID (inverted timestamp above threshold).
- now := time.Now()
- inverted := math.MaxInt64 - now.UnixNano()
- versionId := fmt.Sprintf("%016x", inverted) + "0000000000000000"
-
- got := GetVersionTimestamp(versionId)
- // Should recover the original timestamp within 1 second.
- diff := got.Sub(now)
- if diff < -time.Second || diff > time.Second {
- t.Errorf("timestamp diff too large: %v (got %v, want ~%v)", diff, got, now)
- }
- })
-
- t.Run("old_format_raw_timestamp", func(t *testing.T) {
- // Simulate an old-format version ID (raw nanosecond timestamp below threshold).
- // Use a timestamp from 2023 which would be below threshold.
- ts := time.Date(2023, 6, 15, 12, 0, 0, 0, time.UTC)
- versionId := fmt.Sprintf("%016x", ts.UnixNano()) + "abcdef0123456789"
-
- got := GetVersionTimestamp(versionId)
- if !got.Equal(ts) {
- t.Errorf("expected %v, got %v", ts, got)
- }
- })
-
- t.Run("null_version_id", func(t *testing.T) {
- got := GetVersionTimestamp("null")
- if !got.IsZero() {
- t.Errorf("expected zero time for null version, got %v", got)
- }
- })
-
- t.Run("empty_version_id", func(t *testing.T) {
- got := GetVersionTimestamp("")
- if !got.IsZero() {
- t.Errorf("expected zero time for empty version, got %v", got)
- }
- })
-
- t.Run("short_version_id", func(t *testing.T) {
- got := GetVersionTimestamp("abc")
- if !got.IsZero() {
- t.Errorf("expected zero time for short version, got %v", got)
- }
- })
-
- t.Run("high_bit_overflow_returns_zero", func(t *testing.T) {
- // Version ID with first 16 hex chars > math.MaxInt64 should return zero,
- // not a wrapped negative timestamp.
- versionId := "80000000000000000000000000000000"
- got := GetVersionTimestamp(versionId)
- if !got.IsZero() {
- t.Errorf("expected zero time for overflow version ID, got %v", got)
- }
- })
-
- t.Run("invalid_hex", func(t *testing.T) {
- got := GetVersionTimestamp("zzzzzzzzzzzzzzzz0000000000000000")
- if !got.IsZero() {
- t.Errorf("expected zero time for invalid hex, got %v", got)
- }
- })
-}
diff --git a/weed/s3api/s3tables/filer_ops.go b/weed/s3api/s3tables/filer_ops.go
index 7edb8a2a5..7a0ad66ff 100644
--- a/weed/s3api/s3tables/filer_ops.go
+++ b/weed/s3api/s3tables/filer_ops.go
@@ -50,46 +50,6 @@ func (h *S3TablesHandler) ensureDirectory(ctx context.Context, client filer_pb.S
return err
}
-// upsertFile creates or updates a small file with the given content
-func (h *S3TablesHandler) upsertFile(ctx context.Context, client filer_pb.SeaweedFilerClient, path string, data []byte) error {
- dir, name := splitPath(path)
- now := time.Now().Unix()
- resp, err := filer_pb.LookupEntry(ctx, client, &filer_pb.LookupDirectoryEntryRequest{
- Directory: dir,
- Name: name,
- })
- if err != nil {
- if !errors.Is(err, filer_pb.ErrNotFound) {
- return err
- }
- return filer_pb.CreateEntry(ctx, client, &filer_pb.CreateEntryRequest{
- Directory: dir,
- Entry: &filer_pb.Entry{
- Name: name,
- Content: data,
- Attributes: &filer_pb.FuseAttributes{
- Mtime: now,
- Crtime: now,
- FileMode: uint32(0644),
- FileSize: uint64(len(data)),
- },
- },
- })
- }
-
- entry := resp.Entry
- if entry.Attributes == nil {
- entry.Attributes = &filer_pb.FuseAttributes{}
- }
- entry.Attributes.Mtime = now
- entry.Attributes.FileSize = uint64(len(data))
- entry.Content = data
- return filer_pb.UpdateEntry(ctx, client, &filer_pb.UpdateEntryRequest{
- Directory: dir,
- Entry: entry,
- })
-}
-
// deleteEntryIfExists removes an entry if it exists, ignoring missing errors
func (h *S3TablesHandler) deleteEntryIfExists(ctx context.Context, client filer_pb.SeaweedFilerClient, path string) error {
dir, name := splitPath(path)
diff --git a/weed/s3api/s3tables/iceberg_layout.go b/weed/s3api/s3tables/iceberg_layout.go
index a71fb221d..a754b5d06 100644
--- a/weed/s3api/s3tables/iceberg_layout.go
+++ b/weed/s3api/s3tables/iceberg_layout.go
@@ -1,14 +1,9 @@
package s3tables
import (
- "context"
- "encoding/json"
- "errors"
pathpkg "path"
"regexp"
"strings"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
)
// Iceberg file layout validation
@@ -307,130 +302,3 @@ func (v *TableBucketFileValidator) ValidateTableBucketUpload(fullPath string) er
return v.layoutValidator.ValidateFilePath(tableRelativePath)
}
-
-// IsTableBucketPath checks if a path is under the table buckets directory
-func IsTableBucketPath(fullPath string) bool {
- return strings.HasPrefix(fullPath, TablesPath+"/")
-}
-
-// GetTableInfoFromPath extracts bucket, namespace, and table names from a table bucket path
-// Returns empty strings if the path doesn't contain enough components
-func GetTableInfoFromPath(fullPath string) (bucket, namespace, table string) {
- if !strings.HasPrefix(fullPath, TablesPath+"/") {
- return "", "", ""
- }
-
- relativePath := strings.TrimPrefix(fullPath, TablesPath+"/")
- parts := strings.SplitN(relativePath, "/", 4)
-
- if len(parts) >= 1 {
- bucket = parts[0]
- }
- if len(parts) >= 2 {
- namespace = parts[1]
- }
- if len(parts) >= 3 {
- table = parts[2]
- }
-
- return
-}
-
-// ValidateTableBucketUploadWithClient validates upload and checks that the table exists and is ICEBERG format
-func (v *TableBucketFileValidator) ValidateTableBucketUploadWithClient(
- ctx context.Context,
- client filer_pb.SeaweedFilerClient,
- fullPath string,
-) error {
- // If not a table bucket path, nothing more to check
- if !IsTableBucketPath(fullPath) {
- return nil
- }
-
- // Get table info and verify it exists
- bucket, namespace, table := GetTableInfoFromPath(fullPath)
- if bucket == "" || namespace == "" || table == "" {
- return nil // Not deep enough to need validation
- }
-
- if strings.HasPrefix(bucket, ".") {
- return nil
- }
-
- resp, err := filer_pb.LookupEntry(ctx, client, &filer_pb.LookupDirectoryEntryRequest{
- Directory: TablesPath,
- Name: bucket,
- })
- if err != nil {
- if errors.Is(err, filer_pb.ErrNotFound) {
- return nil
- }
- return &IcebergLayoutError{
- Code: ErrCodeInvalidIcebergLayout,
- Message: "failed to verify table bucket: " + err.Error(),
- }
- }
- if resp == nil || !IsTableBucketEntry(resp.Entry) {
- return nil
- }
-
- // Now check basic layout once we know this is a table bucket path.
- if err := v.ValidateTableBucketUpload(fullPath); err != nil {
- return err
- }
-
- // Verify the table exists and has ICEBERG format by checking its metadata
- tablePath := GetTablePath(bucket, namespace, table)
- dir, name := splitPath(tablePath)
-
- resp, err = filer_pb.LookupEntry(ctx, client, &filer_pb.LookupDirectoryEntryRequest{
- Directory: dir,
- Name: name,
- })
- if err != nil {
- // Distinguish between "not found" and other errors
- if errors.Is(err, filer_pb.ErrNotFound) {
- return &IcebergLayoutError{
- Code: ErrCodeInvalidIcebergLayout,
- Message: "table does not exist",
- }
- }
- return &IcebergLayoutError{
- Code: ErrCodeInvalidIcebergLayout,
- Message: "failed to verify table existence: " + err.Error(),
- }
- }
-
- // Check if table has metadata indicating ICEBERG format
- if resp.Entry == nil || resp.Entry.Extended == nil {
- return &IcebergLayoutError{
- Code: ErrCodeInvalidIcebergLayout,
- Message: "table is not a valid ICEBERG table (missing metadata)",
- }
- }
-
- metadataBytes, ok := resp.Entry.Extended[ExtendedKeyMetadata]
- if !ok {
- return &IcebergLayoutError{
- Code: ErrCodeInvalidIcebergLayout,
- Message: "table is not in ICEBERG format (missing format metadata)",
- }
- }
-
- var metadata tableMetadataInternal
- if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
- return &IcebergLayoutError{
- Code: ErrCodeInvalidIcebergLayout,
- Message: "failed to parse table metadata: " + err.Error(),
- }
- }
- const TableFormatIceberg = "ICEBERG"
- if metadata.Format != TableFormatIceberg {
- return &IcebergLayoutError{
- Code: ErrCodeInvalidIcebergLayout,
- Message: "table is not in " + TableFormatIceberg + " format",
- }
- }
-
- return nil
-}
diff --git a/weed/s3api/s3tables/iceberg_layout_test.go b/weed/s3api/s3tables/iceberg_layout_test.go
deleted file mode 100644
index d68b77b46..000000000
--- a/weed/s3api/s3tables/iceberg_layout_test.go
+++ /dev/null
@@ -1,186 +0,0 @@
-package s3tables
-
-import (
- "testing"
-)
-
-func TestIcebergLayoutValidator_ValidateFilePath(t *testing.T) {
- v := NewIcebergLayoutValidator()
-
- tests := []struct {
- name string
- path string
- wantErr bool
- }{
- // Valid metadata files
- {"valid metadata v1", "metadata/v1.metadata.json", false},
- {"valid metadata v123", "metadata/v123.metadata.json", false},
- {"valid snapshot manifest", "metadata/snap-123-1-abc12345-1234-5678-9abc-def012345678.avro", false},
- {"valid manifest file", "metadata/abc12345-1234-5678-9abc-def012345678-m0.avro", false},
- {"valid general manifest", "metadata/abc12345-1234-5678-9abc-def012345678.avro", false},
- {"valid version hint", "metadata/version-hint.text", false},
- {"valid uuid metadata", "metadata/abc12345-1234-5678-9abc-def012345678.metadata.json", false},
- {"valid trino stats", "metadata/20260208_212535_00007_bn4hb-d3599c32-1709-4b94-b6b2-1957b6d6db04.stats", false},
-
- // Valid data files
- {"valid parquet file", "data/file.parquet", false},
- {"valid orc file", "data/file.orc", false},
- {"valid avro data file", "data/file.avro", false},
- {"valid parquet with path", "data/00000-0-abc12345.parquet", false},
-
- // Valid partitioned data
- {"valid partitioned parquet", "data/year=2024/file.parquet", false},
- {"valid multi-partition", "data/year=2024/month=01/file.parquet", false},
- {"valid bucket subdirectory", "data/bucket0/file.parquet", false},
-
- // Directories only
- {"metadata directory bare", "metadata", true},
- {"data directory bare", "data", true},
- {"metadata directory with slash", "metadata/", false},
- {"data directory with slash", "data/", false},
-
- // Invalid paths
- {"empty path", "", true},
- {"invalid top dir", "invalid/file.parquet", true},
- {"root file", "file.parquet", true},
- {"invalid metadata file", "metadata/random.txt", true},
- {"nested metadata directory", "metadata/nested/v1.metadata.json", true},
- {"nested metadata directory no file", "metadata/nested/", true},
- {"metadata subdir no slash", "metadata/nested", true},
- {"invalid data file", "data/file.csv", true},
- {"invalid data file json", "data/file.json", true},
-
- // Partition/subdirectory without trailing slashes
- {"partition directory no slash", "data/year=2024", false},
- {"data subdirectory no slash", "data/my_subdir", false},
- {"multi-level partition", "data/event_date=2025-01-01/hour=00/file.parquet", false},
- {"multi-level partition directory", "data/event_date=2025-01-01/hour=00/", false},
- {"multi-level partition directory no slash", "data/event_date=2025-01-01/hour=00", false},
-
- // Double slashes
- {"data double slash", "data//file.parquet", true},
- {"data redundant slash", "data/year=2024//file.parquet", true},
- {"metadata redundant slash", "metadata//v1.metadata.json", true},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := v.ValidateFilePath(tt.path)
- if (err != nil) != tt.wantErr {
- t.Errorf("ValidateFilePath(%q) error = %v, wantErr %v", tt.path, err, tt.wantErr)
- }
- })
- }
-}
-
-func TestIcebergLayoutValidator_PartitionPaths(t *testing.T) {
- v := NewIcebergLayoutValidator()
-
- validPaths := []string{
- "data/year=2024/file.parquet",
- "data/date=2024-01-15/file.parquet",
- "data/category=electronics/file.parquet",
- "data/user_id=12345/file.parquet",
- "data/region=us-east-1/file.parquet",
- "data/year=2024/month=01/day=15/file.parquet",
- }
-
- for _, path := range validPaths {
- if err := v.ValidateFilePath(path); err != nil {
- t.Errorf("ValidateFilePath(%q) should be valid, got error: %v", path, err)
- }
- }
-}
-
-func TestTableBucketFileValidator_ValidateTableBucketUpload(t *testing.T) {
- v := NewTableBucketFileValidator()
-
- tests := []struct {
- name string
- path string
- wantErr bool
- }{
- // Non-table bucket paths should pass (no validation)
- {"regular bucket path", "/buckets/mybucket/file.txt", false},
- {"filer path", "/home/user/file.txt", false},
-
- // Table bucket structure paths (creating directories)
- {"table bucket root", "/buckets/mybucket", false},
- {"namespace dir", "/buckets/mybucket/myns", false},
- {"table dir", "/buckets/mybucket/myns/mytable", false},
- {"table dir trailing slash", "/buckets/mybucket/myns/mytable/", false},
-
- // Valid table bucket file uploads
- {"valid parquet upload", "/buckets/mybucket/myns/mytable/data/file.parquet", false},
- {"valid metadata upload", "/buckets/mybucket/myns/mytable/metadata/v1.metadata.json", false},
- {"valid trino stats upload", "/buckets/mybucket/myns/mytable/metadata/20260208_212535_00007_bn4hb-d3599c32-1709-4b94-b6b2-1957b6d6db04.stats", false},
- {"valid partitioned data", "/buckets/mybucket/myns/mytable/data/year=2024/file.parquet", false},
-
- // Invalid table bucket file uploads
- {"invalid file type", "/buckets/mybucket/myns/mytable/data/file.csv", true},
- {"invalid top-level dir", "/buckets/mybucket/myns/mytable/invalid/file.parquet", true},
- {"root file in table", "/buckets/mybucket/myns/mytable/file.parquet", true},
-
- // Empty segment cases
- {"empty bucket", "/buckets//myns/mytable/data/file.parquet", true},
- {"empty namespace", "/buckets/mybucket//mytable/data/file.parquet", true},
- {"empty table", "/buckets/mybucket/myns//data/file.parquet", true},
- {"empty bucket dir", "/buckets//", true},
- {"empty namespace dir", "/buckets/mybucket//", true},
- {"table double slash bypass", "/buckets/mybucket/myns/mytable//data/file.parquet", true},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := v.ValidateTableBucketUpload(tt.path)
- if (err != nil) != tt.wantErr {
- t.Errorf("ValidateTableBucketUpload(%q) error = %v, wantErr %v", tt.path, err, tt.wantErr)
- }
- })
- }
-}
-
-func TestIsTableBucketPath(t *testing.T) {
- tests := []struct {
- path string
- want bool
- }{
- {"/buckets/mybucket", true},
- {"/buckets/mybucket/ns/table/data/file.parquet", true},
- {"/home/user/file.txt", false},
- {"buckets/mybucket", false}, // missing leading slash
- }
-
- for _, tt := range tests {
- t.Run(tt.path, func(t *testing.T) {
- if got := IsTableBucketPath(tt.path); got != tt.want {
- t.Errorf("IsTableBucketPath(%q) = %v, want %v", tt.path, got, tt.want)
- }
- })
- }
-}
-
-func TestGetTableInfoFromPath(t *testing.T) {
- tests := []struct {
- path string
- wantBucket string
- wantNamespace string
- wantTable string
- }{
- {"/buckets/mybucket/myns/mytable/data/file.parquet", "mybucket", "myns", "mytable"},
- {"/buckets/mybucket/myns/mytable", "mybucket", "myns", "mytable"},
- {"/buckets/mybucket/myns", "mybucket", "myns", ""},
- {"/buckets/mybucket", "mybucket", "", ""},
- {"/home/user/file.txt", "", "", ""},
- }
-
- for _, tt := range tests {
- t.Run(tt.path, func(t *testing.T) {
- bucket, namespace, table := GetTableInfoFromPath(tt.path)
- if bucket != tt.wantBucket || namespace != tt.wantNamespace || table != tt.wantTable {
- t.Errorf("GetTableInfoFromPath(%q) = (%q, %q, %q), want (%q, %q, %q)",
- tt.path, bucket, namespace, table, tt.wantBucket, tt.wantNamespace, tt.wantTable)
- }
- })
- }
-}
diff --git a/weed/s3api/s3tables/permissions.go b/weed/s3api/s3tables/permissions.go
index 4ce198b6d..e5cb45a01 100644
--- a/weed/s3api/s3tables/permissions.go
+++ b/weed/s3api/s3tables/permissions.go
@@ -90,17 +90,6 @@ type PolicyContext struct {
DefaultAllow bool
}
-// CheckPermissionWithResource checks if a principal has permission to perform an operation on a specific resource
-func CheckPermissionWithResource(operation, principal, owner, resourcePolicy, resourceARN string) bool {
- return CheckPermissionWithContext(operation, principal, owner, resourcePolicy, resourceARN, nil)
-}
-
-// CheckPermission checks if a principal has permission to perform an operation
-// (without resource-specific validation - for backward compatibility)
-func CheckPermission(operation, principal, owner, resourcePolicy string) bool {
- return CheckPermissionWithContext(operation, principal, owner, resourcePolicy, "", nil)
-}
-
// CheckPermissionWithContext checks permission with optional resource and condition context.
func CheckPermissionWithContext(operation, principal, owner, resourcePolicy, resourceARN string, ctx *PolicyContext) bool {
// Deny access if identities are empty
@@ -415,113 +404,6 @@ func matchesResourcePattern(pattern, resourceARN string) bool {
return wildcard.MatchesWildcard(pattern, resourceARN)
}
-// Helper functions for specific permissions
-
-// CanCreateTableBucket checks if principal can create table buckets
-func CanCreateTableBucket(principal, owner, resourcePolicy string) bool {
- return CheckPermission("CreateTableBucket", principal, owner, resourcePolicy)
-}
-
-// CanGetTableBucket checks if principal can get table bucket details
-func CanGetTableBucket(principal, owner, resourcePolicy string) bool {
- return CheckPermission("GetTableBucket", principal, owner, resourcePolicy)
-}
-
-// CanListTableBuckets checks if principal can list table buckets
-func CanListTableBuckets(principal, owner, resourcePolicy string) bool {
- return CheckPermission("ListTableBuckets", principal, owner, resourcePolicy)
-}
-
-// CanDeleteTableBucket checks if principal can delete table buckets
-func CanDeleteTableBucket(principal, owner, resourcePolicy string) bool {
- return CheckPermission("DeleteTableBucket", principal, owner, resourcePolicy)
-}
-
-// CanPutTableBucketPolicy checks if principal can put table bucket policies
-func CanPutTableBucketPolicy(principal, owner, resourcePolicy string) bool {
- return CheckPermission("PutTableBucketPolicy", principal, owner, resourcePolicy)
-}
-
-// CanGetTableBucketPolicy checks if principal can get table bucket policies
-func CanGetTableBucketPolicy(principal, owner, resourcePolicy string) bool {
- return CheckPermission("GetTableBucketPolicy", principal, owner, resourcePolicy)
-}
-
-// CanDeleteTableBucketPolicy checks if principal can delete table bucket policies
-func CanDeleteTableBucketPolicy(principal, owner, resourcePolicy string) bool {
- return CheckPermission("DeleteTableBucketPolicy", principal, owner, resourcePolicy)
-}
-
-// CanCreateNamespace checks if principal can create namespaces
-func CanCreateNamespace(principal, owner, resourcePolicy string) bool {
- return CheckPermission("CreateNamespace", principal, owner, resourcePolicy)
-}
-
-// CanGetNamespace checks if principal can get namespace details
-func CanGetNamespace(principal, owner, resourcePolicy string) bool {
- return CheckPermission("GetNamespace", principal, owner, resourcePolicy)
-}
-
-// CanListNamespaces checks if principal can list namespaces
-func CanListNamespaces(principal, owner, resourcePolicy string) bool {
- return CheckPermission("ListNamespaces", principal, owner, resourcePolicy)
-}
-
-// CanDeleteNamespace checks if principal can delete namespaces
-func CanDeleteNamespace(principal, owner, resourcePolicy string) bool {
- return CheckPermission("DeleteNamespace", principal, owner, resourcePolicy)
-}
-
-// CanCreateTable checks if principal can create tables
-func CanCreateTable(principal, owner, resourcePolicy string) bool {
- return CheckPermission("CreateTable", principal, owner, resourcePolicy)
-}
-
-// CanGetTable checks if principal can get table details
-func CanGetTable(principal, owner, resourcePolicy string) bool {
- return CheckPermission("GetTable", principal, owner, resourcePolicy)
-}
-
-// CanListTables checks if principal can list tables
-func CanListTables(principal, owner, resourcePolicy string) bool {
- return CheckPermission("ListTables", principal, owner, resourcePolicy)
-}
-
-// CanDeleteTable checks if principal can delete tables
-func CanDeleteTable(principal, owner, resourcePolicy string) bool {
- return CheckPermission("DeleteTable", principal, owner, resourcePolicy)
-}
-
-// CanPutTablePolicy checks if principal can put table policies
-func CanPutTablePolicy(principal, owner, resourcePolicy string) bool {
- return CheckPermission("PutTablePolicy", principal, owner, resourcePolicy)
-}
-
-// CanGetTablePolicy checks if principal can get table policies
-func CanGetTablePolicy(principal, owner, resourcePolicy string) bool {
- return CheckPermission("GetTablePolicy", principal, owner, resourcePolicy)
-}
-
-// CanDeleteTablePolicy checks if principal can delete table policies
-func CanDeleteTablePolicy(principal, owner, resourcePolicy string) bool {
- return CheckPermission("DeleteTablePolicy", principal, owner, resourcePolicy)
-}
-
-// CanTagResource checks if principal can tag a resource
-func CanTagResource(principal, owner, resourcePolicy string) bool {
- return CheckPermission("TagResource", principal, owner, resourcePolicy)
-}
-
-// CanUntagResource checks if principal can untag a resource
-func CanUntagResource(principal, owner, resourcePolicy string) bool {
- return CheckPermission("UntagResource", principal, owner, resourcePolicy)
-}
-
-// CanManageTags checks if principal can manage tags (tag or untag)
-func CanManageTags(principal, owner, resourcePolicy string) bool {
- return CanTagResource(principal, owner, resourcePolicy) || CanUntagResource(principal, owner, resourcePolicy)
-}
-
// AuthError represents an authorization error
type AuthError struct {
Operation string
diff --git a/weed/s3api/s3tables/utils.go b/weed/s3api/s3tables/utils.go
index ff5dd0fe2..2aedefa2b 100644
--- a/weed/s3api/s3tables/utils.go
+++ b/weed/s3api/s3tables/utils.go
@@ -200,11 +200,6 @@ func validateBucketName(name string) error {
return nil
}
-// ValidateBucketName validates bucket name and returns an error if invalid.
-func ValidateBucketName(name string) error {
- return validateBucketName(name)
-}
-
// BuildBucketARN builds a bucket ARN with the provided region and account ID.
// If region is empty, the ARN will omit the region field.
func BuildBucketARN(region, accountID, bucketName string) (string, error) {
@@ -367,11 +362,6 @@ func validateNamespace(namespace []string) (string, error) {
return flattenNamespace(parts), nil
}
-// ValidateNamespace is a wrapper to validate namespace for other packages.
-func ValidateNamespace(namespace []string) (string, error) {
- return validateNamespace(namespace)
-}
-
// ParseNamespace parses a namespace string into namespace parts.
func ParseNamespace(namespace string) ([]string, error) {
return normalizeNamespace([]string{namespace})
diff --git a/weed/server/common.go b/weed/server/common.go
index 32662ada9..9a6b2a7da 100644
--- a/weed/server/common.go
+++ b/weed/server/common.go
@@ -19,7 +19,6 @@ import (
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
- "github.com/seaweedfs/seaweedfs/weed/util/version"
"google.golang.org/grpc/metadata"
"github.com/seaweedfs/seaweedfs/weed/filer"
@@ -237,25 +236,6 @@ func parseURLPath(path string) (vid, fid, filename, ext string, isVolumeIdOnly b
return
}
-func statsHealthHandler(w http.ResponseWriter, r *http.Request) {
- m := make(map[string]interface{})
- m["Version"] = version.Version()
- writeJsonQuiet(w, r, http.StatusOK, m)
-}
-func statsCounterHandler(w http.ResponseWriter, r *http.Request) {
- m := make(map[string]interface{})
- m["Version"] = version.Version()
- m["Counters"] = serverStats
- writeJsonQuiet(w, r, http.StatusOK, m)
-}
-
-func statsMemoryHandler(w http.ResponseWriter, r *http.Request) {
- m := make(map[string]interface{})
- m["Version"] = version.Version()
- m["Memory"] = stats.MemStat()
- writeJsonQuiet(w, r, http.StatusOK, m)
-}
-
var StaticFS fs.FS
func handleStaticResources(defaultMux *http.ServeMux) {
diff --git a/weed/server/filer_server_handlers_proxy.go b/weed/server/filer_server_handlers_proxy.go
index 31ee47cdb..cdbb95321 100644
--- a/weed/server/filer_server_handlers_proxy.go
+++ b/weed/server/filer_server_handlers_proxy.go
@@ -5,7 +5,6 @@ import (
"sync"
"github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/security"
util_http "github.com/seaweedfs/seaweedfs/weed/util/http"
"github.com/seaweedfs/seaweedfs/weed/util/mem"
"github.com/seaweedfs/seaweedfs/weed/util/request_id"
@@ -25,26 +24,6 @@ var (
proxySemaphores sync.Map // host -> chan struct{}
)
-func (fs *FilerServer) maybeAddVolumeJwtAuthorization(r *http.Request, fileId string, isWrite bool) {
- encodedJwt := fs.maybeGetVolumeJwtAuthorizationToken(fileId, isWrite)
-
- if encodedJwt == "" {
- return
- }
-
- r.Header.Set("Authorization", "BEARER "+string(encodedJwt))
-}
-
-func (fs *FilerServer) maybeGetVolumeJwtAuthorizationToken(fileId string, isWrite bool) string {
- var encodedJwt security.EncodedJwt
- if isWrite {
- encodedJwt = security.GenJwtForVolumeServer(fs.volumeGuard.SigningKey, fs.volumeGuard.ExpiresAfterSec, fileId)
- } else {
- encodedJwt = security.GenJwtForVolumeServer(fs.volumeGuard.ReadSigningKey, fs.volumeGuard.ReadExpiresAfterSec, fileId)
- }
- return string(encodedJwt)
-}
-
func acquireProxySemaphore(ctx context.Context, host string) error {
v, _ := proxySemaphores.LoadOrStore(host, make(chan struct{}, proxyReadConcurrencyPerVolumeServer))
sem := v.(chan struct{})
diff --git a/weed/server/filer_server_handlers_write_cipher.go b/weed/server/filer_server_handlers_write_cipher.go
deleted file mode 100644
index 2a3fb6b68..000000000
--- a/weed/server/filer_server_handlers_write_cipher.go
+++ /dev/null
@@ -1,107 +0,0 @@
-package weed_server
-
-import (
- "bytes"
- "context"
- "fmt"
- "net/http"
- "strings"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/filer"
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/operation"
- "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
- "github.com/seaweedfs/seaweedfs/weed/storage/needle"
- "github.com/seaweedfs/seaweedfs/weed/util"
-)
-
-// handling single chunk POST or PUT upload
-func (fs *FilerServer) encrypt(ctx context.Context, w http.ResponseWriter, r *http.Request, so *operation.StorageOption) (filerResult *FilerPostResult, err error) {
-
- fileId, urlLocation, auth, err := fs.assignNewFileInfo(ctx, so)
-
- if err != nil || fileId == "" || urlLocation == "" {
- return nil, fmt.Errorf("fail to allocate volume for %s, collection:%s, datacenter:%s", r.URL.Path, so.Collection, so.DataCenter)
- }
-
- glog.V(4).InfofCtx(ctx, "write %s to %v", r.URL.Path, urlLocation)
-
- // Note: encrypt(gzip(data)), encrypt data first, then gzip
-
- sizeLimit := int64(fs.option.MaxMB) * 1024 * 1024
-
- bytesBuffer := bufPool.Get().(*bytes.Buffer)
- defer bufPool.Put(bytesBuffer)
-
- pu, err := needle.ParseUpload(r, sizeLimit, bytesBuffer)
- uncompressedData := pu.Data
- if pu.IsGzipped {
- uncompressedData = pu.UncompressedData
- }
- if pu.MimeType == "" {
- pu.MimeType = http.DetectContentType(uncompressedData)
- // println("detect2 mimetype to", pu.MimeType)
- }
-
- uploadOption := &operation.UploadOption{
- UploadUrl: urlLocation,
- Filename: pu.FileName,
- Cipher: true,
- IsInputCompressed: false,
- MimeType: pu.MimeType,
- PairMap: pu.PairMap,
- Jwt: auth,
- }
-
- uploader, uploaderErr := operation.NewUploader()
- if uploaderErr != nil {
- return nil, fmt.Errorf("uploader initialization error: %w", uploaderErr)
- }
-
- uploadResult, uploadError := uploader.UploadData(ctx, uncompressedData, uploadOption)
- if uploadError != nil {
- return nil, fmt.Errorf("upload to volume server: %w", uploadError)
- }
-
- // Save to chunk manifest structure
- fileChunks := []*filer_pb.FileChunk{uploadResult.ToPbFileChunk(fileId, 0, time.Now().UnixNano())}
-
- // fmt.Printf("uploaded: %+v\n", uploadResult)
-
- path := r.URL.Path
- if strings.HasSuffix(path, "/") {
- if pu.FileName != "" {
- path += pu.FileName
- }
- }
-
- entry := &filer.Entry{
- FullPath: util.FullPath(path),
- Attr: filer.Attr{
- Mtime: time.Now(),
- Crtime: time.Now(),
- Mode: 0660,
- Uid: OS_UID,
- Gid: OS_GID,
- TtlSec: so.TtlSeconds,
- Mime: pu.MimeType,
- Md5: util.Base64Md5ToBytes(pu.ContentMd5),
- },
- Chunks: fileChunks,
- }
-
- filerResult = &FilerPostResult{
- Name: pu.FileName,
- Size: int64(pu.OriginalDataSize),
- }
-
- if dbErr := fs.filer.CreateEntry(ctx, entry, false, false, nil, false, so.MaxFileNameLength); dbErr != nil {
- fs.filer.DeleteUncommittedChunks(ctx, entry.GetChunks())
- err = dbErr
- filerResult.Error = dbErr.Error()
- return
- }
-
- return
-}
diff --git a/weed/server/filer_server_handlers_write_upload.go b/weed/server/filer_server_handlers_write_upload.go
index 2f67e7860..40a5ca4f5 100644
--- a/weed/server/filer_server_handlers_write_upload.go
+++ b/weed/server/filer_server_handlers_write_upload.go
@@ -196,10 +196,6 @@ func (fs *FilerServer) doUpload(ctx context.Context, urlLocation string, limited
return uploadResult, err, data
}
-func (fs *FilerServer) dataToChunk(ctx context.Context, fileName, contentType string, data []byte, chunkOffset int64, so *operation.StorageOption) ([]*filer_pb.FileChunk, error) {
- return fs.dataToChunkWithSSE(ctx, nil, fileName, contentType, data, chunkOffset, so)
-}
-
func (fs *FilerServer) dataToChunkWithSSE(ctx context.Context, r *http.Request, fileName, contentType string, data []byte, chunkOffset int64, so *operation.StorageOption) ([]*filer_pb.FileChunk, error) {
dataReader := util.NewBytesReader(data)
diff --git a/weed/server/postgres/server.go b/weed/server/postgres/server.go
index f35d3704e..1ac4d8b3e 100644
--- a/weed/server/postgres/server.go
+++ b/weed/server/postgres/server.go
@@ -697,8 +697,3 @@ func (s *PostgreSQLServer) cleanupIdleSessions() {
}
}
}
-
-// GetAddress returns the server address
-func (s *PostgreSQLServer) GetAddress() string {
- return fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
-}
diff --git a/weed/server/volume_grpc_client_to_master.go b/weed/server/volume_grpc_client_to_master.go
index e2523543a..2c484e7ce 100644
--- a/weed/server/volume_grpc_client_to_master.go
+++ b/weed/server/volume_grpc_client_to_master.go
@@ -106,10 +106,6 @@ func (vs *VolumeServer) StopHeartbeat() (isAlreadyStopping bool) {
return false
}
-func (vs *VolumeServer) doHeartbeat(masterAddress pb.ServerAddress, grpcDialOption grpc.DialOption, sleepInterval time.Duration) (newLeader pb.ServerAddress, err error) {
- return vs.doHeartbeatWithRetry(masterAddress, grpcDialOption, sleepInterval, 0)
-}
-
func (vs *VolumeServer) doHeartbeatWithRetry(masterAddress pb.ServerAddress, grpcDialOption grpc.DialOption, sleepInterval time.Duration, duplicateRetryCount int) (newLeader pb.ServerAddress, err error) {
ctx, cancel := context.WithCancel(context.Background())
diff --git a/weed/server/volume_server_handlers_admin.go b/weed/server/volume_server_handlers_admin.go
index a54369277..dfb90befd 100644
--- a/weed/server/volume_server_handlers_admin.go
+++ b/weed/server/volume_server_handlers_admin.go
@@ -50,19 +50,3 @@ func (vs *VolumeServer) statusHandler(w http.ResponseWriter, r *http.Request) {
m["Volumes"] = vs.store.VolumeInfos()
writeJsonQuiet(w, r, http.StatusOK, m)
}
-
-func (vs *VolumeServer) statsDiskHandler(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Server", "SeaweedFS Volume "+version.VERSION)
- m := make(map[string]interface{})
- m["Version"] = version.Version()
- var ds []*volume_server_pb.DiskStatus
- for _, loc := range vs.store.Locations {
- if dir, e := filepath.Abs(loc.Directory); e == nil {
- newDiskStatus := stats.NewDiskStatus(dir)
- newDiskStatus.DiskType = loc.DiskType.String()
- ds = append(ds, newDiskStatus)
- }
- }
- m["DiskStatuses"] = ds
- writeJsonQuiet(w, r, http.StatusOK, m)
-}
diff --git a/weed/server/volume_server_handlers_write.go b/weed/server/volume_server_handlers_write.go
index 44a2abc34..418f3c235 100644
--- a/weed/server/volume_server_handlers_write.go
+++ b/weed/server/volume_server_handlers_write.go
@@ -160,11 +160,3 @@ func SetEtag(w http.ResponseWriter, etag string) {
}
}
}
-
-func getEtag(resp *http.Response) (etag string) {
- etag = resp.Header.Get("ETag")
- if strings.HasPrefix(etag, "\"") && strings.HasSuffix(etag, "\"") {
- return etag[1 : len(etag)-1]
- }
- return
-}
diff --git a/weed/server/volume_server_test.go b/weed/server/volume_server_test.go
deleted file mode 100644
index ac1ad774e..000000000
--- a/weed/server/volume_server_test.go
+++ /dev/null
@@ -1,69 +0,0 @@
-package weed_server
-
-import (
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/volume_server_pb"
- "github.com/seaweedfs/seaweedfs/weed/storage"
-)
-
-func TestMaintenanceMode(t *testing.T) {
- testCases := []struct {
- name string
- pb *volume_server_pb.VolumeServerState
- want bool
- wantCheckErr string
- }{
- {
- name: "non-initialized state",
- pb: nil,
- want: false,
- wantCheckErr: "",
- },
- {
- name: "maintenance mode disabled",
- pb: &volume_server_pb.VolumeServerState{
- Maintenance: false,
- },
- want: false,
- wantCheckErr: "",
- },
- {
- name: "maintenance mode enabled",
- pb: &volume_server_pb.VolumeServerState{
- Maintenance: true,
- },
- want: true,
- wantCheckErr: "volume server test_1234 is in maintenance mode",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- vs := VolumeServer{
- store: &storage.Store{
- Id: "test_1234",
- State: storage.NewStateFromProto("/some/path.pb", tc.pb),
- },
- }
-
- if got, want := vs.MaintenanceMode(), tc.want; got != want {
- t.Errorf("MaintenanceMode() returned %v, want %v", got, want)
- }
-
- err, wantErrStr := vs.CheckMaintenanceMode(), tc.wantCheckErr
- if err != nil {
- if wantErrStr == "" {
- t.Errorf("CheckMaintenanceMode() returned error %v, want nil", err)
- }
- if errStr := err.Error(); errStr != wantErrStr {
- t.Errorf("CheckMaintenanceMode() returned error %q, want %q", errStr, wantErrStr)
- }
- } else {
- if wantErrStr != "" {
- t.Errorf("CheckMaintenanceMode() returned no error, want %q", wantErrStr)
- }
- }
- })
- }
-}
diff --git a/weed/sftpd/sftp_file_writer.go b/weed/sftpd/sftp_file_writer.go
index fed60eec0..3f6d915b0 100644
--- a/weed/sftpd/sftp_file_writer.go
+++ b/weed/sftpd/sftp_file_writer.go
@@ -32,28 +32,6 @@ type bufferReader struct {
i int64
}
-func NewBufferReader(b []byte) *bufferReader { return &bufferReader{b: b} }
-
-func (r *bufferReader) Read(p []byte) (int, error) {
- if r.i >= int64(len(r.b)) {
- return 0, io.EOF
- }
- n := copy(p, r.b[r.i:])
- r.i += int64(n)
- return n, nil
-}
-
-func (r *bufferReader) ReadAt(p []byte, off int64) (int, error) {
- if off >= int64(len(r.b)) {
- return 0, io.EOF
- }
- n := copy(p, r.b[off:])
- if n < len(p) {
- return n, io.EOF
- }
- return n, nil
-}
-
// listerat implements sftp.ListerAt.
type listerat []os.FileInfo
diff --git a/weed/shell/command_ec_common.go b/weed/shell/command_ec_common.go
index ba84fc7f7..aef207772 100644
--- a/weed/shell/command_ec_common.go
+++ b/weed/shell/command_ec_common.go
@@ -406,30 +406,6 @@ func sortEcNodesByFreeslotsAscending(ecNodes []*EcNode) {
})
}
-// if the index node changed the freeEcSlot, need to keep every EcNode still sorted
-func ensureSortedEcNodes(data []*CandidateEcNode, index int, lessThan func(i, j int) bool) {
- for i := index - 1; i >= 0; i-- {
- if lessThan(i+1, i) {
- swap(data, i, i+1)
- } else {
- break
- }
- }
- for i := index + 1; i < len(data); i++ {
- if lessThan(i, i-1) {
- swap(data, i, i-1)
- } else {
- break
- }
- }
-}
-
-func swap(data []*CandidateEcNode, i, j int) {
- t := data[i]
- data[i] = data[j]
- data[j] = t
-}
-
func countShards(ecShardInfos []*master_pb.VolumeEcShardInformationMessage) (count int) {
for _, eci := range ecShardInfos {
count += erasure_coding.GetShardCount(eci)
@@ -1135,48 +1111,6 @@ func (ecb *ecBalancer) pickRackForShardType(
return selected.id, nil
}
-func (ecb *ecBalancer) pickRackToBalanceShardsInto(rackToEcNodes map[RackId]*EcRack, rackToShardCount map[string]int) (RackId, error) {
- targets := []RackId{}
- targetShards := -1
- for _, shards := range rackToShardCount {
- if shards > targetShards {
- targetShards = shards
- }
- }
-
- details := ""
- for rackId, rack := range rackToEcNodes {
- shards := rackToShardCount[string(rackId)]
-
- if rack.freeEcSlot <= 0 {
- details += fmt.Sprintf(" Skipped %s because it has no free slots\n", rackId)
- continue
- }
- // For EC shards, replica placement constraint only applies when DiffRackCount > 0.
- // When DiffRackCount = 0 (e.g., replica placement "000"), EC shards should be
- // distributed freely across racks for fault tolerance - the "000" means
- // "no volume replication needed" because erasure coding provides redundancy.
- if ecb.replicaPlacement != nil && ecb.replicaPlacement.DiffRackCount > 0 && shards > ecb.replicaPlacement.DiffRackCount {
- details += fmt.Sprintf(" Skipped %s because shards %d > replica placement limit for other racks (%d)\n", rackId, shards, ecb.replicaPlacement.DiffRackCount)
- continue
- }
-
- if shards < targetShards {
- // Favor racks with less shards, to ensure an uniform distribution.
- targets = nil
- targetShards = shards
- }
- if shards == targetShards {
- targets = append(targets, rackId)
- }
- }
-
- if len(targets) == 0 {
- return "", errors.New(details)
- }
- return targets[rand.IntN(len(targets))], nil
-}
-
func (ecb *ecBalancer) balanceEcShardsWithinRacks(collection string) error {
// collect vid => []ecNode, since previous steps can change the locations
vidLocations := ecb.collectVolumeIdToEcNodes(collection)
@@ -1567,46 +1501,6 @@ func (ecb *ecBalancer) pickOneEcNodeAndMoveOneShard(existingLocation *EcNode, co
return moveMountedShardToEcNode(ecb.commandEnv, existingLocation, collection, vid, shardId, destNode, destDiskId, ecb.applyBalancing, ecb.diskType)
}
-func pickNEcShardsToMoveFrom(ecNodes []*EcNode, vid needle.VolumeId, n int, diskType types.DiskType) map[erasure_coding.ShardId]*EcNode {
- picked := make(map[erasure_coding.ShardId]*EcNode)
- var candidateEcNodes []*CandidateEcNode
- for _, ecNode := range ecNodes {
- si := findEcVolumeShardsInfo(ecNode, vid, diskType)
- if si.Count() > 0 {
- candidateEcNodes = append(candidateEcNodes, &CandidateEcNode{
- ecNode: ecNode,
- shardCount: si.Count(),
- })
- }
- }
- slices.SortFunc(candidateEcNodes, func(a, b *CandidateEcNode) int {
- return b.shardCount - a.shardCount
- })
- for i := 0; i < n; i++ {
- selectedEcNodeIndex := -1
- for i, candidateEcNode := range candidateEcNodes {
- si := findEcVolumeShardsInfo(candidateEcNode.ecNode, vid, diskType)
- if si.Count() > 0 {
- selectedEcNodeIndex = i
- for _, shardId := range si.Ids() {
- candidateEcNode.shardCount--
- picked[shardId] = candidateEcNode.ecNode
- candidateEcNode.ecNode.deleteEcVolumeShards(vid, []erasure_coding.ShardId{shardId}, diskType)
- break
- }
- break
- }
- }
- if selectedEcNodeIndex >= 0 {
- ensureSortedEcNodes(candidateEcNodes, selectedEcNodeIndex, func(i, j int) bool {
- return candidateEcNodes[i].shardCount > candidateEcNodes[j].shardCount
- })
- }
-
- }
- return picked
-}
-
func (ecb *ecBalancer) collectVolumeIdToEcNodes(collection string) map[needle.VolumeId][]*EcNode {
vidLocations := make(map[needle.VolumeId][]*EcNode)
for _, ecNode := range ecb.ecNodes {
diff --git a/weed/shell/command_ec_common_test.go b/weed/shell/command_ec_common_test.go
deleted file mode 100644
index ff186f21d..000000000
--- a/weed/shell/command_ec_common_test.go
+++ /dev/null
@@ -1,354 +0,0 @@
-package shell
-
-import (
- "fmt"
- "reflect"
- "strings"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/master_pb"
- "github.com/seaweedfs/seaweedfs/weed/storage/erasure_coding"
- "github.com/seaweedfs/seaweedfs/weed/storage/needle"
- "github.com/seaweedfs/seaweedfs/weed/storage/super_block"
- "github.com/seaweedfs/seaweedfs/weed/storage/types"
-)
-
-func errorCheck(got error, want string) error {
- if got == nil && want == "" {
- return nil
- }
- if got != nil && want == "" {
- return fmt.Errorf("expected no error, got %q", got.Error())
- }
- if got == nil && want != "" {
- return fmt.Errorf("got no error, expected %q", want)
- }
- if !strings.Contains(got.Error(), want) {
- return fmt.Errorf("expected error %q, got %q", want, got.Error())
- }
- return nil
-}
-
-func TestCollectCollectionsForVolumeIds(t *testing.T) {
- testCases := []struct {
- topology *master_pb.TopologyInfo
- vids []needle.VolumeId
- want []string
- }{
- // normal volumes
- {testTopology1, nil, nil},
- {testTopology1, []needle.VolumeId{}, nil},
- {testTopology1, []needle.VolumeId{needle.VolumeId(9999)}, nil},
- {testTopology1, []needle.VolumeId{needle.VolumeId(2)}, []string{""}},
- {testTopology1, []needle.VolumeId{needle.VolumeId(2), needle.VolumeId(272)}, []string{"", "collection2"}},
- {testTopology1, []needle.VolumeId{needle.VolumeId(2), needle.VolumeId(272), needle.VolumeId(299)}, []string{"", "collection2"}},
- {testTopology1, []needle.VolumeId{needle.VolumeId(272), needle.VolumeId(299), needle.VolumeId(95)}, []string{"collection1", "collection2"}},
- {testTopology1, []needle.VolumeId{needle.VolumeId(272), needle.VolumeId(299), needle.VolumeId(95), needle.VolumeId(51)}, []string{"collection1", "collection2"}},
- {testTopology1, []needle.VolumeId{needle.VolumeId(272), needle.VolumeId(299), needle.VolumeId(95), needle.VolumeId(51), needle.VolumeId(15)}, []string{"collection0", "collection1", "collection2"}},
- // EC volumes
- {testTopology2, []needle.VolumeId{needle.VolumeId(9577)}, []string{"s3qldata"}},
- {testTopology2, []needle.VolumeId{needle.VolumeId(9577), needle.VolumeId(12549)}, []string{"s3qldata"}},
- // normal + EC volumes
- {testTopology2, []needle.VolumeId{needle.VolumeId(18111)}, []string{"s3qldata"}},
- {testTopology2, []needle.VolumeId{needle.VolumeId(8677)}, []string{"s3qldata"}},
- {testTopology2, []needle.VolumeId{needle.VolumeId(18111), needle.VolumeId(8677)}, []string{"s3qldata"}},
- }
-
- for _, tc := range testCases {
- got := collectCollectionsForVolumeIds(tc.topology, tc.vids)
- if !reflect.DeepEqual(got, tc.want) {
- t.Errorf("for %v: got %v, want %v", tc.vids, got, tc.want)
- }
- }
-}
-
-func TestParseReplicaPlacementArg(t *testing.T) {
- getDefaultReplicaPlacementOrig := getDefaultReplicaPlacement
- getDefaultReplicaPlacement = func(commandEnv *CommandEnv) (*super_block.ReplicaPlacement, error) {
- return super_block.NewReplicaPlacementFromString("123")
- }
- defer func() {
- getDefaultReplicaPlacement = getDefaultReplicaPlacementOrig
- }()
-
- testCases := []struct {
- argument string
- want string
- wantErr string
- }{
- {"lalala", "lal", "unexpected replication type"},
- {"", "123", ""},
- {"021", "021", ""},
- }
-
- for _, tc := range testCases {
- commandEnv := &CommandEnv{}
- got, gotErr := parseReplicaPlacementArg(commandEnv, tc.argument)
-
- if err := errorCheck(gotErr, tc.wantErr); err != nil {
- t.Errorf("argument %q: %s", tc.argument, err.Error())
- continue
- }
-
- want, _ := super_block.NewReplicaPlacementFromString(tc.want)
- if !got.Equals(want) {
- t.Errorf("got replica placement %q, want %q", got.String(), want.String())
- }
- }
-}
-
-func TestEcDistribution(t *testing.T) {
-
- // find out all volume servers with one slot left.
- ecNodes, totalFreeEcSlots := collectEcVolumeServersByDc(testTopology1, "", types.HardDriveType)
-
- sortEcNodesByFreeslotsDescending(ecNodes)
-
- if totalFreeEcSlots < erasure_coding.TotalShardsCount {
- t.Errorf("not enough free ec shard slots: %d", totalFreeEcSlots)
- }
- allocatedDataNodes := ecNodes
- if len(allocatedDataNodes) > erasure_coding.TotalShardsCount {
- allocatedDataNodes = allocatedDataNodes[:erasure_coding.TotalShardsCount]
- }
-
- for _, dn := range allocatedDataNodes {
- // fmt.Printf("info %+v %+v\n", dn.info, dn)
- fmt.Printf("=> %+v %+v\n", dn.info.Id, dn.freeEcSlot)
- }
-}
-
-func TestPickRackToBalanceShardsInto(t *testing.T) {
- testCases := []struct {
- topology *master_pb.TopologyInfo
- vid string
- replicaPlacement string
- wantOneOf []string
- wantErr string
- }{
- // Non-EC volumes. We don't care about these, but the function should return all racks as a safeguard.
- {testTopologyEc, "", "123", []string{"rack1", "rack2", "rack3", "rack4", "rack5", "rack6"}, ""},
- {testTopologyEc, "6225", "123", []string{"rack1", "rack2", "rack3", "rack4", "rack5", "rack6"}, ""},
- {testTopologyEc, "6226", "123", []string{"rack1", "rack2", "rack3", "rack4", "rack5", "rack6"}, ""},
- {testTopologyEc, "6241", "123", []string{"rack1", "rack2", "rack3", "rack4", "rack5", "rack6"}, ""},
- {testTopologyEc, "6242", "123", []string{"rack1", "rack2", "rack3", "rack4", "rack5", "rack6"}, ""},
- // EC volumes.
- // With replication "000" (DiffRackCount=0), EC shards should be distributed freely
- // because erasure coding provides its own redundancy. No replica placement error.
- {testTopologyEc, "9577", "", []string{"rack1", "rack2", "rack3"}, ""},
- {testTopologyEc, "9577", "111", []string{"rack1", "rack2", "rack3"}, ""},
- {testTopologyEc, "9577", "222", []string{"rack1", "rack2", "rack3"}, ""},
- {testTopologyEc, "10457", "222", []string{"rack1"}, ""},
- {testTopologyEc, "12737", "222", []string{"rack2"}, ""},
- {testTopologyEc, "14322", "222", []string{"rack3"}, ""},
- }
-
- for _, tc := range testCases {
- vid, _ := needle.NewVolumeId(tc.vid)
- ecNodes, _ := collectEcVolumeServersByDc(tc.topology, "", types.HardDriveType)
- rp, _ := super_block.NewReplicaPlacementFromString(tc.replicaPlacement)
-
- ecb := &ecBalancer{
- ecNodes: ecNodes,
- replicaPlacement: rp,
- diskType: types.HardDriveType,
- }
-
- racks := ecb.racks()
- rackToShardCount := countShardsByRack(vid, ecNodes, types.HardDriveType)
-
- got, gotErr := ecb.pickRackToBalanceShardsInto(racks, rackToShardCount)
- if err := errorCheck(gotErr, tc.wantErr); err != nil {
- t.Errorf("volume %q: %s", tc.vid, err.Error())
- continue
- }
-
- if string(got) == "" && len(tc.wantOneOf) == 0 {
- continue
- }
- found := false
- for _, want := range tc.wantOneOf {
- if got := string(got); got == want {
- found = true
- break
- }
- }
- if !(found) {
- t.Errorf("expected one of %v for volume %q, got %q", tc.wantOneOf, tc.vid, got)
- }
- }
-}
-func TestPickEcNodeToBalanceShardsInto(t *testing.T) {
- testCases := []struct {
- topology *master_pb.TopologyInfo
- nodeId string
- vid string
- wantOneOf []string
- wantErr string
- }{
- {testTopologyEc, "", "", nil, "INTERNAL: missing source nodes"},
- {testTopologyEc, "idontexist", "12737", nil, "INTERNAL: missing source nodes"},
- // Non-EC nodes. We don't care about these, but the function should return all available target nodes as a safeguard.
- {
- testTopologyEc, "172.19.0.10:8702", "6225", []string{
- "172.19.0.13:8701", "172.19.0.14:8711", "172.19.0.16:8704", "172.19.0.17:8703",
- "172.19.0.19:8700", "172.19.0.20:8706", "172.19.0.21:8710", "172.19.0.3:8708",
- "172.19.0.4:8707", "172.19.0.5:8705", "172.19.0.6:8713", "172.19.0.8:8709",
- "172.19.0.9:8712"},
- "",
- },
- {
- testTopologyEc, "172.19.0.8:8709", "6226", []string{
- "172.19.0.10:8702", "172.19.0.13:8701", "172.19.0.14:8711", "172.19.0.16:8704",
- "172.19.0.17:8703", "172.19.0.19:8700", "172.19.0.20:8706", "172.19.0.21:8710",
- "172.19.0.3:8708", "172.19.0.4:8707", "172.19.0.5:8705", "172.19.0.6:8713",
- "172.19.0.9:8712"},
- "",
- },
- // EC volumes.
- {testTopologyEc, "172.19.0.10:8702", "14322", []string{
- "172.19.0.14:8711", "172.19.0.5:8705", "172.19.0.6:8713"},
- ""},
- {testTopologyEc, "172.19.0.13:8701", "10457", []string{
- "172.19.0.10:8702", "172.19.0.6:8713"},
- ""},
- {testTopologyEc, "172.19.0.17:8703", "12737", []string{
- "172.19.0.13:8701"},
- ""},
- {testTopologyEc, "172.19.0.20:8706", "14322", []string{
- "172.19.0.14:8711", "172.19.0.5:8705", "172.19.0.6:8713"},
- ""},
- }
-
- for _, tc := range testCases {
- vid, _ := needle.NewVolumeId(tc.vid)
- allEcNodes, _ := collectEcVolumeServersByDc(tc.topology, "", types.HardDriveType)
-
- ecb := &ecBalancer{
- ecNodes: allEcNodes,
- diskType: types.HardDriveType,
- }
-
- // Resolve target node by name
- var ecNode *EcNode
- for _, n := range allEcNodes {
- if n.info.Id == tc.nodeId {
- ecNode = n
- break
- }
- }
-
- got, gotErr := ecb.pickEcNodeToBalanceShardsInto(vid, ecNode, allEcNodes)
- if err := errorCheck(gotErr, tc.wantErr); err != nil {
- t.Errorf("node %q, volume %q: %s", tc.nodeId, tc.vid, err.Error())
- continue
- }
-
- if got == nil {
- if len(tc.wantOneOf) == 0 {
- continue
- }
- t.Errorf("node %q, volume %q: got no node, want %q", tc.nodeId, tc.vid, tc.wantOneOf)
- continue
- }
- found := false
- for _, want := range tc.wantOneOf {
- if got := got.info.Id; got == want {
- found = true
- break
- }
- }
- if !(found) {
- t.Errorf("expected one of %v for volume %q, got %q", tc.wantOneOf, tc.vid, got.info.Id)
- }
- }
-}
-
-func TestCountFreeShardSlots(t *testing.T) {
- testCases := []struct {
- name string
- topology *master_pb.TopologyInfo
- diskType types.DiskType
- want map[string]int
- }{
- {
- name: "topology #1, free HDD shards",
- topology: testTopology1,
- diskType: types.HardDriveType,
- want: map[string]int{
- "192.168.1.1:8080": 17330,
- "192.168.1.2:8080": 1540,
- "192.168.1.4:8080": 1900,
- "192.168.1.5:8080": 27010,
- "192.168.1.6:8080": 17420,
- },
- },
- {
- name: "topology #1, no free SSD shards available",
- topology: testTopology1,
- diskType: types.SsdType,
- want: map[string]int{
- "192.168.1.1:8080": 0,
- "192.168.1.2:8080": 0,
- "192.168.1.4:8080": 0,
- "192.168.1.5:8080": 0,
- "192.168.1.6:8080": 0,
- },
- },
- {
- name: "topology #2, no negative free HDD shards",
- topology: testTopology2,
- diskType: types.HardDriveType,
- want: map[string]int{
- "172.19.0.3:8708": 0,
- "172.19.0.4:8707": 8,
- "172.19.0.5:8705": 58,
- "172.19.0.6:8713": 39,
- "172.19.0.8:8709": 8,
- "172.19.0.9:8712": 0,
- "172.19.0.10:8702": 0,
- "172.19.0.13:8701": 0,
- "172.19.0.14:8711": 0,
- "172.19.0.16:8704": 89,
- "172.19.0.17:8703": 0,
- "172.19.0.19:8700": 9,
- "172.19.0.20:8706": 0,
- "172.19.0.21:8710": 9,
- },
- },
- {
- name: "topology #2, no free SSD shards available",
- topology: testTopology2,
- diskType: types.SsdType,
- want: map[string]int{
- "172.19.0.10:8702": 0,
- "172.19.0.13:8701": 0,
- "172.19.0.14:8711": 0,
- "172.19.0.16:8704": 0,
- "172.19.0.17:8703": 0,
- "172.19.0.19:8700": 0,
- "172.19.0.20:8706": 0,
- "172.19.0.21:8710": 0,
- "172.19.0.3:8708": 0,
- "172.19.0.4:8707": 0,
- "172.19.0.5:8705": 0,
- "172.19.0.6:8713": 0,
- "172.19.0.8:8709": 0,
- "172.19.0.9:8712": 0,
- },
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- got := map[string]int{}
- eachDataNode(tc.topology, func(dc DataCenterId, rack RackId, dn *master_pb.DataNodeInfo) {
- got[dn.Id] = countFreeShardSlots(dn, tc.diskType)
- })
-
- if !reflect.DeepEqual(got, tc.want) {
- t.Errorf("got %v, want %v", got, tc.want)
- }
- })
- }
-}
diff --git a/weed/shell/commands.go b/weed/shell/commands.go
index 741dff6b0..6679c15c9 100644
--- a/weed/shell/commands.go
+++ b/weed/shell/commands.go
@@ -4,8 +4,6 @@ import (
"context"
"fmt"
"io"
- "net/url"
- "strconv"
"strings"
"github.com/seaweedfs/seaweedfs/weed/operation"
@@ -138,25 +136,6 @@ func (ce *CommandEnv) GetDataCenter() string {
return ce.MasterClient.GetDataCenter()
}
-func parseFilerUrl(entryPath string) (filerServer string, filerPort int64, path string, err error) {
- if strings.HasPrefix(entryPath, "http") {
- var u *url.URL
- u, err = url.Parse(entryPath)
- if err != nil {
- return
- }
- filerServer = u.Hostname()
- portString := u.Port()
- if portString != "" {
- filerPort, err = strconv.ParseInt(portString, 10, 32)
- }
- path = u.Path
- } else {
- err = fmt.Errorf("path should have full url /path/to/dirOrFile : %s", entryPath)
- }
- return
-}
-
func findInputDirectory(args []string) (input string) {
input = "."
if len(args) > 0 {
diff --git a/weed/shell/ec_proportional_rebalance.go b/weed/shell/ec_proportional_rebalance.go
index 52adf4297..8d6b1c1b7 100644
--- a/weed/shell/ec_proportional_rebalance.go
+++ b/weed/shell/ec_proportional_rebalance.go
@@ -1,8 +1,6 @@
package shell
import (
- "fmt"
-
"github.com/seaweedfs/seaweedfs/weed/storage/erasure_coding"
"github.com/seaweedfs/seaweedfs/weed/storage/erasure_coding/distribution"
"github.com/seaweedfs/seaweedfs/weed/storage/needle"
@@ -13,18 +11,6 @@ import (
// ECDistribution is an alias to the distribution package type for backward compatibility
type ECDistribution = distribution.ECDistribution
-// CalculateECDistribution computes the target EC shard distribution based on replication policy.
-// This is a convenience wrapper that uses the default 10+4 EC configuration.
-// For custom EC ratios, use the distribution package directly.
-func CalculateECDistribution(totalShards, parityShards int, rp *super_block.ReplicaPlacement) *ECDistribution {
- ec := distribution.ECConfig{
- DataShards: totalShards - parityShards,
- ParityShards: parityShards,
- }
- rep := distribution.NewReplicationConfig(rp)
- return distribution.CalculateDistribution(ec, rep)
-}
-
// TopologyDistributionAnalysis holds the current shard distribution analysis
// This wraps the distribution package's TopologyAnalysis with shell-specific EcNode handling
type TopologyDistributionAnalysis struct {
@@ -34,99 +20,6 @@ type TopologyDistributionAnalysis struct {
nodeMap map[string]*EcNode // nodeID -> EcNode
}
-// NewTopologyDistributionAnalysis creates a new analysis structure
-func NewTopologyDistributionAnalysis() *TopologyDistributionAnalysis {
- return &TopologyDistributionAnalysis{
- inner: distribution.NewTopologyAnalysis(),
- nodeMap: make(map[string]*EcNode),
- }
-}
-
-// AddNode adds a node and its shards to the analysis
-func (a *TopologyDistributionAnalysis) AddNode(node *EcNode, shardsInfo *erasure_coding.ShardsInfo) {
- nodeId := node.info.Id
-
- // Create distribution.TopologyNode from EcNode
- topoNode := &distribution.TopologyNode{
- NodeID: nodeId,
- DataCenter: string(node.dc),
- Rack: string(node.rack),
- FreeSlots: node.freeEcSlot,
- TotalShards: shardsInfo.Count(),
- ShardIDs: shardsInfo.IdsInt(),
- }
-
- a.inner.AddNode(topoNode)
- a.nodeMap[nodeId] = node
-
- // Add shard locations
- for _, shardId := range shardsInfo.Ids() {
- a.inner.AddShardLocation(distribution.ShardLocation{
- ShardID: int(shardId),
- NodeID: nodeId,
- DataCenter: string(node.dc),
- Rack: string(node.rack),
- })
- }
-}
-
-// Finalize completes the analysis
-func (a *TopologyDistributionAnalysis) Finalize() {
- a.inner.Finalize()
-}
-
-// String returns a summary
-func (a *TopologyDistributionAnalysis) String() string {
- return a.inner.String()
-}
-
-// DetailedString returns detailed analysis
-func (a *TopologyDistributionAnalysis) DetailedString() string {
- return a.inner.DetailedString()
-}
-
-// GetShardsByDC returns shard counts by DC
-func (a *TopologyDistributionAnalysis) GetShardsByDC() map[DataCenterId]int {
- result := make(map[DataCenterId]int)
- for dc, count := range a.inner.ShardsByDC {
- result[DataCenterId(dc)] = count
- }
- return result
-}
-
-// GetShardsByRack returns shard counts by rack
-func (a *TopologyDistributionAnalysis) GetShardsByRack() map[RackId]int {
- result := make(map[RackId]int)
- for rack, count := range a.inner.ShardsByRack {
- result[RackId(rack)] = count
- }
- return result
-}
-
-// GetShardsByNode returns shard counts by node
-func (a *TopologyDistributionAnalysis) GetShardsByNode() map[EcNodeId]int {
- result := make(map[EcNodeId]int)
- for nodeId, count := range a.inner.ShardsByNode {
- result[EcNodeId(nodeId)] = count
- }
- return result
-}
-
-// AnalyzeVolumeDistribution creates an analysis of current shard distribution for a volume
-func AnalyzeVolumeDistribution(volumeId needle.VolumeId, locations []*EcNode, diskType types.DiskType) *TopologyDistributionAnalysis {
- analysis := NewTopologyDistributionAnalysis()
-
- for _, node := range locations {
- si := findEcVolumeShardsInfo(node, volumeId, diskType)
- if si.Count() > 0 {
- analysis.AddNode(node, si)
- }
- }
-
- analysis.Finalize()
- return analysis
-}
-
// ECShardMove represents a planned shard move (shell-specific with EcNode references)
type ECShardMove struct {
VolumeId needle.VolumeId
@@ -136,12 +29,6 @@ type ECShardMove struct {
Reason string
}
-// String returns a human-readable description
-func (m ECShardMove) String() string {
- return fmt.Sprintf("volume %d shard %d: %s -> %s (%s)",
- m.VolumeId, m.ShardId, m.SourceNode.info.Id, m.DestNode.info.Id, m.Reason)
-}
-
// ProportionalECRebalancer implements proportional shard distribution for shell commands
type ProportionalECRebalancer struct {
ecNodes []*EcNode
@@ -149,133 +36,3 @@ type ProportionalECRebalancer struct {
diskType types.DiskType
ecConfig distribution.ECConfig
}
-
-// NewProportionalECRebalancer creates a new proportional rebalancer with default EC config
-func NewProportionalECRebalancer(
- ecNodes []*EcNode,
- rp *super_block.ReplicaPlacement,
- diskType types.DiskType,
-) *ProportionalECRebalancer {
- return NewProportionalECRebalancerWithConfig(
- ecNodes,
- rp,
- diskType,
- distribution.DefaultECConfig(),
- )
-}
-
-// NewProportionalECRebalancerWithConfig creates a rebalancer with custom EC configuration
-func NewProportionalECRebalancerWithConfig(
- ecNodes []*EcNode,
- rp *super_block.ReplicaPlacement,
- diskType types.DiskType,
- ecConfig distribution.ECConfig,
-) *ProportionalECRebalancer {
- return &ProportionalECRebalancer{
- ecNodes: ecNodes,
- replicaPlacement: rp,
- diskType: diskType,
- ecConfig: ecConfig,
- }
-}
-
-// PlanMoves generates a plan for moving shards to achieve proportional distribution
-func (r *ProportionalECRebalancer) PlanMoves(
- volumeId needle.VolumeId,
- locations []*EcNode,
-) ([]ECShardMove, error) {
- // Build topology analysis
- analysis := distribution.NewTopologyAnalysis()
- nodeMap := make(map[string]*EcNode)
-
- // Add all EC nodes to the analysis (even those without shards)
- for _, node := range r.ecNodes {
- nodeId := node.info.Id
- topoNode := &distribution.TopologyNode{
- NodeID: nodeId,
- DataCenter: string(node.dc),
- Rack: string(node.rack),
- FreeSlots: node.freeEcSlot,
- }
- analysis.AddNode(topoNode)
- nodeMap[nodeId] = node
- }
-
- // Add shard locations from nodes that have shards
- for _, node := range locations {
- nodeId := node.info.Id
- si := findEcVolumeShardsInfo(node, volumeId, r.diskType)
- for _, shardId := range si.Ids() {
- analysis.AddShardLocation(distribution.ShardLocation{
- ShardID: int(shardId),
- NodeID: nodeId,
- DataCenter: string(node.dc),
- Rack: string(node.rack),
- })
- }
- if _, exists := nodeMap[nodeId]; !exists {
- nodeMap[nodeId] = node
- }
- }
-
- analysis.Finalize()
-
- // Create rebalancer and plan moves
- rep := distribution.NewReplicationConfig(r.replicaPlacement)
- rebalancer := distribution.NewRebalancer(r.ecConfig, rep)
-
- plan, err := rebalancer.PlanRebalance(analysis)
- if err != nil {
- return nil, err
- }
-
- // Convert distribution moves to shell moves
- var moves []ECShardMove
- for _, move := range plan.Moves {
- srcNode := nodeMap[move.SourceNode.NodeID]
- destNode := nodeMap[move.DestNode.NodeID]
- if srcNode == nil || destNode == nil {
- continue
- }
-
- moves = append(moves, ECShardMove{
- VolumeId: volumeId,
- ShardId: erasure_coding.ShardId(move.ShardID),
- SourceNode: srcNode,
- DestNode: destNode,
- Reason: move.Reason,
- })
- }
-
- return moves, nil
-}
-
-// GetDistributionSummary returns a summary of the planned distribution
-func GetDistributionSummary(rp *super_block.ReplicaPlacement) string {
- ec := distribution.DefaultECConfig()
- rep := distribution.NewReplicationConfig(rp)
- dist := distribution.CalculateDistribution(ec, rep)
- return dist.Summary()
-}
-
-// GetDistributionSummaryWithConfig returns a summary with custom EC configuration
-func GetDistributionSummaryWithConfig(rp *super_block.ReplicaPlacement, ecConfig distribution.ECConfig) string {
- rep := distribution.NewReplicationConfig(rp)
- dist := distribution.CalculateDistribution(ecConfig, rep)
- return dist.Summary()
-}
-
-// GetFaultToleranceAnalysis returns fault tolerance analysis for the given configuration
-func GetFaultToleranceAnalysis(rp *super_block.ReplicaPlacement) string {
- ec := distribution.DefaultECConfig()
- rep := distribution.NewReplicationConfig(rp)
- dist := distribution.CalculateDistribution(ec, rep)
- return dist.FaultToleranceAnalysis()
-}
-
-// GetFaultToleranceAnalysisWithConfig returns fault tolerance analysis with custom EC configuration
-func GetFaultToleranceAnalysisWithConfig(rp *super_block.ReplicaPlacement, ecConfig distribution.ECConfig) string {
- rep := distribution.NewReplicationConfig(rp)
- dist := distribution.CalculateDistribution(ecConfig, rep)
- return dist.FaultToleranceAnalysis()
-}
diff --git a/weed/shell/ec_proportional_rebalance_test.go b/weed/shell/ec_proportional_rebalance_test.go
deleted file mode 100644
index c8ec99e0a..000000000
--- a/weed/shell/ec_proportional_rebalance_test.go
+++ /dev/null
@@ -1,251 +0,0 @@
-package shell
-
-import (
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/pb/master_pb"
- "github.com/seaweedfs/seaweedfs/weed/storage/erasure_coding"
- "github.com/seaweedfs/seaweedfs/weed/storage/erasure_coding/distribution"
- "github.com/seaweedfs/seaweedfs/weed/storage/needle"
- "github.com/seaweedfs/seaweedfs/weed/storage/super_block"
- "github.com/seaweedfs/seaweedfs/weed/storage/types"
-)
-
-func TestCalculateECDistributionShell(t *testing.T) {
- // Test the shell wrapper function
- rp, _ := super_block.NewReplicaPlacementFromString("110")
-
- dist := CalculateECDistribution(
- erasure_coding.TotalShardsCount,
- erasure_coding.ParityShardsCount,
- rp,
- )
-
- if dist.ReplicationConfig.MinDataCenters != 2 {
- t.Errorf("Expected 2 DCs, got %d", dist.ReplicationConfig.MinDataCenters)
- }
- if dist.TargetShardsPerDC != 7 {
- t.Errorf("Expected 7 shards per DC, got %d", dist.TargetShardsPerDC)
- }
-
- t.Log(dist.Summary())
-}
-
-func TestAnalyzeVolumeDistributionShell(t *testing.T) {
- diskType := types.HardDriveType
- diskTypeKey := string(diskType)
-
- // Build a topology with unbalanced distribution
- node1 := &EcNode{
- info: &master_pb.DataNodeInfo{
- Id: "127.0.0.1:8080",
- DiskInfos: map[string]*master_pb.DiskInfo{
- diskTypeKey: {
- Type: diskTypeKey,
- MaxVolumeCount: 10,
- EcShardInfos: []*master_pb.VolumeEcShardInformationMessage{
- {
- Id: 1,
- Collection: "test",
- EcIndexBits: 0x3FFF, // All 14 shards
- },
- },
- },
- },
- },
- dc: "dc1",
- rack: "rack1",
- freeEcSlot: 5,
- }
-
- node2 := &EcNode{
- info: &master_pb.DataNodeInfo{
- Id: "127.0.0.1:8081",
- DiskInfos: map[string]*master_pb.DiskInfo{
- diskTypeKey: {
- Type: diskTypeKey,
- MaxVolumeCount: 10,
- EcShardInfos: []*master_pb.VolumeEcShardInformationMessage{},
- },
- },
- },
- dc: "dc2",
- rack: "rack2",
- freeEcSlot: 10,
- }
-
- locations := []*EcNode{node1, node2}
- volumeId := needle.VolumeId(1)
-
- analysis := AnalyzeVolumeDistribution(volumeId, locations, diskType)
-
- shardsByDC := analysis.GetShardsByDC()
- if shardsByDC["dc1"] != 14 {
- t.Errorf("Expected 14 shards in dc1, got %d", shardsByDC["dc1"])
- }
-
- t.Log(analysis.DetailedString())
-}
-
-func TestProportionalRebalancerShell(t *testing.T) {
- diskType := types.HardDriveType
- diskTypeKey := string(diskType)
-
- // Build topology: 2 DCs, 2 racks each, all shards on one node
- nodes := []*EcNode{
- {
- info: &master_pb.DataNodeInfo{
- Id: "dc1-rack1-node1",
- DiskInfos: map[string]*master_pb.DiskInfo{
- diskTypeKey: {
- Type: diskTypeKey,
- MaxVolumeCount: 10,
- EcShardInfos: []*master_pb.VolumeEcShardInformationMessage{
- {Id: 1, Collection: "test", EcIndexBits: 0x3FFF},
- },
- },
- },
- },
- dc: "dc1", rack: "dc1-rack1", freeEcSlot: 0,
- },
- {
- info: &master_pb.DataNodeInfo{
- Id: "dc1-rack2-node1",
- DiskInfos: map[string]*master_pb.DiskInfo{
- diskTypeKey: {Type: diskTypeKey, MaxVolumeCount: 10},
- },
- },
- dc: "dc1", rack: "dc1-rack2", freeEcSlot: 10,
- },
- {
- info: &master_pb.DataNodeInfo{
- Id: "dc2-rack1-node1",
- DiskInfos: map[string]*master_pb.DiskInfo{
- diskTypeKey: {Type: diskTypeKey, MaxVolumeCount: 10},
- },
- },
- dc: "dc2", rack: "dc2-rack1", freeEcSlot: 10,
- },
- {
- info: &master_pb.DataNodeInfo{
- Id: "dc2-rack2-node1",
- DiskInfos: map[string]*master_pb.DiskInfo{
- diskTypeKey: {Type: diskTypeKey, MaxVolumeCount: 10},
- },
- },
- dc: "dc2", rack: "dc2-rack2", freeEcSlot: 10,
- },
- }
-
- rp, _ := super_block.NewReplicaPlacementFromString("110")
- rebalancer := NewProportionalECRebalancer(nodes, rp, diskType)
-
- volumeId := needle.VolumeId(1)
- moves, err := rebalancer.PlanMoves(volumeId, []*EcNode{nodes[0]})
-
- if err != nil {
- t.Fatalf("PlanMoves failed: %v", err)
- }
-
- t.Logf("Planned %d moves", len(moves))
- for i, move := range moves {
- t.Logf(" %d. %s", i+1, move.String())
- }
-
- // Verify moves to dc2
- movedToDC2 := 0
- for _, move := range moves {
- if move.DestNode.dc == "dc2" {
- movedToDC2++
- }
- }
-
- if movedToDC2 == 0 {
- t.Error("Expected some moves to dc2")
- }
-}
-
-func TestCustomECConfigRebalancer(t *testing.T) {
- diskType := types.HardDriveType
- diskTypeKey := string(diskType)
-
- // Test with custom 8+4 EC configuration
- ecConfig, err := distribution.NewECConfig(8, 4)
- if err != nil {
- t.Fatalf("Failed to create EC config: %v", err)
- }
-
- // Build topology for 12 shards (8+4)
- nodes := []*EcNode{
- {
- info: &master_pb.DataNodeInfo{
- Id: "dc1-node1",
- DiskInfos: map[string]*master_pb.DiskInfo{
- diskTypeKey: {
- Type: diskTypeKey,
- MaxVolumeCount: 10,
- EcShardInfos: []*master_pb.VolumeEcShardInformationMessage{
- {Id: 1, Collection: "test", EcIndexBits: 0x0FFF}, // 12 shards (bits 0-11)
- },
- },
- },
- },
- dc: "dc1", rack: "dc1-rack1", freeEcSlot: 0,
- },
- {
- info: &master_pb.DataNodeInfo{
- Id: "dc2-node1",
- DiskInfos: map[string]*master_pb.DiskInfo{
- diskTypeKey: {Type: diskTypeKey, MaxVolumeCount: 10},
- },
- },
- dc: "dc2", rack: "dc2-rack1", freeEcSlot: 10,
- },
- {
- info: &master_pb.DataNodeInfo{
- Id: "dc3-node1",
- DiskInfos: map[string]*master_pb.DiskInfo{
- diskTypeKey: {Type: diskTypeKey, MaxVolumeCount: 10},
- },
- },
- dc: "dc3", rack: "dc3-rack1", freeEcSlot: 10,
- },
- }
-
- rp, _ := super_block.NewReplicaPlacementFromString("200") // 3 DCs
- rebalancer := NewProportionalECRebalancerWithConfig(nodes, rp, diskType, ecConfig)
-
- volumeId := needle.VolumeId(1)
- moves, err := rebalancer.PlanMoves(volumeId, []*EcNode{nodes[0]})
-
- if err != nil {
- t.Fatalf("PlanMoves failed: %v", err)
- }
-
- t.Logf("Custom 8+4 EC with 200 replication: planned %d moves", len(moves))
-
- // Get the distribution summary
- summary := GetDistributionSummaryWithConfig(rp, ecConfig)
- t.Log(summary)
-
- analysis := GetFaultToleranceAnalysisWithConfig(rp, ecConfig)
- t.Log(analysis)
-}
-
-func TestGetDistributionSummaryShell(t *testing.T) {
- rp, _ := super_block.NewReplicaPlacementFromString("110")
-
- summary := GetDistributionSummary(rp)
- t.Log(summary)
-
- if len(summary) == 0 {
- t.Error("Summary should not be empty")
- }
-
- analysis := GetFaultToleranceAnalysis(rp)
- t.Log(analysis)
-
- if len(analysis) == 0 {
- t.Error("Analysis should not be empty")
- }
-}
diff --git a/weed/shell/shell_liner.go b/weed/shell/shell_liner.go
index 00831d42e..78afc7880 100644
--- a/weed/shell/shell_liner.go
+++ b/weed/shell/shell_liner.go
@@ -126,17 +126,6 @@ func processEachCmd(cmd string, commandEnv *CommandEnv) bool {
return false
}
-func stripQuotes(s string) string {
- tokens, unbalanced := parseShellInput(s, false)
- if unbalanced {
- return s
- }
- if len(tokens) > 0 {
- return tokens[0]
- }
- return ""
-}
-
func splitCommandLine(line string) []string {
tokens, _ := parseShellInput(line, true)
return tokens
diff --git a/weed/shell/shell_liner_test.go b/weed/shell/shell_liner_test.go
deleted file mode 100644
index bfdd2b378..000000000
--- a/weed/shell/shell_liner_test.go
+++ /dev/null
@@ -1,105 +0,0 @@
-package shell
-
-import (
- "flag"
- "reflect"
- "testing"
-)
-
-func TestSplitCommandLine(t *testing.T) {
- tests := []struct {
- input string
- expected []string
- }{
- {
- input: `s3.configure -user=test`,
- expected: []string{`s3.configure`, `-user=test`},
- },
- {
- input: `s3.configure -user=Test_number_004 -account_display_name="Test number 004" -actions=write -apply`,
- expected: []string{`s3.configure`, `-user=Test_number_004`, `-account_display_name=Test number 004`, `-actions=write`, `-apply`},
- },
- {
- input: `s3.configure -user=Test_number_004 -account_display_name='Test number 004' -actions=write -apply`,
- expected: []string{`s3.configure`, `-user=Test_number_004`, `-account_display_name=Test number 004`, `-actions=write`, `-apply`},
- },
- {
- input: `s3.configure -flag="a b"c'd e'`,
- expected: []string{`s3.configure`, `-flag=a bcd e`},
- },
- {
- input: `s3.configure -name="a\"b"`,
- expected: []string{`s3.configure`, `-name=a"b`},
- },
- {
- input: `s3.configure -path='a\ b'`,
- expected: []string{`s3.configure`, `-path=a\ b`},
- },
- }
-
- for _, tt := range tests {
- got := splitCommandLine(tt.input)
- if !reflect.DeepEqual(got, tt.expected) {
- t.Errorf("input: %s\ngot: %v\nwant: %v", tt.input, got, tt.expected)
- }
- }
-}
-
-func TestStripQuotes(t *testing.T) {
- tests := []struct {
- input string
- expected string
- }{
- {input: `"Test number 004"`, expected: `Test number 004`},
- {input: `'Test number 004'`, expected: `Test number 004`},
- {input: `-account_display_name="Test number 004"`, expected: `-account_display_name=Test number 004`},
- {input: `-flag="a"b'c'`, expected: `-flag=abc`},
- {input: `-name="a\"b"`, expected: `-name=a"b`},
- {input: `-path='a\ b'`, expected: `-path=a\ b`},
- {input: `"unbalanced`, expected: `"unbalanced`},
- {input: `'unbalanced`, expected: `'unbalanced`},
- {input: `-name="a\"b`, expected: `-name="a\"b`},
- {input: `trailing\`, expected: `trailing\`},
- }
-
- for _, tt := range tests {
- got := stripQuotes(tt.input)
- if got != tt.expected {
- t.Errorf("input: %s, got: %s, want: %s", tt.input, got, tt.expected)
- }
- }
-}
-
-func TestFlagParsing(t *testing.T) {
- fs := flag.NewFlagSet("test", flag.ContinueOnError)
- displayName := fs.String("account_display_name", "", "display name")
-
- rawArg := `-account_display_name="Test number 004"`
- args := []string{stripQuotes(rawArg)}
- err := fs.Parse(args)
- if err != nil {
- t.Fatal(err)
- }
-
- expected := "Test number 004"
- if *displayName != expected {
- t.Errorf("got: [%s], want: [%s]", *displayName, expected)
- }
-}
-
-func TestEscapedFlagParsing(t *testing.T) {
- fs := flag.NewFlagSet("test", flag.ContinueOnError)
- name := fs.String("name", "", "name")
-
- rawArg := `-name="a\"b"`
- args := []string{stripQuotes(rawArg)}
- err := fs.Parse(args)
- if err != nil {
- t.Fatal(err)
- }
-
- expected := `a"b`
- if *name != expected {
- t.Errorf("got: [%s], want: [%s]", *name, expected)
- }
-}
diff --git a/weed/stats/disk_common.go b/weed/stats/disk_common.go
deleted file mode 100644
index 936c77e91..000000000
--- a/weed/stats/disk_common.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package stats
-
-import "github.com/seaweedfs/seaweedfs/weed/pb/volume_server_pb"
-
-func calculateDiskRemaining(disk *volume_server_pb.DiskStatus) {
- disk.Used = disk.All - disk.Free
-
- if disk.All > 0 {
- disk.PercentFree = float32((float64(disk.Free) / float64(disk.All)) * 100)
- disk.PercentUsed = float32((float64(disk.Used) / float64(disk.All)) * 100)
- } else {
- disk.PercentFree = 0
- disk.PercentUsed = 0
- }
-
- return
-}
diff --git a/weed/stats/stats.go b/weed/stats/stats.go
index 6d3d55cc6..f875f3780 100644
--- a/weed/stats/stats.go
+++ b/weed/stats/stats.go
@@ -62,12 +62,6 @@ func ConnectionOpen() {
func ConnectionClose() {
Chan.Connections <- NewTimedValue(time.Now(), -1)
}
-func RequestOpen() {
- Chan.Requests <- NewTimedValue(time.Now(), 1)
-}
-func RequestClose() {
- Chan.Requests <- NewTimedValue(time.Now(), -1)
-}
func AssignRequest() {
Chan.AssignRequests <- NewTimedValue(time.Now(), 1)
}
diff --git a/weed/storage/erasure_coding/distribution/analysis.go b/weed/storage/erasure_coding/distribution/analysis.go
index 22923e671..b939df53e 100644
--- a/weed/storage/erasure_coding/distribution/analysis.go
+++ b/weed/storage/erasure_coding/distribution/analysis.go
@@ -1,10 +1,5 @@
package distribution
-import (
- "fmt"
- "slices"
-)
-
// ShardLocation represents where a shard is located in the topology
type ShardLocation struct {
ShardID int
@@ -47,101 +42,6 @@ type TopologyAnalysis struct {
TotalDCs int
}
-// NewTopologyAnalysis creates a new empty analysis
-func NewTopologyAnalysis() *TopologyAnalysis {
- return &TopologyAnalysis{
- ShardsByDC: make(map[string]int),
- ShardsByRack: make(map[string]int),
- ShardsByNode: make(map[string]int),
- DCToShards: make(map[string][]int),
- RackToShards: make(map[string][]int),
- NodeToShards: make(map[string][]int),
- DCToRacks: make(map[string][]string),
- RackToNodes: make(map[string][]*TopologyNode),
- AllNodes: make(map[string]*TopologyNode),
- }
-}
-
-// AddShardLocation adds a shard location to the analysis
-func (a *TopologyAnalysis) AddShardLocation(loc ShardLocation) {
- // Update counts
- a.ShardsByDC[loc.DataCenter]++
- a.ShardsByRack[loc.Rack]++
- a.ShardsByNode[loc.NodeID]++
-
- // Update shard lists
- a.DCToShards[loc.DataCenter] = append(a.DCToShards[loc.DataCenter], loc.ShardID)
- a.RackToShards[loc.Rack] = append(a.RackToShards[loc.Rack], loc.ShardID)
- a.NodeToShards[loc.NodeID] = append(a.NodeToShards[loc.NodeID], loc.ShardID)
-
- a.TotalShards++
-}
-
-// AddNode adds a node to the topology (even if it has no shards)
-func (a *TopologyAnalysis) AddNode(node *TopologyNode) {
- if _, exists := a.AllNodes[node.NodeID]; exists {
- return // Already added
- }
-
- a.AllNodes[node.NodeID] = node
- a.TotalNodes++
-
- // Update topology structure
- if !slices.Contains(a.DCToRacks[node.DataCenter], node.Rack) {
- a.DCToRacks[node.DataCenter] = append(a.DCToRacks[node.DataCenter], node.Rack)
- }
- a.RackToNodes[node.Rack] = append(a.RackToNodes[node.Rack], node)
-
- // Update counts
- if _, exists := a.ShardsByDC[node.DataCenter]; !exists {
- a.TotalDCs++
- }
- if _, exists := a.ShardsByRack[node.Rack]; !exists {
- a.TotalRacks++
- }
-}
-
-// Finalize computes final statistics after all data is added
-func (a *TopologyAnalysis) Finalize() {
- // Ensure we have accurate DC and rack counts
- dcSet := make(map[string]bool)
- rackSet := make(map[string]bool)
- for _, node := range a.AllNodes {
- dcSet[node.DataCenter] = true
- rackSet[node.Rack] = true
- }
- a.TotalDCs = len(dcSet)
- a.TotalRacks = len(rackSet)
- a.TotalNodes = len(a.AllNodes)
-}
-
-// String returns a summary of the analysis
-func (a *TopologyAnalysis) String() string {
- return fmt.Sprintf("TopologyAnalysis{shards:%d, nodes:%d, racks:%d, dcs:%d}",
- a.TotalShards, a.TotalNodes, a.TotalRacks, a.TotalDCs)
-}
-
-// DetailedString returns a detailed multi-line summary
-func (a *TopologyAnalysis) DetailedString() string {
- s := fmt.Sprintf("Topology Analysis:\n")
- s += fmt.Sprintf(" Total Shards: %d\n", a.TotalShards)
- s += fmt.Sprintf(" Data Centers: %d\n", a.TotalDCs)
- for dc, count := range a.ShardsByDC {
- s += fmt.Sprintf(" %s: %d shards\n", dc, count)
- }
- s += fmt.Sprintf(" Racks: %d\n", a.TotalRacks)
- for rack, count := range a.ShardsByRack {
- s += fmt.Sprintf(" %s: %d shards\n", rack, count)
- }
- s += fmt.Sprintf(" Nodes: %d\n", a.TotalNodes)
- for nodeID, count := range a.ShardsByNode {
- if count > 0 {
- s += fmt.Sprintf(" %s: %d shards\n", nodeID, count)
- }
- }
- return s
-}
-
// TopologyExcess represents a topology level (DC/rack/node) with excess shards
type TopologyExcess struct {
ID string // DC/rack/node ID
@@ -150,91 +50,3 @@ type TopologyExcess struct {
Shards []int // Shard IDs at this level
Nodes []*TopologyNode // Nodes at this level (for finding sources)
}
-
-// CalculateDCExcess returns DCs with more shards than the target
-func CalculateDCExcess(analysis *TopologyAnalysis, dist *ECDistribution) []TopologyExcess {
- var excess []TopologyExcess
-
- for dc, count := range analysis.ShardsByDC {
- if count > dist.TargetShardsPerDC {
- // Collect nodes in this DC
- var nodes []*TopologyNode
- for _, rack := range analysis.DCToRacks[dc] {
- nodes = append(nodes, analysis.RackToNodes[rack]...)
- }
- excess = append(excess, TopologyExcess{
- ID: dc,
- Level: "dc",
- Excess: count - dist.TargetShardsPerDC,
- Shards: analysis.DCToShards[dc],
- Nodes: nodes,
- })
- }
- }
-
- // Sort by excess (most excess first)
- slices.SortFunc(excess, func(a, b TopologyExcess) int {
- return b.Excess - a.Excess
- })
-
- return excess
-}
-
-// CalculateRackExcess returns racks with more shards than the target (within a DC)
-func CalculateRackExcess(analysis *TopologyAnalysis, dc string, targetPerRack int) []TopologyExcess {
- var excess []TopologyExcess
-
- for _, rack := range analysis.DCToRacks[dc] {
- count := analysis.ShardsByRack[rack]
- if count > targetPerRack {
- excess = append(excess, TopologyExcess{
- ID: rack,
- Level: "rack",
- Excess: count - targetPerRack,
- Shards: analysis.RackToShards[rack],
- Nodes: analysis.RackToNodes[rack],
- })
- }
- }
-
- slices.SortFunc(excess, func(a, b TopologyExcess) int {
- return b.Excess - a.Excess
- })
-
- return excess
-}
-
-// CalculateUnderservedDCs returns DCs that have fewer shards than target
-func CalculateUnderservedDCs(analysis *TopologyAnalysis, dist *ECDistribution) []string {
- var underserved []string
-
- // Check existing DCs
- for dc, count := range analysis.ShardsByDC {
- if count < dist.TargetShardsPerDC {
- underserved = append(underserved, dc)
- }
- }
-
- // Check DCs with nodes but no shards
- for dc := range analysis.DCToRacks {
- if _, exists := analysis.ShardsByDC[dc]; !exists {
- underserved = append(underserved, dc)
- }
- }
-
- return underserved
-}
-
-// CalculateUnderservedRacks returns racks that have fewer shards than target
-func CalculateUnderservedRacks(analysis *TopologyAnalysis, dc string, targetPerRack int) []string {
- var underserved []string
-
- for _, rack := range analysis.DCToRacks[dc] {
- count := analysis.ShardsByRack[rack]
- if count < targetPerRack {
- underserved = append(underserved, rack)
- }
- }
-
- return underserved
-}
diff --git a/weed/storage/erasure_coding/distribution/config.go b/weed/storage/erasure_coding/distribution/config.go
index e89d6eeb6..b4935b0c7 100644
--- a/weed/storage/erasure_coding/distribution/config.go
+++ b/weed/storage/erasure_coding/distribution/config.go
@@ -1,12 +1,6 @@
// Package distribution provides EC shard distribution algorithms with configurable EC ratios.
package distribution
-import (
- "fmt"
-
- "github.com/seaweedfs/seaweedfs/weed/storage/super_block"
-)
-
// ECConfig holds erasure coding configuration parameters.
// This replaces hard-coded constants like DataShardsCount=10, ParityShardsCount=4.
type ECConfig struct {
@@ -14,113 +8,6 @@ type ECConfig struct {
ParityShards int // Number of parity shards (e.g., 4)
}
-// DefaultECConfig returns the standard 10+4 EC configuration
-func DefaultECConfig() ECConfig {
- return ECConfig{
- DataShards: 10,
- ParityShards: 4,
- }
-}
-
-// NewECConfig creates a new EC configuration with validation
-func NewECConfig(dataShards, parityShards int) (ECConfig, error) {
- if dataShards <= 0 {
- return ECConfig{}, fmt.Errorf("dataShards must be positive, got %d", dataShards)
- }
- if parityShards <= 0 {
- return ECConfig{}, fmt.Errorf("parityShards must be positive, got %d", parityShards)
- }
- if dataShards+parityShards > 32 {
- return ECConfig{}, fmt.Errorf("total shards (%d+%d=%d) exceeds maximum of 32",
- dataShards, parityShards, dataShards+parityShards)
- }
- return ECConfig{
- DataShards: dataShards,
- ParityShards: parityShards,
- }, nil
-}
-
-// TotalShards returns the total number of shards (data + parity)
-func (c ECConfig) TotalShards() int {
- return c.DataShards + c.ParityShards
-}
-
-// MaxTolerableLoss returns the maximum number of shards that can be lost
-// while still being able to reconstruct the data
-func (c ECConfig) MaxTolerableLoss() int {
- return c.ParityShards
-}
-
-// MinShardsForReconstruction returns the minimum number of shards needed
-// to reconstruct the original data
-func (c ECConfig) MinShardsForReconstruction() int {
- return c.DataShards
-}
-
-// String returns a human-readable representation
-func (c ECConfig) String() string {
- return fmt.Sprintf("%d+%d (total: %d, can lose: %d)",
- c.DataShards, c.ParityShards, c.TotalShards(), c.MaxTolerableLoss())
-}
-
-// IsDataShard returns true if the shard ID is a data shard (0 to DataShards-1)
-func (c ECConfig) IsDataShard(shardID int) bool {
- return shardID >= 0 && shardID < c.DataShards
-}
-
-// IsParityShard returns true if the shard ID is a parity shard (DataShards to TotalShards-1)
-func (c ECConfig) IsParityShard(shardID int) bool {
- return shardID >= c.DataShards && shardID < c.TotalShards()
-}
-
-// SortShardsDataFirst returns a copy of shards sorted with data shards first.
-// This is useful for initial placement where data shards should be spread out first.
-func (c ECConfig) SortShardsDataFirst(shards []int) []int {
- result := make([]int, len(shards))
- copy(result, shards)
-
- // Partition: data shards first, then parity shards
- dataIdx := 0
- parityIdx := len(result) - 1
-
- sorted := make([]int, len(result))
- for _, s := range result {
- if c.IsDataShard(s) {
- sorted[dataIdx] = s
- dataIdx++
- } else {
- sorted[parityIdx] = s
- parityIdx--
- }
- }
-
- return sorted
-}
-
-// SortShardsParityFirst returns a copy of shards sorted with parity shards first.
-// This is useful for rebalancing where we prefer to move parity shards.
-func (c ECConfig) SortShardsParityFirst(shards []int) []int {
- result := make([]int, len(shards))
- copy(result, shards)
-
- // Partition: parity shards first, then data shards
- parityIdx := 0
- dataIdx := len(result) - 1
-
- sorted := make([]int, len(result))
- for _, s := range result {
- if c.IsParityShard(s) {
- sorted[parityIdx] = s
- parityIdx++
- } else {
- sorted[dataIdx] = s
- dataIdx--
- }
- }
-
- return sorted
-}
-
// ReplicationConfig holds the parsed replication policy
type ReplicationConfig struct {
MinDataCenters int // X+1 from XYZ replication (minimum DCs to use)
@@ -130,42 +17,3 @@ type ReplicationConfig struct {
// Original replication string (for logging/debugging)
Original string
}
-
-// NewReplicationConfig creates a ReplicationConfig from a ReplicaPlacement
-func NewReplicationConfig(rp *super_block.ReplicaPlacement) ReplicationConfig {
- if rp == nil {
- return ReplicationConfig{
- MinDataCenters: 1,
- MinRacksPerDC: 1,
- MinNodesPerRack: 1,
- Original: "000",
- }
- }
- return ReplicationConfig{
- MinDataCenters: rp.DiffDataCenterCount + 1,
- MinRacksPerDC: rp.DiffRackCount + 1,
- MinNodesPerRack: rp.SameRackCount + 1,
- Original: rp.String(),
- }
-}
-
-// NewReplicationConfigFromString creates a ReplicationConfig from a replication string
-func NewReplicationConfigFromString(replication string) (ReplicationConfig, error) {
- rp, err := super_block.NewReplicaPlacementFromString(replication)
- if err != nil {
- return ReplicationConfig{}, err
- }
- return NewReplicationConfig(rp), nil
-}
-
-// TotalPlacementSlots returns the minimum number of unique placement locations
-// based on the replication policy
-func (r ReplicationConfig) TotalPlacementSlots() int {
- return r.MinDataCenters * r.MinRacksPerDC * r.MinNodesPerRack
-}
-
-// String returns a human-readable representation
-func (r ReplicationConfig) String() string {
- return fmt.Sprintf("replication=%s (DCs:%d, Racks/DC:%d, Nodes/Rack:%d)",
- r.Original, r.MinDataCenters, r.MinRacksPerDC, r.MinNodesPerRack)
-}
diff --git a/weed/storage/erasure_coding/distribution/distribution.go b/weed/storage/erasure_coding/distribution/distribution.go
index 03deea710..1ef05c55d 100644
--- a/weed/storage/erasure_coding/distribution/distribution.go
+++ b/weed/storage/erasure_coding/distribution/distribution.go
@@ -1,9 +1,5 @@
package distribution
-import (
- "fmt"
-)
-
// ECDistribution represents the target distribution of EC shards
// based on EC configuration and replication policy.
type ECDistribution struct {
@@ -24,137 +20,3 @@ type ECDistribution struct {
MaxShardsPerRack int
MaxShardsPerNode int
}
-
-// CalculateDistribution computes the target EC shard distribution based on
-// EC configuration and replication policy.
-//
-// The algorithm:
-// 1. Uses replication policy to determine minimum topology spread
-// 2. Calculates target shards per level (evenly distributed)
-// 3. Calculates max shards per level (for fault tolerance)
-func CalculateDistribution(ec ECConfig, rep ReplicationConfig) *ECDistribution {
- totalShards := ec.TotalShards()
-
- // Target distribution (balanced, rounded up to ensure all shards placed)
- targetShardsPerDC := ceilDivide(totalShards, rep.MinDataCenters)
- targetShardsPerRack := ceilDivide(targetShardsPerDC, rep.MinRacksPerDC)
- targetShardsPerNode := ceilDivide(targetShardsPerRack, rep.MinNodesPerRack)
-
- // Maximum limits for fault tolerance
- // The key constraint: losing one failure domain shouldn't lose more than parityShards
- // So max shards per domain = totalShards - parityShards + tolerance
- // We add small tolerance (+2) to allow for imbalanced topologies
- faultToleranceLimit := totalShards - ec.ParityShards + 1
-
- maxShardsPerDC := min(faultToleranceLimit, targetShardsPerDC+2)
- maxShardsPerRack := min(faultToleranceLimit, targetShardsPerRack+2)
- maxShardsPerNode := min(faultToleranceLimit, targetShardsPerNode+2)
-
- return &ECDistribution{
- ECConfig: ec,
- ReplicationConfig: rep,
- TargetShardsPerDC: targetShardsPerDC,
- TargetShardsPerRack: targetShardsPerRack,
- TargetShardsPerNode: targetShardsPerNode,
- MaxShardsPerDC: maxShardsPerDC,
- MaxShardsPerRack: maxShardsPerRack,
- MaxShardsPerNode: maxShardsPerNode,
- }
-}
-
-// String returns a human-readable description of the distribution
-func (d *ECDistribution) String() string {
- return fmt.Sprintf(
- "ECDistribution{EC:%s, DCs:%d (target:%d/max:%d), Racks/DC:%d (target:%d/max:%d), Nodes/Rack:%d (target:%d/max:%d)}",
- d.ECConfig.String(),
- d.ReplicationConfig.MinDataCenters, d.TargetShardsPerDC, d.MaxShardsPerDC,
- d.ReplicationConfig.MinRacksPerDC, d.TargetShardsPerRack, d.MaxShardsPerRack,
- d.ReplicationConfig.MinNodesPerRack, d.TargetShardsPerNode, d.MaxShardsPerNode,
- )
-}
-
-// Summary returns a multi-line summary of the distribution plan
-func (d *ECDistribution) Summary() string {
- summary := fmt.Sprintf("EC Configuration: %s\n", d.ECConfig.String())
- summary += fmt.Sprintf("Replication: %s\n", d.ReplicationConfig.String())
- summary += fmt.Sprintf("Distribution Plan:\n")
- summary += fmt.Sprintf(" Data Centers: %d (target %d shards each, max %d)\n",
- d.ReplicationConfig.MinDataCenters, d.TargetShardsPerDC, d.MaxShardsPerDC)
- summary += fmt.Sprintf(" Racks per DC: %d (target %d shards each, max %d)\n",
- d.ReplicationConfig.MinRacksPerDC, d.TargetShardsPerRack, d.MaxShardsPerRack)
- summary += fmt.Sprintf(" Nodes per Rack: %d (target %d shards each, max %d)\n",
- d.ReplicationConfig.MinNodesPerRack, d.TargetShardsPerNode, d.MaxShardsPerNode)
- return summary
-}
-
-// CanSurviveDCFailure returns true if the distribution can survive
-// complete loss of one data center
-func (d *ECDistribution) CanSurviveDCFailure() bool {
- // After losing one DC with max shards, check if remaining shards are enough
- remainingAfterDCLoss := d.ECConfig.TotalShards() - d.TargetShardsPerDC
- return remainingAfterDCLoss >= d.ECConfig.MinShardsForReconstruction()
-}
-
-// CanSurviveRackFailure returns true if the distribution can survive
-// complete loss of one rack
-func (d *ECDistribution) CanSurviveRackFailure() bool {
- remainingAfterRackLoss := d.ECConfig.TotalShards() - d.TargetShardsPerRack
- return remainingAfterRackLoss >= d.ECConfig.MinShardsForReconstruction()
-}
-
-// MinDCsForDCFaultTolerance calculates the minimum number of DCs needed
-// to survive complete DC failure with this EC configuration
-func (d *ECDistribution) MinDCsForDCFaultTolerance() int {
- // To survive DC failure, max shards per DC = parityShards
- maxShardsPerDC := d.ECConfig.MaxTolerableLoss()
- if maxShardsPerDC == 0 {
- return d.ECConfig.TotalShards() // Would need one DC per shard
- }
- return ceilDivide(d.ECConfig.TotalShards(), maxShardsPerDC)
-}
-
-// FaultToleranceAnalysis returns a detailed analysis of fault tolerance
-func (d *ECDistribution) FaultToleranceAnalysis() string {
- analysis := fmt.Sprintf("Fault Tolerance Analysis for %s:\n", d.ECConfig.String())
-
- // DC failure
- dcSurvive := d.CanSurviveDCFailure()
- shardsAfterDC := d.ECConfig.TotalShards() - d.TargetShardsPerDC
- analysis += fmt.Sprintf(" DC Failure: %s\n", boolToResult(dcSurvive))
- analysis += fmt.Sprintf(" - Losing one DC loses ~%d shards\n", d.TargetShardsPerDC)
- analysis += fmt.Sprintf(" - Remaining: %d shards (need %d)\n", shardsAfterDC, d.ECConfig.DataShards)
- if !dcSurvive {
- analysis += fmt.Sprintf(" - Need at least %d DCs for DC fault tolerance\n", d.MinDCsForDCFaultTolerance())
- }
-
- // Rack failure
- rackSurvive := d.CanSurviveRackFailure()
- shardsAfterRack := d.ECConfig.TotalShards() - d.TargetShardsPerRack
- analysis += fmt.Sprintf(" Rack Failure: %s\n", boolToResult(rackSurvive))
- analysis += fmt.Sprintf(" - Losing one rack loses ~%d shards\n", d.TargetShardsPerRack)
- analysis += fmt.Sprintf(" - Remaining: %d shards (need %d)\n", shardsAfterRack, d.ECConfig.DataShards)
-
- // Node failure (usually survivable)
- shardsAfterNode := d.ECConfig.TotalShards() - d.TargetShardsPerNode
- nodeSurvive := shardsAfterNode >= d.ECConfig.DataShards
- analysis += fmt.Sprintf(" Node Failure: %s\n", boolToResult(nodeSurvive))
- analysis += fmt.Sprintf(" - Losing one node loses ~%d shards\n", d.TargetShardsPerNode)
- analysis += fmt.Sprintf(" - Remaining: %d shards (need %d)\n", shardsAfterNode, d.ECConfig.DataShards)
-
- return analysis
-}
-
-func boolToResult(b bool) string {
- if b {
- return "SURVIVABLE ✓"
- }
- return "NOT SURVIVABLE ✗"
-}
-
-// ceilDivide performs ceiling division
-func ceilDivide(a, b int) int {
- if b <= 0 {
- return a
- }
- return (a + b - 1) / b
-}
diff --git a/weed/storage/erasure_coding/distribution/distribution_test.go b/weed/storage/erasure_coding/distribution/distribution_test.go
deleted file mode 100644
index dc6a19192..000000000
--- a/weed/storage/erasure_coding/distribution/distribution_test.go
+++ /dev/null
@@ -1,565 +0,0 @@
-package distribution
-
-import (
- "testing"
-)
-
-func TestNewECConfig(t *testing.T) {
- tests := []struct {
- name string
- dataShards int
- parityShards int
- wantErr bool
- }{
- {"valid 10+4", 10, 4, false},
- {"valid 8+4", 8, 4, false},
- {"valid 6+3", 6, 3, false},
- {"valid 4+2", 4, 2, false},
- {"invalid data=0", 0, 4, true},
- {"invalid parity=0", 10, 0, true},
- {"invalid total>32", 20, 15, true},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- config, err := NewECConfig(tt.dataShards, tt.parityShards)
- if (err != nil) != tt.wantErr {
- t.Errorf("NewECConfig() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
- if !tt.wantErr {
- if config.DataShards != tt.dataShards {
- t.Errorf("DataShards = %d, want %d", config.DataShards, tt.dataShards)
- }
- if config.ParityShards != tt.parityShards {
- t.Errorf("ParityShards = %d, want %d", config.ParityShards, tt.parityShards)
- }
- if config.TotalShards() != tt.dataShards+tt.parityShards {
- t.Errorf("TotalShards() = %d, want %d", config.TotalShards(), tt.dataShards+tt.parityShards)
- }
- }
- })
- }
-}
-
-func TestCalculateDistribution(t *testing.T) {
- tests := []struct {
- name string
- ecConfig ECConfig
- replication string
- expectedMinDCs int
- expectedMinRacksPerDC int
- expectedMinNodesPerRack int
- expectedTargetPerDC int
- expectedTargetPerRack int
- expectedTargetPerNode int
- }{
- {
- name: "10+4 with 000",
- ecConfig: DefaultECConfig(),
- replication: "000",
- expectedMinDCs: 1,
- expectedMinRacksPerDC: 1,
- expectedMinNodesPerRack: 1,
- expectedTargetPerDC: 14,
- expectedTargetPerRack: 14,
- expectedTargetPerNode: 14,
- },
- {
- name: "10+4 with 100",
- ecConfig: DefaultECConfig(),
- replication: "100",
- expectedMinDCs: 2,
- expectedMinRacksPerDC: 1,
- expectedMinNodesPerRack: 1,
- expectedTargetPerDC: 7,
- expectedTargetPerRack: 7,
- expectedTargetPerNode: 7,
- },
- {
- name: "10+4 with 110",
- ecConfig: DefaultECConfig(),
- replication: "110",
- expectedMinDCs: 2,
- expectedMinRacksPerDC: 2,
- expectedMinNodesPerRack: 1,
- expectedTargetPerDC: 7,
- expectedTargetPerRack: 4,
- expectedTargetPerNode: 4,
- },
- {
- name: "10+4 with 200",
- ecConfig: DefaultECConfig(),
- replication: "200",
- expectedMinDCs: 3,
- expectedMinRacksPerDC: 1,
- expectedMinNodesPerRack: 1,
- expectedTargetPerDC: 5,
- expectedTargetPerRack: 5,
- expectedTargetPerNode: 5,
- },
- {
- name: "8+4 with 110",
- ecConfig: ECConfig{
- DataShards: 8,
- ParityShards: 4,
- },
- replication: "110",
- expectedMinDCs: 2,
- expectedMinRacksPerDC: 2,
- expectedMinNodesPerRack: 1,
- expectedTargetPerDC: 6, // 12/2 = 6
- expectedTargetPerRack: 3, // 6/2 = 3
- expectedTargetPerNode: 3,
- },
- {
- name: "6+3 with 100",
- ecConfig: ECConfig{
- DataShards: 6,
- ParityShards: 3,
- },
- replication: "100",
- expectedMinDCs: 2,
- expectedMinRacksPerDC: 1,
- expectedMinNodesPerRack: 1,
- expectedTargetPerDC: 5, // ceil(9/2) = 5
- expectedTargetPerRack: 5,
- expectedTargetPerNode: 5,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- rep, err := NewReplicationConfigFromString(tt.replication)
- if err != nil {
- t.Fatalf("Failed to parse replication %s: %v", tt.replication, err)
- }
-
- dist := CalculateDistribution(tt.ecConfig, rep)
-
- if dist.ReplicationConfig.MinDataCenters != tt.expectedMinDCs {
- t.Errorf("MinDataCenters = %d, want %d", dist.ReplicationConfig.MinDataCenters, tt.expectedMinDCs)
- }
- if dist.ReplicationConfig.MinRacksPerDC != tt.expectedMinRacksPerDC {
- t.Errorf("MinRacksPerDC = %d, want %d", dist.ReplicationConfig.MinRacksPerDC, tt.expectedMinRacksPerDC)
- }
- if dist.ReplicationConfig.MinNodesPerRack != tt.expectedMinNodesPerRack {
- t.Errorf("MinNodesPerRack = %d, want %d", dist.ReplicationConfig.MinNodesPerRack, tt.expectedMinNodesPerRack)
- }
- if dist.TargetShardsPerDC != tt.expectedTargetPerDC {
- t.Errorf("TargetShardsPerDC = %d, want %d", dist.TargetShardsPerDC, tt.expectedTargetPerDC)
- }
- if dist.TargetShardsPerRack != tt.expectedTargetPerRack {
- t.Errorf("TargetShardsPerRack = %d, want %d", dist.TargetShardsPerRack, tt.expectedTargetPerRack)
- }
- if dist.TargetShardsPerNode != tt.expectedTargetPerNode {
- t.Errorf("TargetShardsPerNode = %d, want %d", dist.TargetShardsPerNode, tt.expectedTargetPerNode)
- }
-
- t.Logf("Distribution for %s: %s", tt.name, dist.String())
- })
- }
-}
-
-func TestFaultToleranceAnalysis(t *testing.T) {
- tests := []struct {
- name string
- ecConfig ECConfig
- replication string
- canSurviveDC bool
- canSurviveRack bool
- }{
- // 10+4 = 14 shards, need 10 to reconstruct, can lose 4
- {"10+4 000", DefaultECConfig(), "000", false, false}, // All in one, any failure is fatal
- {"10+4 100", DefaultECConfig(), "100", false, false}, // 7 per DC/rack, 7 remaining < 10
- {"10+4 200", DefaultECConfig(), "200", false, false}, // 5 per DC/rack, 9 remaining < 10
- {"10+4 110", DefaultECConfig(), "110", false, true}, // 4 per rack, 10 remaining = enough for rack
-
- // 8+4 = 12 shards, need 8 to reconstruct, can lose 4
- {"8+4 100", ECConfig{8, 4}, "100", false, false}, // 6 per DC/rack, 6 remaining < 8
- {"8+4 200", ECConfig{8, 4}, "200", true, true}, // 4 per DC/rack, 8 remaining = enough!
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- rep, _ := NewReplicationConfigFromString(tt.replication)
- dist := CalculateDistribution(tt.ecConfig, rep)
-
- if dist.CanSurviveDCFailure() != tt.canSurviveDC {
- t.Errorf("CanSurviveDCFailure() = %v, want %v", dist.CanSurviveDCFailure(), tt.canSurviveDC)
- }
- if dist.CanSurviveRackFailure() != tt.canSurviveRack {
- t.Errorf("CanSurviveRackFailure() = %v, want %v", dist.CanSurviveRackFailure(), tt.canSurviveRack)
- }
-
- t.Log(dist.FaultToleranceAnalysis())
- })
- }
-}
-
-func TestMinDCsForDCFaultTolerance(t *testing.T) {
- tests := []struct {
- name string
- ecConfig ECConfig
- minDCs int
- }{
- // 10+4: can lose 4, so max 4 per DC, 14/4 = 4 DCs needed
- {"10+4", DefaultECConfig(), 4},
- // 8+4: can lose 4, so max 4 per DC, 12/4 = 3 DCs needed
- {"8+4", ECConfig{8, 4}, 3},
- // 6+3: can lose 3, so max 3 per DC, 9/3 = 3 DCs needed
- {"6+3", ECConfig{6, 3}, 3},
- // 4+2: can lose 2, so max 2 per DC, 6/2 = 3 DCs needed
- {"4+2", ECConfig{4, 2}, 3},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- rep, _ := NewReplicationConfigFromString("000")
- dist := CalculateDistribution(tt.ecConfig, rep)
-
- if dist.MinDCsForDCFaultTolerance() != tt.minDCs {
- t.Errorf("MinDCsForDCFaultTolerance() = %d, want %d",
- dist.MinDCsForDCFaultTolerance(), tt.minDCs)
- }
-
- t.Logf("%s: needs %d DCs for DC fault tolerance", tt.name, dist.MinDCsForDCFaultTolerance())
- })
- }
-}
-
-func TestTopologyAnalysis(t *testing.T) {
- analysis := NewTopologyAnalysis()
-
- // Add nodes to topology
- node1 := &TopologyNode{
- NodeID: "node1",
- DataCenter: "dc1",
- Rack: "rack1",
- FreeSlots: 5,
- }
- node2 := &TopologyNode{
- NodeID: "node2",
- DataCenter: "dc1",
- Rack: "rack2",
- FreeSlots: 10,
- }
- node3 := &TopologyNode{
- NodeID: "node3",
- DataCenter: "dc2",
- Rack: "rack3",
- FreeSlots: 10,
- }
-
- analysis.AddNode(node1)
- analysis.AddNode(node2)
- analysis.AddNode(node3)
-
- // Add shard locations (all on node1)
- for i := 0; i < 14; i++ {
- analysis.AddShardLocation(ShardLocation{
- ShardID: i,
- NodeID: "node1",
- DataCenter: "dc1",
- Rack: "rack1",
- })
- }
-
- analysis.Finalize()
-
- // Verify counts
- if analysis.TotalShards != 14 {
- t.Errorf("TotalShards = %d, want 14", analysis.TotalShards)
- }
- if analysis.ShardsByDC["dc1"] != 14 {
- t.Errorf("ShardsByDC[dc1] = %d, want 14", analysis.ShardsByDC["dc1"])
- }
- if analysis.ShardsByRack["rack1"] != 14 {
- t.Errorf("ShardsByRack[rack1] = %d, want 14", analysis.ShardsByRack["rack1"])
- }
- if analysis.ShardsByNode["node1"] != 14 {
- t.Errorf("ShardsByNode[node1] = %d, want 14", analysis.ShardsByNode["node1"])
- }
-
- t.Log(analysis.DetailedString())
-}
-
-func TestRebalancer(t *testing.T) {
- // Build topology: 2 DCs, 2 racks each, all shards on one node
- analysis := NewTopologyAnalysis()
-
- // Add nodes
- nodes := []*TopologyNode{
- {NodeID: "dc1-rack1-node1", DataCenter: "dc1", Rack: "dc1-rack1", FreeSlots: 0},
- {NodeID: "dc1-rack2-node1", DataCenter: "dc1", Rack: "dc1-rack2", FreeSlots: 10},
- {NodeID: "dc2-rack1-node1", DataCenter: "dc2", Rack: "dc2-rack1", FreeSlots: 10},
- {NodeID: "dc2-rack2-node1", DataCenter: "dc2", Rack: "dc2-rack2", FreeSlots: 10},
- }
- for _, node := range nodes {
- analysis.AddNode(node)
- }
-
- // Add all 14 shards to first node
- for i := 0; i < 14; i++ {
- analysis.AddShardLocation(ShardLocation{
- ShardID: i,
- NodeID: "dc1-rack1-node1",
- DataCenter: "dc1",
- Rack: "dc1-rack1",
- })
- }
- analysis.Finalize()
-
- // Create rebalancer with 110 replication (2 DCs, 2 racks each)
- ec := DefaultECConfig()
- rep, _ := NewReplicationConfigFromString("110")
- rebalancer := NewRebalancer(ec, rep)
-
- plan, err := rebalancer.PlanRebalance(analysis)
- if err != nil {
- t.Fatalf("PlanRebalance failed: %v", err)
- }
-
- t.Logf("Planned %d moves", plan.TotalMoves)
- t.Log(plan.DetailedString())
-
- // Verify we're moving shards to dc2
- movedToDC2 := 0
- for _, move := range plan.Moves {
- if move.DestNode.DataCenter == "dc2" {
- movedToDC2++
- }
- }
-
- if movedToDC2 == 0 {
- t.Error("Expected some moves to dc2")
- }
-
- // With "110" replication, target is 7 shards per DC
- // Starting with 14 in dc1, should plan to move 7 to dc2
- if plan.MovesAcrossDC < 7 {
- t.Errorf("Expected at least 7 cross-DC moves for 110 replication, got %d", plan.MovesAcrossDC)
- }
-}
-
-func TestCustomECRatios(t *testing.T) {
- // Test various custom EC ratios that seaweed-enterprise might use
- ratios := []struct {
- name string
- data int
- parity int
- }{
- {"4+2", 4, 2},
- {"6+3", 6, 3},
- {"8+2", 8, 2},
- {"8+4", 8, 4},
- {"10+4", 10, 4},
- {"12+4", 12, 4},
- {"16+4", 16, 4},
- }
-
- for _, ratio := range ratios {
- t.Run(ratio.name, func(t *testing.T) {
- ec, err := NewECConfig(ratio.data, ratio.parity)
- if err != nil {
- t.Fatalf("Failed to create EC config: %v", err)
- }
-
- rep, _ := NewReplicationConfigFromString("110")
- dist := CalculateDistribution(ec, rep)
-
- t.Logf("EC %s with replication 110:", ratio.name)
- t.Logf(" Total shards: %d", ec.TotalShards())
- t.Logf(" Can lose: %d shards", ec.MaxTolerableLoss())
- t.Logf(" Target per DC: %d", dist.TargetShardsPerDC)
- t.Logf(" Target per rack: %d", dist.TargetShardsPerRack)
- t.Logf(" Min DCs for DC fault tolerance: %d", dist.MinDCsForDCFaultTolerance())
-
- // Verify basic sanity
- if dist.TargetShardsPerDC*2 < ec.TotalShards() {
- t.Errorf("Target per DC (%d) * 2 should be >= total (%d)",
- dist.TargetShardsPerDC, ec.TotalShards())
- }
- })
- }
-}
-
-func TestShardClassification(t *testing.T) {
- ec := DefaultECConfig() // 10+4
-
- // Test IsDataShard
- for i := 0; i < 10; i++ {
- if !ec.IsDataShard(i) {
- t.Errorf("Shard %d should be a data shard", i)
- }
- if ec.IsParityShard(i) {
- t.Errorf("Shard %d should not be a parity shard", i)
- }
- }
-
- // Test IsParityShard
- for i := 10; i < 14; i++ {
- if ec.IsDataShard(i) {
- t.Errorf("Shard %d should not be a data shard", i)
- }
- if !ec.IsParityShard(i) {
- t.Errorf("Shard %d should be a parity shard", i)
- }
- }
-
- // Test with custom 8+4 EC
- ec84, _ := NewECConfig(8, 4)
- for i := 0; i < 8; i++ {
- if !ec84.IsDataShard(i) {
- t.Errorf("8+4 EC: Shard %d should be a data shard", i)
- }
- }
- for i := 8; i < 12; i++ {
- if !ec84.IsParityShard(i) {
- t.Errorf("8+4 EC: Shard %d should be a parity shard", i)
- }
- }
-}
-
-func TestSortShardsDataFirst(t *testing.T) {
- ec := DefaultECConfig() // 10+4
-
- // Mixed shards: [0, 10, 5, 11, 2, 12, 7, 13]
- shards := []int{0, 10, 5, 11, 2, 12, 7, 13}
- sorted := ec.SortShardsDataFirst(shards)
-
- t.Logf("Original: %v", shards)
- t.Logf("Sorted (data first): %v", sorted)
-
- // First 4 should be data shards (0, 5, 2, 7)
- for i := 0; i < 4; i++ {
- if !ec.IsDataShard(sorted[i]) {
- t.Errorf("Position %d should be a data shard, got %d", i, sorted[i])
- }
- }
-
- // Last 4 should be parity shards (10, 11, 12, 13)
- for i := 4; i < 8; i++ {
- if !ec.IsParityShard(sorted[i]) {
- t.Errorf("Position %d should be a parity shard, got %d", i, sorted[i])
- }
- }
-}
-
-func TestSortShardsParityFirst(t *testing.T) {
- ec := DefaultECConfig() // 10+4
-
- // Mixed shards: [0, 10, 5, 11, 2, 12, 7, 13]
- shards := []int{0, 10, 5, 11, 2, 12, 7, 13}
- sorted := ec.SortShardsParityFirst(shards)
-
- t.Logf("Original: %v", shards)
- t.Logf("Sorted (parity first): %v", sorted)
-
- // First 4 should be parity shards (10, 11, 12, 13)
- for i := 0; i < 4; i++ {
- if !ec.IsParityShard(sorted[i]) {
- t.Errorf("Position %d should be a parity shard, got %d", i, sorted[i])
- }
- }
-
- // Last 4 should be data shards (0, 5, 2, 7)
- for i := 4; i < 8; i++ {
- if !ec.IsDataShard(sorted[i]) {
- t.Errorf("Position %d should be a data shard, got %d", i, sorted[i])
- }
- }
-}
-
-func TestRebalancerPrefersMovingParityShards(t *testing.T) {
- // Build topology where one node has all shards including mix of data and parity
- analysis := NewTopologyAnalysis()
-
- // Node 1: Has all 14 shards (mixed data and parity)
- node1 := &TopologyNode{
- NodeID: "node1",
- DataCenter: "dc1",
- Rack: "rack1",
- FreeSlots: 0,
- }
- analysis.AddNode(node1)
-
- // Node 2: Empty, ready to receive
- node2 := &TopologyNode{
- NodeID: "node2",
- DataCenter: "dc1",
- Rack: "rack1",
- FreeSlots: 10,
- }
- analysis.AddNode(node2)
-
- // Add all 14 shards to node1
- for i := 0; i < 14; i++ {
- analysis.AddShardLocation(ShardLocation{
- ShardID: i,
- NodeID: "node1",
- DataCenter: "dc1",
- Rack: "rack1",
- })
- }
- analysis.Finalize()
-
- // Create rebalancer
- ec := DefaultECConfig()
- rep, _ := NewReplicationConfigFromString("000")
- rebalancer := NewRebalancer(ec, rep)
-
- plan, err := rebalancer.PlanRebalance(analysis)
- if err != nil {
- t.Fatalf("PlanRebalance failed: %v", err)
- }
-
- t.Logf("Planned %d moves", len(plan.Moves))
-
- // Check that parity shards are moved first
- parityMovesFirst := 0
- dataMovesFirst := 0
- seenDataMove := false
-
- for _, move := range plan.Moves {
- isParity := ec.IsParityShard(move.ShardID)
- t.Logf("Move shard %d (parity=%v): %s -> %s",
- move.ShardID, isParity, move.SourceNode.NodeID, move.DestNode.NodeID)
-
- if isParity && !seenDataMove {
- parityMovesFirst++
- } else if !isParity {
- seenDataMove = true
- dataMovesFirst++
- }
- }
-
- t.Logf("Parity moves before first data move: %d", parityMovesFirst)
- t.Logf("Data moves: %d", dataMovesFirst)
-
- // With 10+4 EC, there are 4 parity shards
- // They should be moved before data shards when possible
- if parityMovesFirst < 4 && len(plan.Moves) >= 4 {
- t.Logf("Note: Expected parity shards to be moved first, but got %d parity moves before data moves", parityMovesFirst)
- }
-}
-
-func TestDistributionSummary(t *testing.T) {
- ec := DefaultECConfig()
- rep, _ := NewReplicationConfigFromString("110")
- dist := CalculateDistribution(ec, rep)
-
- summary := dist.Summary()
- t.Log(summary)
-
- if len(summary) == 0 {
- t.Error("Summary should not be empty")
- }
-
- analysis := dist.FaultToleranceAnalysis()
- t.Log(analysis)
-
- if len(analysis) == 0 {
- t.Error("Fault tolerance analysis should not be empty")
- }
-}
diff --git a/weed/storage/erasure_coding/distribution/rebalancer.go b/weed/storage/erasure_coding/distribution/rebalancer.go
index cd8b87de6..2442e59a9 100644
--- a/weed/storage/erasure_coding/distribution/rebalancer.go
+++ b/weed/storage/erasure_coding/distribution/rebalancer.go
@@ -1,10 +1,5 @@
package distribution
-import (
- "fmt"
- "slices"
-)
-
// ShardMove represents a planned shard move
type ShardMove struct {
ShardID int
@@ -13,12 +8,6 @@ type ShardMove struct {
Reason string
}
-// String returns a human-readable description of the move
-func (m ShardMove) String() string {
- return fmt.Sprintf("shard %d: %s -> %s (%s)",
- m.ShardID, m.SourceNode.NodeID, m.DestNode.NodeID, m.Reason)
-}
-
// RebalancePlan contains the complete plan for rebalancing EC shards
type RebalancePlan struct {
Moves []ShardMove
@@ -32,346 +21,8 @@ type RebalancePlan struct {
MovesWithinRack int
}
-// String returns a summary of the plan
-func (p *RebalancePlan) String() string {
- return fmt.Sprintf("RebalancePlan{moves:%d, acrossDC:%d, acrossRack:%d, withinRack:%d}",
- p.TotalMoves, p.MovesAcrossDC, p.MovesAcrossRack, p.MovesWithinRack)
-}
-
-// DetailedString returns a detailed multi-line summary
-func (p *RebalancePlan) DetailedString() string {
- s := fmt.Sprintf("Rebalance Plan:\n")
- s += fmt.Sprintf(" Total Moves: %d\n", p.TotalMoves)
- s += fmt.Sprintf(" Across DC: %d\n", p.MovesAcrossDC)
- s += fmt.Sprintf(" Across Rack: %d\n", p.MovesAcrossRack)
- s += fmt.Sprintf(" Within Rack: %d\n", p.MovesWithinRack)
- s += fmt.Sprintf("\nMoves:\n")
- for i, move := range p.Moves {
- s += fmt.Sprintf(" %d. %s\n", i+1, move.String())
- }
- return s
-}
-
// Rebalancer plans shard moves to achieve proportional distribution
type Rebalancer struct {
ecConfig ECConfig
repConfig ReplicationConfig
}
-
-// NewRebalancer creates a new rebalancer with the given configuration
-func NewRebalancer(ec ECConfig, rep ReplicationConfig) *Rebalancer {
- return &Rebalancer{
- ecConfig: ec,
- repConfig: rep,
- }
-}
-
-// PlanRebalance creates a rebalancing plan based on current topology analysis
-func (r *Rebalancer) PlanRebalance(analysis *TopologyAnalysis) (*RebalancePlan, error) {
- dist := CalculateDistribution(r.ecConfig, r.repConfig)
-
- plan := &RebalancePlan{
- Distribution: dist,
- Analysis: analysis,
- }
-
- // Step 1: Balance across data centers
- dcMoves := r.planDCMoves(analysis, dist)
- for _, move := range dcMoves {
- plan.Moves = append(plan.Moves, move)
- plan.MovesAcrossDC++
- }
-
- // Update analysis after DC moves (for planning purposes)
- r.applyMovesToAnalysis(analysis, dcMoves)
-
- // Step 2: Balance across racks within each DC
- rackMoves := r.planRackMoves(analysis, dist)
- for _, move := range rackMoves {
- plan.Moves = append(plan.Moves, move)
- plan.MovesAcrossRack++
- }
-
- // Update analysis after rack moves
- r.applyMovesToAnalysis(analysis, rackMoves)
-
- // Step 3: Balance across nodes within each rack
- nodeMoves := r.planNodeMoves(analysis, dist)
- for _, move := range nodeMoves {
- plan.Moves = append(plan.Moves, move)
- plan.MovesWithinRack++
- }
-
- plan.TotalMoves = len(plan.Moves)
-
- return plan, nil
-}
-
-// planDCMoves plans moves to balance shards across data centers
-func (r *Rebalancer) planDCMoves(analysis *TopologyAnalysis, dist *ECDistribution) []ShardMove {
- var moves []ShardMove
-
- overDCs := CalculateDCExcess(analysis, dist)
- underDCs := CalculateUnderservedDCs(analysis, dist)
-
- underIdx := 0
- for _, over := range overDCs {
- for over.Excess > 0 && underIdx < len(underDCs) {
- destDC := underDCs[underIdx]
-
- // Find a shard and source node
- shardID, srcNode := r.pickShardToMove(analysis, over.Nodes)
- if srcNode == nil {
- break
- }
-
- // Find destination node in target DC
- destNode := r.pickBestDestination(analysis, destDC, "", dist)
- if destNode == nil {
- underIdx++
- continue
- }
-
- moves = append(moves, ShardMove{
- ShardID: shardID,
- SourceNode: srcNode,
- DestNode: destNode,
- Reason: fmt.Sprintf("balance DC: %s -> %s", srcNode.DataCenter, destDC),
- })
-
- over.Excess--
- analysis.ShardsByDC[srcNode.DataCenter]--
- analysis.ShardsByDC[destDC]++
-
- // Check if destDC reached target
- if analysis.ShardsByDC[destDC] >= dist.TargetShardsPerDC {
- underIdx++
- }
- }
- }
-
- return moves
-}
-
-// planRackMoves plans moves to balance shards across racks within each DC
-func (r *Rebalancer) planRackMoves(analysis *TopologyAnalysis, dist *ECDistribution) []ShardMove {
- var moves []ShardMove
-
- for dc := range analysis.DCToRacks {
- dcShards := analysis.ShardsByDC[dc]
- numRacks := len(analysis.DCToRacks[dc])
- if numRacks == 0 {
- continue
- }
-
- targetPerRack := ceilDivide(dcShards, max(numRacks, dist.ReplicationConfig.MinRacksPerDC))
-
- overRacks := CalculateRackExcess(analysis, dc, targetPerRack)
- underRacks := CalculateUnderservedRacks(analysis, dc, targetPerRack)
-
- underIdx := 0
- for _, over := range overRacks {
- for over.Excess > 0 && underIdx < len(underRacks) {
- destRack := underRacks[underIdx]
-
- // Find shard and source node
- shardID, srcNode := r.pickShardToMove(analysis, over.Nodes)
- if srcNode == nil {
- break
- }
-
- // Find destination node in target rack
- destNode := r.pickBestDestination(analysis, dc, destRack, dist)
- if destNode == nil {
- underIdx++
- continue
- }
-
- moves = append(moves, ShardMove{
- ShardID: shardID,
- SourceNode: srcNode,
- DestNode: destNode,
- Reason: fmt.Sprintf("balance rack: %s -> %s", srcNode.Rack, destRack),
- })
-
- over.Excess--
- analysis.ShardsByRack[srcNode.Rack]--
- analysis.ShardsByRack[destRack]++
-
- if analysis.ShardsByRack[destRack] >= targetPerRack {
- underIdx++
- }
- }
- }
- }
-
- return moves
-}
-
-// planNodeMoves plans moves to balance shards across nodes within each rack
-func (r *Rebalancer) planNodeMoves(analysis *TopologyAnalysis, dist *ECDistribution) []ShardMove {
- var moves []ShardMove
-
- for rack, nodes := range analysis.RackToNodes {
- if len(nodes) <= 1 {
- continue
- }
-
- rackShards := analysis.ShardsByRack[rack]
- targetPerNode := ceilDivide(rackShards, max(len(nodes), dist.ReplicationConfig.MinNodesPerRack))
-
- // Find over and under nodes
- var overNodes []*TopologyNode
- var underNodes []*TopologyNode
-
- for _, node := range nodes {
- count := analysis.ShardsByNode[node.NodeID]
- if count > targetPerNode {
- overNodes = append(overNodes, node)
- } else if count < targetPerNode {
- underNodes = append(underNodes, node)
- }
- }
-
- // Sort by excess/deficit
- slices.SortFunc(overNodes, func(a, b *TopologyNode) int {
- return analysis.ShardsByNode[b.NodeID] - analysis.ShardsByNode[a.NodeID]
- })
-
- underIdx := 0
- for _, srcNode := range overNodes {
- excess := analysis.ShardsByNode[srcNode.NodeID] - targetPerNode
-
- for excess > 0 && underIdx < len(underNodes) {
- destNode := underNodes[underIdx]
-
- // Pick a shard from this node, preferring parity shards
- shards := analysis.NodeToShards[srcNode.NodeID]
- if len(shards) == 0 {
- break
- }
-
- // Find a parity shard first, fallback to data shard
- shardID := -1
- shardIdx := -1
- for i, s := range shards {
- if r.ecConfig.IsParityShard(s) {
- shardID = s
- shardIdx = i
- break
- }
- }
- if shardID == -1 {
- shardID = shards[0]
- shardIdx = 0
- }
-
- moves = append(moves, ShardMove{
- ShardID: shardID,
- SourceNode: srcNode,
- DestNode: destNode,
- Reason: fmt.Sprintf("balance node: %s -> %s", srcNode.NodeID, destNode.NodeID),
- })
-
- excess--
- analysis.ShardsByNode[srcNode.NodeID]--
- analysis.ShardsByNode[destNode.NodeID]++
-
- // Update shard lists - remove the specific shard we picked
- analysis.NodeToShards[srcNode.NodeID] = append(
- shards[:shardIdx], shards[shardIdx+1:]...)
- analysis.NodeToShards[destNode.NodeID] = append(
- analysis.NodeToShards[destNode.NodeID], shardID)
-
- if analysis.ShardsByNode[destNode.NodeID] >= targetPerNode {
- underIdx++
- }
- }
- }
- }
-
- return moves
-}
-
-// pickShardToMove selects a shard and its node from the given nodes.
-// It prefers to move parity shards first, keeping data shards spread out
-// since data shards serve read requests while parity shards are only for reconstruction.
-func (r *Rebalancer) pickShardToMove(analysis *TopologyAnalysis, nodes []*TopologyNode) (int, *TopologyNode) {
- // Sort by shard count (most shards first)
- slices.SortFunc(nodes, func(a, b *TopologyNode) int {
- return analysis.ShardsByNode[b.NodeID] - analysis.ShardsByNode[a.NodeID]
- })
-
- // First pass: try to find a parity shard to move (prefer moving parity)
- for _, node := range nodes {
- shards := analysis.NodeToShards[node.NodeID]
- for _, shardID := range shards {
- if r.ecConfig.IsParityShard(shardID) {
- return shardID, node
- }
- }
- }
-
- // Second pass: if no parity shards, move a data shard
- for _, node := range nodes {
- shards := analysis.NodeToShards[node.NodeID]
- if len(shards) > 0 {
- return shards[0], node
- }
- }
-
- return -1, nil
-}
-
-// pickBestDestination selects the best destination node
-func (r *Rebalancer) pickBestDestination(analysis *TopologyAnalysis, targetDC, targetRack string, dist *ECDistribution) *TopologyNode {
- var candidates []*TopologyNode
-
- // Collect candidates
- for _, node := range analysis.AllNodes {
- // Filter by DC if specified
- if targetDC != "" && node.DataCenter != targetDC {
- continue
- }
- // Filter by rack if specified
- if targetRack != "" && node.Rack != targetRack {
- continue
- }
- // Check capacity
- if node.FreeSlots <= 0 {
- continue
- }
- // Check max shards limit
- if analysis.ShardsByNode[node.NodeID] >= dist.MaxShardsPerNode {
- continue
- }
-
- candidates = append(candidates, node)
- }
-
- if len(candidates) == 0 {
- return nil
- }
-
- // Sort by: 1) fewer shards, 2) more free slots
- slices.SortFunc(candidates, func(a, b *TopologyNode) int {
- aShards := analysis.ShardsByNode[a.NodeID]
- bShards := analysis.ShardsByNode[b.NodeID]
- if aShards != bShards {
- return aShards - bShards
- }
- return b.FreeSlots - a.FreeSlots
- })
-
- return candidates[0]
-}
-
-// applyMovesToAnalysis is a no-op placeholder for potential future use.
-// Note: All planners (planDCMoves, planRackMoves, planNodeMoves) update
-// their respective counts (ShardsByDC, ShardsByRack, ShardsByNode) and
-// shard lists (NodeToShards) inline during planning. This avoids duplicate
-// updates that would occur if we also updated counts here.
-func (r *Rebalancer) applyMovesToAnalysis(analysis *TopologyAnalysis, moves []ShardMove) {
- // Counts are already updated by the individual planners.
- // This function is kept for API compatibility and potential future use.
-}
diff --git a/weed/storage/erasure_coding/ec_shards_info.go b/weed/storage/erasure_coding/ec_shards_info.go
index 55838eb4e..0d2ce5b63 100644
--- a/weed/storage/erasure_coding/ec_shards_info.go
+++ b/weed/storage/erasure_coding/ec_shards_info.go
@@ -53,19 +53,6 @@ func NewShardsInfo() *ShardsInfo {
}
}
-// Initializes a ShardsInfo from a ECVolume.
-func ShardsInfoFromVolume(ev *EcVolume) *ShardsInfo {
- res := &ShardsInfo{
- shards: make([]ShardInfo, len(ev.Shards)),
- }
- // Build shards directly to avoid locking in Set() since res is not yet shared
- for i, s := range ev.Shards {
- res.shards[i] = NewShardInfo(s.ShardId, ShardSize(s.Size()))
- res.shardBits = res.shardBits.Set(s.ShardId)
- }
- return res
-}
-
// Initializes a ShardsInfo from a VolumeEcShardInformationMessage proto.
func ShardsInfoFromVolumeEcShardInformationMessage(vi *master_pb.VolumeEcShardInformationMessage) *ShardsInfo {
res := NewShardsInfo()
diff --git a/weed/storage/erasure_coding/placement/placement.go b/weed/storage/erasure_coding/placement/placement.go
index 67e21c1f8..bda050b82 100644
--- a/weed/storage/erasure_coding/placement/placement.go
+++ b/weed/storage/erasure_coding/placement/placement.go
@@ -64,18 +64,6 @@ type PlacementRequest struct {
PreferDifferentRacks bool
}
-// DefaultPlacementRequest returns the default placement configuration
-func DefaultPlacementRequest() PlacementRequest {
- return PlacementRequest{
- ShardsNeeded: 14,
- MaxShardsPerServer: 0,
- MaxShardsPerRack: 0,
- MaxTaskLoad: 5,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-}
-
// PlacementResult contains the selected destinations for EC shards
type PlacementResult struct {
SelectedDisks []*DiskCandidate
@@ -270,15 +258,6 @@ func groupDisksByRack(disks []*DiskCandidate) map[string][]*DiskCandidate {
return result
}
-// groupDisksByServer groups disks by their server
-func groupDisksByServer(disks []*DiskCandidate) map[string][]*DiskCandidate {
- result := make(map[string][]*DiskCandidate)
- for _, disk := range disks {
- result[disk.NodeID] = append(result[disk.NodeID], disk)
- }
- return result
-}
-
// getRackKey returns the unique key for a rack (dc:rack)
func getRackKey(disk *DiskCandidate) string {
return fmt.Sprintf("%s:%s", disk.DataCenter, disk.Rack)
@@ -393,28 +372,3 @@ func addDiskToResult(result *PlacementResult, disk *DiskCandidate,
result.ShardsPerRack[rackKey]++
result.ShardsPerDC[disk.DataCenter]++
}
-
-// VerifySpread checks if the placement result meets diversity requirements
-func VerifySpread(result *PlacementResult, minServers, minRacks int) error {
- if result.ServersUsed < minServers {
- return fmt.Errorf("only %d servers used, need at least %d", result.ServersUsed, minServers)
- }
- if result.RacksUsed < minRacks {
- return fmt.Errorf("only %d racks used, need at least %d", result.RacksUsed, minRacks)
- }
- return nil
-}
-
-// CalculateIdealDistribution returns the ideal number of shards per server
-// when we have a certain number of shards and servers
-func CalculateIdealDistribution(totalShards, numServers int) (min, max int) {
- if numServers <= 0 {
- return 0, totalShards
- }
- min = totalShards / numServers
- max = min
- if totalShards%numServers != 0 {
- max = min + 1
- }
- return
-}
diff --git a/weed/storage/erasure_coding/placement/placement_test.go b/weed/storage/erasure_coding/placement/placement_test.go
deleted file mode 100644
index 7501dfa9e..000000000
--- a/weed/storage/erasure_coding/placement/placement_test.go
+++ /dev/null
@@ -1,517 +0,0 @@
-package placement
-
-import (
- "testing"
-)
-
-// Helper function to create disk candidates for testing
-func makeDisk(nodeID string, diskID uint32, dc, rack string, freeSlots int) *DiskCandidate {
- return &DiskCandidate{
- NodeID: nodeID,
- DiskID: diskID,
- DataCenter: dc,
- Rack: rack,
- VolumeCount: 0,
- MaxVolumeCount: 100,
- ShardCount: 0,
- FreeSlots: freeSlots,
- LoadCount: 0,
- }
-}
-
-func TestSelectDestinations_SingleRack(t *testing.T) {
- // Test: 3 servers in same rack, each with 2 disks, need 6 shards
- // Expected: Should spread across all 6 disks (one per disk)
- disks := []*DiskCandidate{
- makeDisk("server1", 0, "dc1", "rack1", 10),
- makeDisk("server1", 1, "dc1", "rack1", 10),
- makeDisk("server2", 0, "dc1", "rack1", 10),
- makeDisk("server2", 1, "dc1", "rack1", 10),
- makeDisk("server3", 0, "dc1", "rack1", 10),
- makeDisk("server3", 1, "dc1", "rack1", 10),
- }
-
- config := PlacementRequest{
- ShardsNeeded: 6,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-
- result, err := SelectDestinations(disks, config)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
-
- if len(result.SelectedDisks) != 6 {
- t.Errorf("expected 6 selected disks, got %d", len(result.SelectedDisks))
- }
-
- // Verify all 3 servers are used
- if result.ServersUsed != 3 {
- t.Errorf("expected 3 servers used, got %d", result.ServersUsed)
- }
-
- // Verify each disk is unique
- diskSet := make(map[string]bool)
- for _, disk := range result.SelectedDisks {
- key := getDiskKey(disk)
- if diskSet[key] {
- t.Errorf("disk %s selected multiple times", key)
- }
- diskSet[key] = true
- }
-}
-
-func TestSelectDestinations_MultipleRacks(t *testing.T) {
- // Test: 2 racks with 2 servers each, each server has 2 disks
- // Need 8 shards
- // Expected: Should spread across all 8 disks
- disks := []*DiskCandidate{
- makeDisk("server1", 0, "dc1", "rack1", 10),
- makeDisk("server1", 1, "dc1", "rack1", 10),
- makeDisk("server2", 0, "dc1", "rack1", 10),
- makeDisk("server2", 1, "dc1", "rack1", 10),
- makeDisk("server3", 0, "dc1", "rack2", 10),
- makeDisk("server3", 1, "dc1", "rack2", 10),
- makeDisk("server4", 0, "dc1", "rack2", 10),
- makeDisk("server4", 1, "dc1", "rack2", 10),
- }
-
- config := PlacementRequest{
- ShardsNeeded: 8,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-
- result, err := SelectDestinations(disks, config)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
-
- if len(result.SelectedDisks) != 8 {
- t.Errorf("expected 8 selected disks, got %d", len(result.SelectedDisks))
- }
-
- // Verify all 4 servers are used
- if result.ServersUsed != 4 {
- t.Errorf("expected 4 servers used, got %d", result.ServersUsed)
- }
-
- // Verify both racks are used
- if result.RacksUsed != 2 {
- t.Errorf("expected 2 racks used, got %d", result.RacksUsed)
- }
-}
-
-func TestSelectDestinations_PrefersDifferentServers(t *testing.T) {
- // Test: 4 servers with 4 disks each, need 4 shards
- // Expected: Should use one disk from each server
- disks := []*DiskCandidate{
- makeDisk("server1", 0, "dc1", "rack1", 10),
- makeDisk("server1", 1, "dc1", "rack1", 10),
- makeDisk("server1", 2, "dc1", "rack1", 10),
- makeDisk("server1", 3, "dc1", "rack1", 10),
- makeDisk("server2", 0, "dc1", "rack1", 10),
- makeDisk("server2", 1, "dc1", "rack1", 10),
- makeDisk("server2", 2, "dc1", "rack1", 10),
- makeDisk("server2", 3, "dc1", "rack1", 10),
- makeDisk("server3", 0, "dc1", "rack1", 10),
- makeDisk("server3", 1, "dc1", "rack1", 10),
- makeDisk("server3", 2, "dc1", "rack1", 10),
- makeDisk("server3", 3, "dc1", "rack1", 10),
- makeDisk("server4", 0, "dc1", "rack1", 10),
- makeDisk("server4", 1, "dc1", "rack1", 10),
- makeDisk("server4", 2, "dc1", "rack1", 10),
- makeDisk("server4", 3, "dc1", "rack1", 10),
- }
-
- config := PlacementRequest{
- ShardsNeeded: 4,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-
- result, err := SelectDestinations(disks, config)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
-
- if len(result.SelectedDisks) != 4 {
- t.Errorf("expected 4 selected disks, got %d", len(result.SelectedDisks))
- }
-
- // Verify all 4 servers are used (one shard per server)
- if result.ServersUsed != 4 {
- t.Errorf("expected 4 servers used, got %d", result.ServersUsed)
- }
-
- // Each server should have exactly 1 shard
- for server, count := range result.ShardsPerServer {
- if count != 1 {
- t.Errorf("server %s has %d shards, expected 1", server, count)
- }
- }
-}
-
-func TestSelectDestinations_SpilloverToMultipleDisksPerServer(t *testing.T) {
- // Test: 2 servers with 4 disks each, need 6 shards
- // Expected: First pick one from each server (2 shards), then one more from each (4 shards),
- // then fill remaining from any server (6 shards)
- disks := []*DiskCandidate{
- makeDisk("server1", 0, "dc1", "rack1", 10),
- makeDisk("server1", 1, "dc1", "rack1", 10),
- makeDisk("server1", 2, "dc1", "rack1", 10),
- makeDisk("server1", 3, "dc1", "rack1", 10),
- makeDisk("server2", 0, "dc1", "rack1", 10),
- makeDisk("server2", 1, "dc1", "rack1", 10),
- makeDisk("server2", 2, "dc1", "rack1", 10),
- makeDisk("server2", 3, "dc1", "rack1", 10),
- }
-
- config := PlacementRequest{
- ShardsNeeded: 6,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-
- result, err := SelectDestinations(disks, config)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
-
- if len(result.SelectedDisks) != 6 {
- t.Errorf("expected 6 selected disks, got %d", len(result.SelectedDisks))
- }
-
- // Both servers should be used
- if result.ServersUsed != 2 {
- t.Errorf("expected 2 servers used, got %d", result.ServersUsed)
- }
-
- // Each server should have exactly 3 shards (balanced)
- for server, count := range result.ShardsPerServer {
- if count != 3 {
- t.Errorf("server %s has %d shards, expected 3", server, count)
- }
- }
-}
-
-func TestSelectDestinations_MaxShardsPerServer(t *testing.T) {
- // Test: 2 servers with 4 disks each, need 6 shards, max 2 per server
- // Expected: Should only select 4 shards (2 per server limit)
- disks := []*DiskCandidate{
- makeDisk("server1", 0, "dc1", "rack1", 10),
- makeDisk("server1", 1, "dc1", "rack1", 10),
- makeDisk("server1", 2, "dc1", "rack1", 10),
- makeDisk("server1", 3, "dc1", "rack1", 10),
- makeDisk("server2", 0, "dc1", "rack1", 10),
- makeDisk("server2", 1, "dc1", "rack1", 10),
- makeDisk("server2", 2, "dc1", "rack1", 10),
- makeDisk("server2", 3, "dc1", "rack1", 10),
- }
-
- config := PlacementRequest{
- ShardsNeeded: 6,
- MaxShardsPerServer: 2,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-
- result, err := SelectDestinations(disks, config)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
-
- // Should only get 4 shards due to server limit
- if len(result.SelectedDisks) != 4 {
- t.Errorf("expected 4 selected disks (limit 2 per server), got %d", len(result.SelectedDisks))
- }
-
- // No server should exceed the limit
- for server, count := range result.ShardsPerServer {
- if count > 2 {
- t.Errorf("server %s has %d shards, exceeds limit of 2", server, count)
- }
- }
-}
-
-func TestSelectDestinations_14ShardsAcross7Servers(t *testing.T) {
- // Test: Real-world EC scenario - 14 shards across 7 servers with 2 disks each
- // Expected: Should spread evenly (2 shards per server)
- var disks []*DiskCandidate
- for i := 1; i <= 7; i++ {
- serverID := "server" + string(rune('0'+i))
- disks = append(disks, makeDisk(serverID, 0, "dc1", "rack1", 10))
- disks = append(disks, makeDisk(serverID, 1, "dc1", "rack1", 10))
- }
-
- config := PlacementRequest{
- ShardsNeeded: 14,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-
- result, err := SelectDestinations(disks, config)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
-
- if len(result.SelectedDisks) != 14 {
- t.Errorf("expected 14 selected disks, got %d", len(result.SelectedDisks))
- }
-
- // All 7 servers should be used
- if result.ServersUsed != 7 {
- t.Errorf("expected 7 servers used, got %d", result.ServersUsed)
- }
-
- // Each server should have exactly 2 shards
- for server, count := range result.ShardsPerServer {
- if count != 2 {
- t.Errorf("server %s has %d shards, expected 2", server, count)
- }
- }
-}
-
-func TestSelectDestinations_FewerServersThanShards(t *testing.T) {
- // Test: Only 3 servers but need 6 shards
- // Expected: Should distribute evenly (2 per server)
- disks := []*DiskCandidate{
- makeDisk("server1", 0, "dc1", "rack1", 10),
- makeDisk("server1", 1, "dc1", "rack1", 10),
- makeDisk("server1", 2, "dc1", "rack1", 10),
- makeDisk("server2", 0, "dc1", "rack1", 10),
- makeDisk("server2", 1, "dc1", "rack1", 10),
- makeDisk("server2", 2, "dc1", "rack1", 10),
- makeDisk("server3", 0, "dc1", "rack1", 10),
- makeDisk("server3", 1, "dc1", "rack1", 10),
- makeDisk("server3", 2, "dc1", "rack1", 10),
- }
-
- config := PlacementRequest{
- ShardsNeeded: 6,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-
- result, err := SelectDestinations(disks, config)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
-
- if len(result.SelectedDisks) != 6 {
- t.Errorf("expected 6 selected disks, got %d", len(result.SelectedDisks))
- }
-
- // All 3 servers should be used
- if result.ServersUsed != 3 {
- t.Errorf("expected 3 servers used, got %d", result.ServersUsed)
- }
-
- // Each server should have exactly 2 shards
- for server, count := range result.ShardsPerServer {
- if count != 2 {
- t.Errorf("server %s has %d shards, expected 2", server, count)
- }
- }
-}
-
-func TestSelectDestinations_NoSuitableDisks(t *testing.T) {
- // Test: All disks have no free slots
- disks := []*DiskCandidate{
- {NodeID: "server1", DiskID: 0, DataCenter: "dc1", Rack: "rack1", FreeSlots: 0},
- {NodeID: "server2", DiskID: 0, DataCenter: "dc1", Rack: "rack1", FreeSlots: 0},
- }
-
- config := PlacementRequest{
- ShardsNeeded: 4,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-
- _, err := SelectDestinations(disks, config)
- if err == nil {
- t.Error("expected error for no suitable disks, got nil")
- }
-}
-
-func TestSelectDestinations_EmptyInput(t *testing.T) {
- config := DefaultPlacementRequest()
- _, err := SelectDestinations([]*DiskCandidate{}, config)
- if err == nil {
- t.Error("expected error for empty input, got nil")
- }
-}
-
-func TestSelectDestinations_FiltersByLoad(t *testing.T) {
- // Test: Some disks have too high load
- disks := []*DiskCandidate{
- {NodeID: "server1", DiskID: 0, DataCenter: "dc1", Rack: "rack1", FreeSlots: 10, LoadCount: 10},
- {NodeID: "server2", DiskID: 0, DataCenter: "dc1", Rack: "rack1", FreeSlots: 10, LoadCount: 2},
- {NodeID: "server3", DiskID: 0, DataCenter: "dc1", Rack: "rack1", FreeSlots: 10, LoadCount: 1},
- }
-
- config := PlacementRequest{
- ShardsNeeded: 2,
- MaxTaskLoad: 5,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-
- result, err := SelectDestinations(disks, config)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
-
- // Should only select from server2 and server3 (server1 has too high load)
- for _, disk := range result.SelectedDisks {
- if disk.NodeID == "server1" {
- t.Errorf("disk from server1 should not be selected (load too high)")
- }
- }
-}
-
-func TestCalculateDiskScore(t *testing.T) {
- // Test that score calculation works as expected
- lowUtilDisk := &DiskCandidate{
- VolumeCount: 10,
- MaxVolumeCount: 100,
- ShardCount: 0,
- LoadCount: 0,
- }
-
- highUtilDisk := &DiskCandidate{
- VolumeCount: 90,
- MaxVolumeCount: 100,
- ShardCount: 5,
- LoadCount: 5,
- }
-
- lowScore := calculateDiskScore(lowUtilDisk)
- highScore := calculateDiskScore(highUtilDisk)
-
- if lowScore <= highScore {
- t.Errorf("low utilization disk should have higher score: low=%f, high=%f", lowScore, highScore)
- }
-}
-
-func TestCalculateIdealDistribution(t *testing.T) {
- tests := []struct {
- totalShards int
- numServers int
- expectedMin int
- expectedMax int
- }{
- {14, 7, 2, 2}, // Even distribution
- {14, 4, 3, 4}, // Uneven: 14/4 = 3 remainder 2
- {6, 3, 2, 2}, // Even distribution
- {7, 3, 2, 3}, // Uneven: 7/3 = 2 remainder 1
- {10, 0, 0, 10}, // Edge case: no servers
- {0, 5, 0, 0}, // Edge case: no shards
- }
-
- for _, tt := range tests {
- min, max := CalculateIdealDistribution(tt.totalShards, tt.numServers)
- if min != tt.expectedMin || max != tt.expectedMax {
- t.Errorf("CalculateIdealDistribution(%d, %d) = (%d, %d), want (%d, %d)",
- tt.totalShards, tt.numServers, min, max, tt.expectedMin, tt.expectedMax)
- }
- }
-}
-
-func TestVerifySpread(t *testing.T) {
- result := &PlacementResult{
- ServersUsed: 3,
- RacksUsed: 2,
- }
-
- // Should pass
- if err := VerifySpread(result, 3, 2); err != nil {
- t.Errorf("unexpected error: %v", err)
- }
-
- // Should fail - not enough servers
- if err := VerifySpread(result, 4, 2); err == nil {
- t.Error("expected error for insufficient servers")
- }
-
- // Should fail - not enough racks
- if err := VerifySpread(result, 3, 3); err == nil {
- t.Error("expected error for insufficient racks")
- }
-}
-
-func TestSelectDestinations_MultiDC(t *testing.T) {
- // Test: 2 DCs, each with 2 racks, each rack has 2 servers
- disks := []*DiskCandidate{
- // DC1, Rack1
- makeDisk("dc1-r1-s1", 0, "dc1", "rack1", 10),
- makeDisk("dc1-r1-s1", 1, "dc1", "rack1", 10),
- makeDisk("dc1-r1-s2", 0, "dc1", "rack1", 10),
- makeDisk("dc1-r1-s2", 1, "dc1", "rack1", 10),
- // DC1, Rack2
- makeDisk("dc1-r2-s1", 0, "dc1", "rack2", 10),
- makeDisk("dc1-r2-s1", 1, "dc1", "rack2", 10),
- makeDisk("dc1-r2-s2", 0, "dc1", "rack2", 10),
- makeDisk("dc1-r2-s2", 1, "dc1", "rack2", 10),
- // DC2, Rack1
- makeDisk("dc2-r1-s1", 0, "dc2", "rack1", 10),
- makeDisk("dc2-r1-s1", 1, "dc2", "rack1", 10),
- makeDisk("dc2-r1-s2", 0, "dc2", "rack1", 10),
- makeDisk("dc2-r1-s2", 1, "dc2", "rack1", 10),
- // DC2, Rack2
- makeDisk("dc2-r2-s1", 0, "dc2", "rack2", 10),
- makeDisk("dc2-r2-s1", 1, "dc2", "rack2", 10),
- makeDisk("dc2-r2-s2", 0, "dc2", "rack2", 10),
- makeDisk("dc2-r2-s2", 1, "dc2", "rack2", 10),
- }
-
- config := PlacementRequest{
- ShardsNeeded: 8,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-
- result, err := SelectDestinations(disks, config)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
-
- if len(result.SelectedDisks) != 8 {
- t.Errorf("expected 8 selected disks, got %d", len(result.SelectedDisks))
- }
-
- // Should use all 4 racks
- if result.RacksUsed != 4 {
- t.Errorf("expected 4 racks used, got %d", result.RacksUsed)
- }
-
- // Should use both DCs
- if result.DCsUsed != 2 {
- t.Errorf("expected 2 DCs used, got %d", result.DCsUsed)
- }
-}
-
-func TestSelectDestinations_SameRackDifferentDC(t *testing.T) {
- // Test: Same rack name in different DCs should be treated as different racks
- disks := []*DiskCandidate{
- makeDisk("dc1-s1", 0, "dc1", "rack1", 10),
- makeDisk("dc2-s1", 0, "dc2", "rack1", 10),
- }
-
- config := PlacementRequest{
- ShardsNeeded: 2,
- PreferDifferentServers: true,
- PreferDifferentRacks: true,
- }
-
- result, err := SelectDestinations(disks, config)
- if err != nil {
- t.Fatalf("unexpected error: %v", err)
- }
-
- // Should use 2 racks (dc1:rack1 and dc2:rack1 are different)
- if result.RacksUsed != 2 {
- t.Errorf("expected 2 racks used (different DCs), got %d", result.RacksUsed)
- }
-}
diff --git a/weed/storage/idx/binary_search.go b/weed/storage/idx/binary_search.go
deleted file mode 100644
index 9f1dcef40..000000000
--- a/weed/storage/idx/binary_search.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package idx
-
-import (
- "github.com/seaweedfs/seaweedfs/weed/storage/types"
-)
-
-// FirstInvalidIndex find the first index the failed lessThanOrEqualToFn function's requirement.
-func FirstInvalidIndex(bytes []byte, lessThanOrEqualToFn func(key types.NeedleId, offset types.Offset, size types.Size) (bool, error)) (int, error) {
- left, right := 0, len(bytes)/types.NeedleMapEntrySize-1
- index := right + 1
- for left <= right {
- mid := left + (right-left)>>1
- loc := mid * types.NeedleMapEntrySize
- key := types.BytesToNeedleId(bytes[loc : loc+types.NeedleIdSize])
- offset := types.BytesToOffset(bytes[loc+types.NeedleIdSize : loc+types.NeedleIdSize+types.OffsetSize])
- size := types.BytesToSize(bytes[loc+types.NeedleIdSize+types.OffsetSize : loc+types.NeedleIdSize+types.OffsetSize+types.SizeSize])
- res, err := lessThanOrEqualToFn(key, offset, size)
- if err != nil {
- return -1, err
- }
- if res {
- left = mid + 1
- } else {
- index = mid
- right = mid - 1
- }
- }
- return index, nil
-}
diff --git a/weed/storage/idx_binary_search_test.go b/weed/storage/idx_binary_search_test.go
deleted file mode 100644
index 77d38e562..000000000
--- a/weed/storage/idx_binary_search_test.go
+++ /dev/null
@@ -1,71 +0,0 @@
-package storage
-
-import (
- "os"
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/storage/idx"
- "github.com/seaweedfs/seaweedfs/weed/storage/needle"
- "github.com/seaweedfs/seaweedfs/weed/storage/super_block"
- "github.com/seaweedfs/seaweedfs/weed/storage/types"
- "github.com/stretchr/testify/assert"
-)
-
-func TestFirstInvalidIndex(t *testing.T) {
- dir := t.TempDir()
-
- v, err := NewVolume(dir, dir, "", 1, NeedleMapInMemory, &super_block.ReplicaPlacement{}, &needle.TTL{}, 0, needle.GetCurrentVersion(), 0, 0)
- if err != nil {
- t.Fatalf("volume creation: %v", err)
- }
- defer v.Close()
- type WriteInfo struct {
- offset int64
- size int32
- }
- // initialize 20 needles then update first 10 needles
- for i := 1; i <= 30; i++ {
- n := newRandomNeedle(uint64(i))
- n.Flags = 0x08
- _, _, _, err := v.writeNeedle2(n, true, false)
- if err != nil {
- t.Fatalf("write needle %d: %v", i, err)
- }
- }
- b, err := os.ReadFile(v.IndexFileName() + ".idx")
- if err != nil {
- t.Fatal(err)
- }
- // base case every record is valid -> nothing is filtered
- index, err := idx.FirstInvalidIndex(b, func(key types.NeedleId, offset types.Offset, size types.Size) (bool, error) {
- return true, nil
- })
- if err != nil {
- t.Fatalf("failed to complete binary search %v", err)
- }
- assert.Equal(t, 30, index, "when every record is valid nothing should be filtered from binary search")
- index, err = idx.FirstInvalidIndex(b, func(key types.NeedleId, offset types.Offset, size types.Size) (bool, error) {
- return false, nil
- })
- if err != nil {
- t.Fatal(err)
- }
- assert.Equal(t, 0, index, "when every record is invalid everything should be filtered from binary search")
- index, err = idx.FirstInvalidIndex(b, func(key types.NeedleId, offset types.Offset, size types.Size) (bool, error) {
- return key < 20, nil
- })
- if err != nil {
- t.Fatal(err)
- }
- // needle key range from 1 to 30 so < 20 means 19 keys are valid and cutoff the bytes at 19 * 16 = 304
- assert.Equal(t, 19, index, "when every record is invalid everything should be filtered from binary search")
-
- index, err = idx.FirstInvalidIndex(b, func(key types.NeedleId, offset types.Offset, size types.Size) (bool, error) {
- return key <= 1, nil
- })
- if err != nil {
- t.Fatal(err)
- }
- // needle key range from 1 to 30 so <=1 1 means 1 key is valid and cutoff the bytes at 1 * 16 = 16
- assert.Equal(t, 1, index, "when every record is invalid everything should be filtered from binary search")
-}
diff --git a/weed/storage/needle/crc.go b/weed/storage/needle/crc.go
index 6ac31cb43..b1c092c49 100644
--- a/weed/storage/needle/crc.go
+++ b/weed/storage/needle/crc.go
@@ -32,24 +32,7 @@ func (n *Needle) Etag() string {
return fmt.Sprintf("%x", bits)
}
-func NewCRCwriter(w io.Writer) *CRCwriter {
-
- return &CRCwriter{
- crc: CRC(0),
- w: w,
- }
-
-}
-
type CRCwriter struct {
crc CRC
w io.Writer
}
-
-func (c *CRCwriter) Write(p []byte) (n int, err error) {
- n, err = c.w.Write(p) // with each write ...
- c.crc = c.crc.Update(p)
- return
-}
-
-func (c *CRCwriter) Sum() uint32 { return uint32(c.crc) } // final hash
diff --git a/weed/storage/needle/needle_write.go b/weed/storage/needle/needle_write.go
index 009bf393e..d90807d70 100644
--- a/weed/storage/needle/needle_write.go
+++ b/weed/storage/needle/needle_write.go
@@ -1,7 +1,6 @@
package needle
import (
- "bytes"
"fmt"
"github.com/seaweedfs/seaweedfs/weed/glog"
@@ -83,27 +82,3 @@ func WriteNeedleBlob(w backend.BackendStorageFile, dataSlice []byte, size Size,
return
}
-
-// prepareNeedleWrite encapsulates the common beginning logic for all versioned writeNeedle functions.
-func prepareNeedleWrite(w backend.BackendStorageFile, n *Needle) (offset uint64, bytesBuffer *bytes.Buffer, cleanup func(err error), err error) {
- end, _, e := w.GetStat()
- if e != nil {
- err = fmt.Errorf("Cannot Read Current Volume Position: %w", e)
- return
- }
- offset = uint64(end)
- if offset >= MaxPossibleVolumeSize && len(n.Data) != 0 {
- err = fmt.Errorf("Volume Size %d Exceeded %d", offset, MaxPossibleVolumeSize)
- return
- }
- bytesBuffer = buffer_pool.SyncPoolGetBuffer()
- cleanup = func(err error) {
- if err != nil {
- if te := w.Truncate(end); te != nil {
- // handle error or log
- }
- }
- buffer_pool.SyncPoolPutBuffer(bytesBuffer)
- }
- return
-}
diff --git a/weed/storage/store_state.go b/weed/storage/store_state.go
index 2bac4fae6..9014b2a2e 100644
--- a/weed/storage/store_state.go
+++ b/weed/storage/store_state.go
@@ -34,16 +34,6 @@ func NewState(dir string) (*State, error) {
return state, err
}
-func NewStateFromProto(filePath string, state *volume_server_pb.VolumeServerState) *State {
- pb := &volume_server_pb.VolumeServerState{}
- proto.Merge(pb, state)
-
- return &State{
- filePath: filePath,
- pb: pb,
- }
-}
-
func (st *State) Proto() *volume_server_pb.VolumeServerState {
st.mu.Lock()
defer st.mu.Unlock()
diff --git a/weed/topology/capacity_reservation_test.go b/weed/topology/capacity_reservation_test.go
deleted file mode 100644
index 38cb14c50..000000000
--- a/weed/topology/capacity_reservation_test.go
+++ /dev/null
@@ -1,215 +0,0 @@
-package topology
-
-import (
- "sync"
- "testing"
- "time"
-
- "github.com/seaweedfs/seaweedfs/weed/storage/types"
-)
-
-func TestCapacityReservations_BasicOperations(t *testing.T) {
- cr := newCapacityReservations()
- diskType := types.HardDriveType
-
- // Test initial state
- if count := cr.getReservedCount(diskType); count != 0 {
- t.Errorf("Expected 0 reserved count initially, got %d", count)
- }
-
- // Test add reservation
- reservationId := cr.addReservation(diskType, 5)
- if reservationId == "" {
- t.Error("Expected non-empty reservation ID")
- }
-
- if count := cr.getReservedCount(diskType); count != 5 {
- t.Errorf("Expected 5 reserved count, got %d", count)
- }
-
- // Test multiple reservations
- cr.addReservation(diskType, 3)
- if count := cr.getReservedCount(diskType); count != 8 {
- t.Errorf("Expected 8 reserved count after second reservation, got %d", count)
- }
-
- // Test remove reservation
- success := cr.removeReservation(reservationId)
- if !success {
- t.Error("Expected successful removal of existing reservation")
- }
-
- if count := cr.getReservedCount(diskType); count != 3 {
- t.Errorf("Expected 3 reserved count after removal, got %d", count)
- }
-
- // Test remove non-existent reservation
- success = cr.removeReservation("non-existent-id")
- if success {
- t.Error("Expected failure when removing non-existent reservation")
- }
-}
-
-func TestCapacityReservations_ExpiredCleaning(t *testing.T) {
- cr := newCapacityReservations()
- diskType := types.HardDriveType
-
- // Add reservations and manipulate their creation time
- reservationId1 := cr.addReservation(diskType, 3)
- reservationId2 := cr.addReservation(diskType, 2)
-
- // Make one reservation "old"
- cr.Lock()
- if reservation, exists := cr.reservations[reservationId1]; exists {
- reservation.createdAt = time.Now().Add(-10 * time.Minute) // 10 minutes ago
- }
- cr.Unlock()
-
- // Clean expired reservations (5 minute expiration)
- cr.cleanExpiredReservations(5 * time.Minute)
-
- // Only the non-expired reservation should remain
- if count := cr.getReservedCount(diskType); count != 2 {
- t.Errorf("Expected 2 reserved count after cleaning, got %d", count)
- }
-
- // Verify the right reservation was kept
- if !cr.removeReservation(reservationId2) {
- t.Error("Expected recent reservation to still exist")
- }
-
- if cr.removeReservation(reservationId1) {
- t.Error("Expected old reservation to be cleaned up")
- }
-}
-
-func TestCapacityReservations_DifferentDiskTypes(t *testing.T) {
- cr := newCapacityReservations()
-
- // Add reservations for different disk types
- cr.addReservation(types.HardDriveType, 5)
- cr.addReservation(types.SsdType, 3)
-
- // Check counts are separate
- if count := cr.getReservedCount(types.HardDriveType); count != 5 {
- t.Errorf("Expected 5 HDD reserved count, got %d", count)
- }
-
- if count := cr.getReservedCount(types.SsdType); count != 3 {
- t.Errorf("Expected 3 SSD reserved count, got %d", count)
- }
-}
-
-func TestNodeImpl_ReservationMethods(t *testing.T) {
- // Create a test data node
- dn := NewDataNode("test-node")
- diskType := types.HardDriveType
-
- // Set up some capacity
- diskUsage := dn.diskUsages.getOrCreateDisk(diskType)
- diskUsage.maxVolumeCount = 10
- diskUsage.volumeCount = 5 // 5 volumes free initially
-
- option := &VolumeGrowOption{DiskType: diskType}
-
- // Test available space calculation
- available := dn.AvailableSpaceFor(option)
- if available != 5 {
- t.Errorf("Expected 5 available slots, got %d", available)
- }
-
- availableForReservation := dn.AvailableSpaceForReservation(option)
- if availableForReservation != 5 {
- t.Errorf("Expected 5 available slots for reservation, got %d", availableForReservation)
- }
-
- // Test successful reservation
- reservationId, success := dn.TryReserveCapacity(diskType, 3)
- if !success {
- t.Error("Expected successful reservation")
- }
- if reservationId == "" {
- t.Error("Expected non-empty reservation ID")
- }
-
- // Available space should be reduced by reservations
- availableForReservation = dn.AvailableSpaceForReservation(option)
- if availableForReservation != 2 {
- t.Errorf("Expected 2 available slots after reservation, got %d", availableForReservation)
- }
-
- // Base available space should remain unchanged
- available = dn.AvailableSpaceFor(option)
- if available != 5 {
- t.Errorf("Expected base available to remain 5, got %d", available)
- }
-
- // Test reservation failure when insufficient capacity
- _, success = dn.TryReserveCapacity(diskType, 3)
- if success {
- t.Error("Expected reservation failure due to insufficient capacity")
- }
-
- // Test release reservation
- dn.ReleaseReservedCapacity(reservationId)
- availableForReservation = dn.AvailableSpaceForReservation(option)
- if availableForReservation != 5 {
- t.Errorf("Expected 5 available slots after release, got %d", availableForReservation)
- }
-}
-
-func TestNodeImpl_ConcurrentReservations(t *testing.T) {
- dn := NewDataNode("test-node")
- diskType := types.HardDriveType
-
- // Set up capacity
- diskUsage := dn.diskUsages.getOrCreateDisk(diskType)
- diskUsage.maxVolumeCount = 10
- diskUsage.volumeCount = 0 // 10 volumes free initially
-
- // Test concurrent reservations using goroutines
- var wg sync.WaitGroup
- var reservationIds sync.Map
- concurrentRequests := 10
- wg.Add(concurrentRequests)
-
- for i := 0; i < concurrentRequests; i++ {
- go func(i int) {
- defer wg.Done()
- if reservationId, success := dn.TryReserveCapacity(diskType, 1); success {
- reservationIds.Store(reservationId, true)
- t.Logf("goroutine %d: Successfully reserved %s", i, reservationId)
- } else {
- t.Errorf("goroutine %d: Expected successful reservation", i)
- }
- }(i)
- }
-
- wg.Wait()
-
- // Should have no more capacity
- option := &VolumeGrowOption{DiskType: diskType}
- if available := dn.AvailableSpaceForReservation(option); available != 0 {
- t.Errorf("Expected 0 available slots after all reservations, got %d", available)
- // Debug: check total reserved
- reservedCount := dn.capacityReservations.getReservedCount(diskType)
- t.Logf("Debug: Total reserved count: %d", reservedCount)
- }
-
- // Next reservation should fail
- _, success := dn.TryReserveCapacity(diskType, 1)
- if success {
- t.Error("Expected reservation failure when at capacity")
- }
-
- // Release all reservations
- reservationIds.Range(func(key, value interface{}) bool {
- dn.ReleaseReservedCapacity(key.(string))
- return true
- })
-
- // Should have full capacity back
- if available := dn.AvailableSpaceForReservation(option); available != 10 {
- t.Errorf("Expected 10 available slots after releasing all, got %d", available)
- }
-}
diff --git a/weed/topology/disk.go b/weed/topology/disk.go
index fa99ef37a..3616ff928 100644
--- a/weed/topology/disk.go
+++ b/weed/topology/disk.go
@@ -118,16 +118,6 @@ func (a *DiskUsageCounts) FreeSpace() int64 {
return freeVolumeSlotCount
}
-func (a *DiskUsageCounts) minus(b *DiskUsageCounts) *DiskUsageCounts {
- return &DiskUsageCounts{
- volumeCount: a.volumeCount - b.volumeCount,
- remoteVolumeCount: a.remoteVolumeCount - b.remoteVolumeCount,
- activeVolumeCount: a.activeVolumeCount - b.activeVolumeCount,
- ecShardCount: a.ecShardCount - b.ecShardCount,
- maxVolumeCount: a.maxVolumeCount - b.maxVolumeCount,
- }
-}
-
func (du *DiskUsages) getOrCreateDisk(diskType types.DiskType) *DiskUsageCounts {
du.Lock()
defer du.Unlock()
diff --git a/weed/topology/node.go b/weed/topology/node.go
index d32927fca..66d44a8e1 100644
--- a/weed/topology/node.go
+++ b/weed/topology/node.go
@@ -40,13 +40,6 @@ func newCapacityReservations() *CapacityReservations {
}
}
-func (cr *CapacityReservations) addReservation(diskType types.DiskType, count int64) string {
- cr.Lock()
- defer cr.Unlock()
-
- return cr.doAddReservation(diskType, count)
-}
-
func (cr *CapacityReservations) removeReservation(reservationId string) bool {
cr.Lock()
defer cr.Unlock()
diff --git a/weed/topology/volume_layout.go b/weed/topology/volume_layout.go
index ecbacef75..6a7ca2c89 100644
--- a/weed/topology/volume_layout.go
+++ b/weed/topology/volume_layout.go
@@ -40,10 +40,6 @@ func ExistCopies() stateIndicator {
return func(state copyState) bool { return state != noCopies }
}
-func NoCopies() stateIndicator {
- return func(state copyState) bool { return state == noCopies }
-}
-
type volumesBinaryState struct {
rp *super_block.ReplicaPlacement
name volumeState // the name for volume state (eg. "Readonly", "Oversized")
@@ -264,12 +260,6 @@ func (vl *VolumeLayout) isCrowdedVolume(v *storage.VolumeInfo) bool {
return float64(v.Size) > float64(vl.volumeSizeLimit)*VolumeGrowStrategy.Threshold
}
-func (vl *VolumeLayout) isWritable(v *storage.VolumeInfo) bool {
- return !vl.isOversized(v) &&
- v.Version == needle.GetCurrentVersion() &&
- !v.ReadOnly
-}
-
func (vl *VolumeLayout) isEmpty() bool {
vl.accessLock.RLock()
defer vl.accessLock.RUnlock()
diff --git a/weed/topology/volume_layout_test.go b/weed/topology/volume_layout_test.go
deleted file mode 100644
index 999c8de8e..000000000
--- a/weed/topology/volume_layout_test.go
+++ /dev/null
@@ -1,190 +0,0 @@
-package topology
-
-import (
- "testing"
-
- "github.com/seaweedfs/seaweedfs/weed/storage"
- "github.com/seaweedfs/seaweedfs/weed/storage/needle"
- "github.com/seaweedfs/seaweedfs/weed/storage/super_block"
- "github.com/seaweedfs/seaweedfs/weed/storage/types"
-)
-
-func TestVolumesBinaryState(t *testing.T) {
- vids := []needle.VolumeId{
- needle.VolumeId(1),
- needle.VolumeId(2),
- needle.VolumeId(3),
- needle.VolumeId(4),
- needle.VolumeId(5),
- }
-
- dns := []*DataNode{
- &DataNode{
- Ip: "127.0.0.1",
- Port: 8081,
- },
- &DataNode{
- Ip: "127.0.0.1",
- Port: 8082,
- },
- &DataNode{
- Ip: "127.0.0.1",
- Port: 8083,
- },
- }
-
- rp, _ := super_block.NewReplicaPlacementFromString("002")
-
- state_exist := NewVolumesBinaryState(readOnlyState, rp, ExistCopies())
- state_exist.Add(vids[0], dns[0])
- state_exist.Add(vids[0], dns[1])
- state_exist.Add(vids[1], dns[2])
- state_exist.Add(vids[2], dns[1])
- state_exist.Add(vids[4], dns[1])
- state_exist.Add(vids[4], dns[2])
-
- state_no := NewVolumesBinaryState(readOnlyState, rp, NoCopies())
- state_no.Add(vids[0], dns[0])
- state_no.Add(vids[0], dns[1])
- state_no.Add(vids[3], dns[1])
-
- tests := []struct {
- name string
- state *volumesBinaryState
- expectResult []bool
- update func()
- expectResultAfterUpdate []bool
- }{
- {
- name: "mark true when copies exist",
- state: state_exist,
- expectResult: []bool{true, true, true, false, true},
- update: func() {
- state_exist.Remove(vids[0], dns[2])
- state_exist.Remove(vids[1], dns[2])
- state_exist.Remove(vids[3], dns[2])
- state_exist.Remove(vids[4], dns[1])
- state_exist.Remove(vids[4], dns[2])
- },
- expectResultAfterUpdate: []bool{true, false, true, false, false},
- },
- {
- name: "mark true when no copies exist",
- state: state_no,
- expectResult: []bool{false, true, true, false, true},
- update: func() {
- state_no.Remove(vids[0], dns[2])
- state_no.Remove(vids[1], dns[2])
- state_no.Add(vids[2], dns[1])
- state_no.Remove(vids[3], dns[1])
- state_no.Remove(vids[4], dns[2])
- },
- expectResultAfterUpdate: []bool{false, true, false, true, true},
- },
- }
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- var result []bool
- for index, _ := range vids {
- result = append(result, test.state.IsTrue(vids[index]))
- }
- if len(result) != len(test.expectResult) {
- t.Fatalf("len(result) != len(expectResult), got %d, expected %d\n",
- len(result), len(test.expectResult))
- }
- for index, val := range result {
- if val != test.expectResult[index] {
- t.Fatalf("result not matched, index %d, got %v, expected %v\n",
- index, val, test.expectResult[index])
- }
- }
- test.update()
- var updateResult []bool
- for index, _ := range vids {
- updateResult = append(updateResult, test.state.IsTrue(vids[index]))
- }
- if len(updateResult) != len(test.expectResultAfterUpdate) {
- t.Fatalf("len(updateResult) != len(expectResultAfterUpdate), got %d, expected %d\n",
- len(updateResult), len(test.expectResultAfterUpdate))
- }
- for index, val := range updateResult {
- if val != test.expectResultAfterUpdate[index] {
- t.Fatalf("update result not matched, index %d, got %v, expected %v\n",
- index, val, test.expectResultAfterUpdate[index])
- }
- }
- })
- }
-}
-
-func TestVolumeLayoutCrowdedState(t *testing.T) {
- rp, _ := super_block.NewReplicaPlacementFromString("000")
- ttl, _ := needle.ReadTTL("")
- diskType := types.HardDriveType
-
- vl := NewVolumeLayout(rp, ttl, diskType, 1024*1024*1024, false)
-
- vid := needle.VolumeId(1)
- dn := &DataNode{
- NodeImpl: NodeImpl{
- id: "test-node",
- },
- Ip: "127.0.0.1",
- Port: 8080,
- }
-
- // Create a volume info
- volumeInfo := &storage.VolumeInfo{
- Id: vid,
- ReplicaPlacement: rp,
- Ttl: ttl,
- DiskType: string(diskType),
- }
-
- // Register the volume
- vl.RegisterVolume(volumeInfo, dn)
-
- // Add the volume to writables
- vl.accessLock.Lock()
- vl.setVolumeWritable(vid)
- vl.accessLock.Unlock()
-
- // Mark the volume as crowded
- vl.SetVolumeCrowded(vid)
-
- t.Run("should be crowded after being marked", func(t *testing.T) {
- vl.accessLock.RLock()
- _, isCrowded := vl.crowded[vid]
- vl.accessLock.RUnlock()
- if !isCrowded {
- t.Fatal("Volume should be marked as crowded after SetVolumeCrowded")
- }
- })
-
- // Remove from writable (simulating temporary unwritable state)
- vl.accessLock.Lock()
- vl.removeFromWritable(vid)
- vl.accessLock.Unlock()
-
- t.Run("should remain crowded after becoming unwritable", func(t *testing.T) {
- // This is the fix for issue #6712 - crowded state should persist
- vl.accessLock.RLock()
- _, stillCrowded := vl.crowded[vid]
- vl.accessLock.RUnlock()
- if !stillCrowded {
- t.Fatal("Volume should remain crowded after becoming unwritable (fix for issue #6712)")
- }
- })
-
- // Now unregister the volume completely
- vl.UnRegisterVolume(volumeInfo, dn)
-
- t.Run("should not be crowded after unregistering", func(t *testing.T) {
- vl.accessLock.RLock()
- _, stillCrowdedAfterUnregister := vl.crowded[vid]
- vl.accessLock.RUnlock()
- if stillCrowdedAfterUnregister {
- t.Fatal("Volume should be removed from crowded map after full unregistration")
- }
- })
-}
diff --git a/weed/util/bytes.go b/weed/util/bytes.go
index faf7df916..43008c42f 100644
--- a/weed/util/bytes.go
+++ b/weed/util/bytes.go
@@ -120,10 +120,6 @@ func Base64Encode(data []byte) string {
return base64.StdEncoding.EncodeToString(data)
}
-func Base64Md5(data []byte) string {
- return Base64Encode(Md5(data))
-}
-
func Md5(data []byte) []byte {
hash := md5.New()
hash.Write(data)
diff --git a/weed/util/limited_async_pool.go b/weed/util/limited_async_pool.go
deleted file mode 100644
index 51dfd6252..000000000
--- a/weed/util/limited_async_pool.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package util
-
-// initial version comes from https://hackernoon.com/asyncawait-in-golang-an-introductory-guide-ol1e34sg
-
-import (
- "container/list"
- "context"
- "sync"
-)
-
-type Future interface {
- Await() interface{}
-}
-
-type future struct {
- await func(ctx context.Context) interface{}
-}
-
-func (f future) Await() interface{} {
- return f.await(context.Background())
-}
-
-type LimitedAsyncExecutor struct {
- executor *LimitedConcurrentExecutor
- futureList *list.List
- futureListCond *sync.Cond
-}
-
-func NewLimitedAsyncExecutor(limit int) *LimitedAsyncExecutor {
- return &LimitedAsyncExecutor{
- executor: NewLimitedConcurrentExecutor(limit),
- futureList: list.New(),
- futureListCond: sync.NewCond(&sync.Mutex{}),
- }
-}
-
-func (ae *LimitedAsyncExecutor) Execute(job func() interface{}) {
- var result interface{}
- c := make(chan struct{})
- ae.executor.Execute(func() {
- defer close(c)
- result = job()
- })
- f := future{await: func(ctx context.Context) interface{} {
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-c:
- return result
- }
- }}
- ae.futureListCond.L.Lock()
- ae.futureList.PushBack(f)
- ae.futureListCond.Signal()
- ae.futureListCond.L.Unlock()
-}
-
-func (ae *LimitedAsyncExecutor) NextFuture() Future {
- ae.futureListCond.L.Lock()
- for ae.futureList.Len() == 0 {
- ae.futureListCond.Wait()
- }
- f := ae.futureList.Remove(ae.futureList.Front())
- ae.futureListCond.L.Unlock()
- return f.(Future)
-}
diff --git a/weed/util/limited_async_pool_test.go b/weed/util/limited_async_pool_test.go
deleted file mode 100644
index 1289f4f33..000000000
--- a/weed/util/limited_async_pool_test.go
+++ /dev/null
@@ -1,64 +0,0 @@
-package util
-
-import (
- "fmt"
- "sort"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestAsyncPool(t *testing.T) {
- p := NewLimitedAsyncExecutor(3)
-
- p.Execute(FirstFunc)
- p.Execute(SecondFunc)
- p.Execute(ThirdFunc)
- p.Execute(FourthFunc)
- p.Execute(FifthFunc)
-
- var sorted_results []int
- for i := 0; i < 5; i++ {
- f := p.NextFuture()
- x := f.Await().(int)
- println(x)
- sorted_results = append(sorted_results, x)
- }
- assert.True(t, sort.IntsAreSorted(sorted_results), "results should be sorted")
-}
-
-func FirstFunc() any {
- fmt.Println("-- Executing first function --")
- time.Sleep(70 * time.Millisecond)
- fmt.Println("-- First Function finished --")
- return 1
-}
-
-func SecondFunc() any {
- fmt.Println("-- Executing second function --")
- time.Sleep(50 * time.Millisecond)
- fmt.Println("-- Second Function finished --")
- return 2
-}
-
-func ThirdFunc() any {
- fmt.Println("-- Executing third function --")
- time.Sleep(20 * time.Millisecond)
- fmt.Println("-- Third Function finished --")
- return 3
-}
-
-func FourthFunc() any {
- fmt.Println("-- Executing fourth function --")
- time.Sleep(100 * time.Millisecond)
- fmt.Println("-- Fourth Function finished --")
- return 4
-}
-
-func FifthFunc() any {
- fmt.Println("-- Executing fifth function --")
- time.Sleep(40 * time.Millisecond)
- fmt.Println("-- Fourth fifth finished --")
- return 5
-}
diff --git a/weed/util/lock_table.go b/weed/util/lock_table.go
index 8f65aac06..65daae39c 100644
--- a/weed/util/lock_table.go
+++ b/weed/util/lock_table.go
@@ -175,7 +175,3 @@ func (lt *LockTable[T]) ReleaseLock(key T, lock *ActiveLock) {
// Notify the next waiter
entry.cond.Broadcast()
}
-
-func main() {
-
-}
diff --git a/weed/wdclient/net2/base_connection_pool.go b/weed/wdclient/net2/base_connection_pool.go
deleted file mode 100644
index 0b79130e3..000000000
--- a/weed/wdclient/net2/base_connection_pool.go
+++ /dev/null
@@ -1,159 +0,0 @@
-package net2
-
-import (
- "net"
- "strings"
- "time"
-
- rp "github.com/seaweedfs/seaweedfs/weed/wdclient/resource_pool"
-)
-
-const defaultDialTimeout = 1 * time.Second
-
-func defaultDialFunc(network string, address string) (net.Conn, error) {
- return net.DialTimeout(network, address, defaultDialTimeout)
-}
-
-func parseResourceLocation(resourceLocation string) (
- network string,
- address string) {
-
- idx := strings.Index(resourceLocation, " ")
- if idx >= 0 {
- return resourceLocation[:idx], resourceLocation[idx+1:]
- }
-
- return "", resourceLocation
-}
-
-// A thin wrapper around the underlying resource pool.
-type connectionPoolImpl struct {
- options ConnectionOptions
-
- pool rp.ResourcePool
-}
-
-// This returns a connection pool where all connections are connected
-// to the same (network, address)
-func newBaseConnectionPool(
- options ConnectionOptions,
- createPool func(rp.Options) rp.ResourcePool) ConnectionPool {
-
- dial := options.Dial
- if dial == nil {
- dial = defaultDialFunc
- }
-
- openFunc := func(loc string) (interface{}, error) {
- network, address := parseResourceLocation(loc)
- return dial(network, address)
- }
-
- closeFunc := func(handle interface{}) error {
- return handle.(net.Conn).Close()
- }
-
- poolOptions := rp.Options{
- MaxActiveHandles: options.MaxActiveConnections,
- MaxIdleHandles: options.MaxIdleConnections,
- MaxIdleTime: options.MaxIdleTime,
- OpenMaxConcurrency: options.DialMaxConcurrency,
- Open: openFunc,
- Close: closeFunc,
- NowFunc: options.NowFunc,
- }
-
- return &connectionPoolImpl{
- options: options,
- pool: createPool(poolOptions),
- }
-}
-
-// This returns a connection pool where all connections are connected
-// to the same (network, address)
-func NewSimpleConnectionPool(options ConnectionOptions) ConnectionPool {
- return newBaseConnectionPool(options, rp.NewSimpleResourcePool)
-}
-
-// This returns a connection pool that manages multiple (network, address)
-// entries. The connections to each (network, address) entry acts
-// independently. For example ("tcp", "localhost:11211") could act as memcache
-// shard 0 and ("tcp", "localhost:11212") could act as memcache shard 1.
-func NewMultiConnectionPool(options ConnectionOptions) ConnectionPool {
- return newBaseConnectionPool(
- options,
- func(poolOptions rp.Options) rp.ResourcePool {
- return rp.NewMultiResourcePool(poolOptions, nil)
- })
-}
-
-// See ConnectionPool for documentation.
-func (p *connectionPoolImpl) NumActive() int32 {
- return p.pool.NumActive()
-}
-
-// See ConnectionPool for documentation.
-func (p *connectionPoolImpl) ActiveHighWaterMark() int32 {
- return p.pool.ActiveHighWaterMark()
-}
-
-// This returns the number of alive idle connections. This method is not part
-// of ConnectionPool's API. It is used only for testing.
-func (p *connectionPoolImpl) NumIdle() int {
- return p.pool.NumIdle()
-}
-
-// BaseConnectionPool can only register a single (network, address) entry.
-// Register should be call before any Get calls.
-func (p *connectionPoolImpl) Register(network string, address string) error {
- return p.pool.Register(network + " " + address)
-}
-
-// BaseConnectionPool has nothing to do on Unregister.
-func (p *connectionPoolImpl) Unregister(network string, address string) error {
- return nil
-}
-
-func (p *connectionPoolImpl) ListRegistered() []NetworkAddress {
- result := make([]NetworkAddress, 0, 1)
- for _, location := range p.pool.ListRegistered() {
- network, address := parseResourceLocation(location)
-
- result = append(
- result,
- NetworkAddress{
- Network: network,
- Address: address,
- })
- }
- return result
-}
-
-// This gets an active connection from the connection pool. Note that network
-// and address arguments are ignored (The connections with point to the
-// network/address provided by the first Register call).
-func (p *connectionPoolImpl) Get(
- network string,
- address string) (ManagedConn, error) {
-
- handle, err := p.pool.Get(network + " " + address)
- if err != nil {
- return nil, err
- }
- return NewManagedConn(network, address, handle, p, p.options), nil
-}
-
-// See ConnectionPool for documentation.
-func (p *connectionPoolImpl) Release(conn ManagedConn) error {
- return conn.ReleaseConnection()
-}
-
-// See ConnectionPool for documentation.
-func (p *connectionPoolImpl) Discard(conn ManagedConn) error {
- return conn.DiscardConnection()
-}
-
-// See ConnectionPool for documentation.
-func (p *connectionPoolImpl) EnterLameDuckMode() {
- p.pool.EnterLameDuckMode()
-}
diff --git a/weed/wdclient/net2/connection_pool.go b/weed/wdclient/net2/connection_pool.go
deleted file mode 100644
index 5b8d4d232..000000000
--- a/weed/wdclient/net2/connection_pool.go
+++ /dev/null
@@ -1,97 +0,0 @@
-package net2
-
-import (
- "net"
- "time"
-)
-
-type ConnectionOptions struct {
- // The maximum number of connections that can be active per host at any
- // given time (A non-positive value indicates the number of connections
- // is unbounded).
- MaxActiveConnections int32
-
- // The maximum number of idle connections per host that are kept alive by
- // the connection pool.
- MaxIdleConnections uint32
-
- // The maximum amount of time an idle connection can alive (if specified).
- MaxIdleTime *time.Duration
-
- // This limits the number of concurrent Dial calls (there's no limit when
- // DialMaxConcurrency is non-positive).
- DialMaxConcurrency int
-
- // Dial specifies the dial function for creating network connections.
- // If Dial is nil, net.DialTimeout is used, with timeout set to 1 second.
- Dial func(network string, address string) (net.Conn, error)
-
- // This specifies the now time function. When the function is non-nil, the
- // connection pool will use the specified function instead of time.Now to
- // generate the current time.
- NowFunc func() time.Time
-
- // This specifies the timeout for any Read() operation.
- // Note that setting this to 0 (i.e. not setting it) will make
- // read operations block indefinitely.
- ReadTimeout time.Duration
-
- // This specifies the timeout for any Write() operation.
- // Note that setting this to 0 (i.e. not setting it) will make
- // write operations block indefinitely.
- WriteTimeout time.Duration
-}
-
-func (o ConnectionOptions) getCurrentTime() time.Time {
- if o.NowFunc == nil {
- return time.Now()
- } else {
- return o.NowFunc()
- }
-}
-
-// A generic interface for managed connection pool. All connection pool
-// implementations must be threadsafe.
-type ConnectionPool interface {
- // This returns the number of active connections that are on loan.
- NumActive() int32
-
- // This returns the highest number of active connections for the entire
- // lifetime of the pool.
- ActiveHighWaterMark() int32
-
- // This returns the number of idle connections that are in the pool.
- NumIdle() int
-
- // This associates (network, address) to the connection pool; afterwhich,
- // the user can get connections to (network, address).
- Register(network string, address string) error
-
- // This dissociate (network, address) from the connection pool;
- // afterwhich, the user can no longer get connections to
- // (network, address).
- Unregister(network string, address string) error
-
- // This returns the list of registered (network, address) entries.
- ListRegistered() []NetworkAddress
-
- // This gets an active connection from the connection pool. The connection
- // will remain active until one of the following is called:
- // 1. conn.ReleaseConnection()
- // 2. conn.DiscardConnection()
- // 3. pool.Release(conn)
- // 4. pool.Discard(conn)
- Get(network string, address string) (ManagedConn, error)
-
- // This releases an active connection back to the connection pool.
- Release(conn ManagedConn) error
-
- // This discards an active connection from the connection pool.
- Discard(conn ManagedConn) error
-
- // Enter the connection pool into lame duck mode. The connection pool
- // will no longer return connections, and all idle connections are closed
- // immediately (including active connections that are released back to the
- // pool afterward).
- EnterLameDuckMode()
-}
diff --git a/weed/wdclient/net2/doc.go b/weed/wdclient/net2/doc.go
deleted file mode 100644
index fd1c6323d..000000000
--- a/weed/wdclient/net2/doc.go
+++ /dev/null
@@ -1,6 +0,0 @@
-// net2 is a collection of functions meant to supplement the capabilities
-// provided by the standard "net" package.
-package net2
-
-// copied from https://github.com/dropbox/godropbox/tree/master/net2
-// removed other dependencies
diff --git a/weed/wdclient/net2/managed_connection.go b/weed/wdclient/net2/managed_connection.go
deleted file mode 100644
index d4696739e..000000000
--- a/weed/wdclient/net2/managed_connection.go
+++ /dev/null
@@ -1,186 +0,0 @@
-package net2
-
-import (
- "fmt"
- "net"
- "time"
-
- "errors"
-
- "github.com/seaweedfs/seaweedfs/weed/wdclient/resource_pool"
-)
-
-// Dial's arguments.
-type NetworkAddress struct {
- Network string
- Address string
-}
-
-// A connection managed by a connection pool. NOTE: SetDeadline,
-// SetReadDeadline and SetWriteDeadline are disabled for managed connections.
-// (The deadlines are set by the connection pool).
-type ManagedConn interface {
- net.Conn
-
- // This returns the original (network, address) entry used for creating
- // the connection.
- Key() NetworkAddress
-
- // This returns the underlying net.Conn implementation.
- RawConn() net.Conn
-
- // This returns the connection pool which owns this connection.
- Owner() ConnectionPool
-
- // This indicates a user is done with the connection and releases the
- // connection back to the connection pool.
- ReleaseConnection() error
-
- // This indicates the connection is an invalid state, and that the
- // connection should be discarded from the connection pool.
- DiscardConnection() error
-}
-
-// A physical implementation of ManagedConn
-type managedConnImpl struct {
- addr NetworkAddress
- handle resource_pool.ManagedHandle
- pool ConnectionPool
- options ConnectionOptions
-}
-
-// This creates a managed connection wrapper.
-func NewManagedConn(
- network string,
- address string,
- handle resource_pool.ManagedHandle,
- pool ConnectionPool,
- options ConnectionOptions) ManagedConn {
-
- addr := NetworkAddress{
- Network: network,
- Address: address,
- }
-
- return &managedConnImpl{
- addr: addr,
- handle: handle,
- pool: pool,
- options: options,
- }
-}
-
-func (c *managedConnImpl) rawConn() (net.Conn, error) {
- h, err := c.handle.Handle()
- return h.(net.Conn), err
-}
-
-// See ManagedConn for documentation.
-func (c *managedConnImpl) RawConn() net.Conn {
- h, _ := c.handle.Handle()
- return h.(net.Conn)
-}
-
-// See ManagedConn for documentation.
-func (c *managedConnImpl) Key() NetworkAddress {
- return c.addr
-}
-
-// See ManagedConn for documentation.
-func (c *managedConnImpl) Owner() ConnectionPool {
- return c.pool
-}
-
-// See ManagedConn for documentation.
-func (c *managedConnImpl) ReleaseConnection() error {
- return c.handle.Release()
-}
-
-// See ManagedConn for documentation.
-func (c *managedConnImpl) DiscardConnection() error {
- return c.handle.Discard()
-}
-
-// See net.Conn for documentation
-func (c *managedConnImpl) Read(b []byte) (n int, err error) {
- conn, err := c.rawConn()
- if err != nil {
- return 0, err
- }
-
- if c.options.ReadTimeout > 0 {
- deadline := c.options.getCurrentTime().Add(c.options.ReadTimeout)
- _ = conn.SetReadDeadline(deadline)
- }
- n, err = conn.Read(b)
- if err != nil {
- var localAddr string
- if conn.LocalAddr() != nil {
- localAddr = conn.LocalAddr().String()
- } else {
- localAddr = "(nil)"
- }
-
- var remoteAddr string
- if conn.RemoteAddr() != nil {
- remoteAddr = conn.RemoteAddr().String()
- } else {
- remoteAddr = "(nil)"
- }
- err = fmt.Errorf("Read error from host: %s <-> %s: %v", localAddr, remoteAddr, err)
- }
- return
-}
-
-// See net.Conn for documentation
-func (c *managedConnImpl) Write(b []byte) (n int, err error) {
- conn, err := c.rawConn()
- if err != nil {
- return 0, err
- }
-
- if c.options.WriteTimeout > 0 {
- deadline := c.options.getCurrentTime().Add(c.options.WriteTimeout)
- _ = conn.SetWriteDeadline(deadline)
- }
- n, err = conn.Write(b)
- if err != nil {
- err = fmt.Errorf("Write error: %w", err)
- }
- return
-}
-
-// See net.Conn for documentation
-func (c *managedConnImpl) Close() error {
- return c.handle.Discard()
-}
-
-// See net.Conn for documentation
-func (c *managedConnImpl) LocalAddr() net.Addr {
- conn, _ := c.rawConn()
- return conn.LocalAddr()
-}
-
-// See net.Conn for documentation
-func (c *managedConnImpl) RemoteAddr() net.Addr {
- conn, _ := c.rawConn()
- return conn.RemoteAddr()
-}
-
-// SetDeadline is disabled for managed connection (The deadline is set by
-// us, with respect to the read/write timeouts specified in ConnectionOptions).
-func (c *managedConnImpl) SetDeadline(t time.Time) error {
- return errors.New("Cannot set deadline for managed connection")
-}
-
-// SetReadDeadline is disabled for managed connection (The deadline is set by
-// us with respect to the read timeout specified in ConnectionOptions).
-func (c *managedConnImpl) SetReadDeadline(t time.Time) error {
- return errors.New("Cannot set read deadline for managed connection")
-}
-
-// SetWriteDeadline is disabled for managed connection (The deadline is set by
-// us with respect to the write timeout specified in ConnectionOptions).
-func (c *managedConnImpl) SetWriteDeadline(t time.Time) error {
- return errors.New("Cannot set write deadline for managed connection")
-}
diff --git a/weed/wdclient/net2/port.go b/weed/wdclient/net2/port.go
deleted file mode 100644
index f83adba28..000000000
--- a/weed/wdclient/net2/port.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package net2
-
-import (
- "net"
- "strconv"
-)
-
-// Returns the port information.
-func GetPort(addr net.Addr) (int, error) {
- _, lport, err := net.SplitHostPort(addr.String())
- if err != nil {
- return -1, err
- }
- lportInt, err := strconv.Atoi(lport)
- if err != nil {
- return -1, err
- }
- return lportInt, nil
-}
diff --git a/weed/wdclient/resource_pool/doc.go b/weed/wdclient/resource_pool/doc.go
deleted file mode 100644
index b8b3f92fa..000000000
--- a/weed/wdclient/resource_pool/doc.go
+++ /dev/null
@@ -1,5 +0,0 @@
-// A generic resource pool for managing resources such as network connections.
-package resource_pool
-
-// copied from https://github.com/dropbox/godropbox/tree/master/resource_pool
-// removed other dependencies
diff --git a/weed/wdclient/resource_pool/managed_handle.go b/weed/wdclient/resource_pool/managed_handle.go
deleted file mode 100644
index 936c2d7c3..000000000
--- a/weed/wdclient/resource_pool/managed_handle.go
+++ /dev/null
@@ -1,97 +0,0 @@
-package resource_pool
-
-import (
- "sync/atomic"
-
- "errors"
-)
-
-// A resource handle managed by a resource pool.
-type ManagedHandle interface {
- // This returns the handle's resource location.
- ResourceLocation() string
-
- // This returns the underlying resource handle (or error if the handle
- // is no longer active).
- Handle() (interface{}, error)
-
- // This returns the resource pool which owns this handle.
- Owner() ResourcePool
-
- // The releases the underlying resource handle to the caller and marks the
- // managed handle as inactive. The caller is responsible for cleaning up
- // the released handle. This returns nil if the managed handle no longer
- // owns the resource.
- ReleaseUnderlyingHandle() interface{}
-
- // This indicates a user is done with the handle and releases the handle
- // back to the resource pool.
- Release() error
-
- // This indicates the handle is an invalid state, and that the
- // connection should be discarded from the connection pool.
- Discard() error
-}
-
-// A physical implementation of ManagedHandle
-type managedHandleImpl struct {
- location string
- handle interface{}
- pool ResourcePool
- isActive int32 // atomic bool
- options Options
-}
-
-// This creates a managed handle wrapper.
-func NewManagedHandle(
- resourceLocation string,
- handle interface{},
- pool ResourcePool,
- options Options) ManagedHandle {
-
- h := &managedHandleImpl{
- location: resourceLocation,
- handle: handle,
- pool: pool,
- options: options,
- }
- atomic.StoreInt32(&h.isActive, 1)
-
- return h
-}
-
-// See ManagedHandle for documentation.
-func (c *managedHandleImpl) ResourceLocation() string {
- return c.location
-}
-
-// See ManagedHandle for documentation.
-func (c *managedHandleImpl) Handle() (interface{}, error) {
- if atomic.LoadInt32(&c.isActive) == 0 {
- return c.handle, errors.New("Resource handle is no longer valid")
- }
- return c.handle, nil
-}
-
-// See ManagedHandle for documentation.
-func (c *managedHandleImpl) Owner() ResourcePool {
- return c.pool
-}
-
-// See ManagedHandle for documentation.
-func (c *managedHandleImpl) ReleaseUnderlyingHandle() interface{} {
- if atomic.CompareAndSwapInt32(&c.isActive, 1, 0) {
- return c.handle
- }
- return nil
-}
-
-// See ManagedHandle for documentation.
-func (c *managedHandleImpl) Release() error {
- return c.pool.Release(c)
-}
-
-// See ManagedHandle for documentation.
-func (c *managedHandleImpl) Discard() error {
- return c.pool.Discard(c)
-}
diff --git a/weed/wdclient/resource_pool/multi_resource_pool.go b/weed/wdclient/resource_pool/multi_resource_pool.go
deleted file mode 100644
index 9ac25526d..000000000
--- a/weed/wdclient/resource_pool/multi_resource_pool.go
+++ /dev/null
@@ -1,200 +0,0 @@
-package resource_pool
-
-import (
- "fmt"
- "sync"
-
- "errors"
-)
-
-// A resource pool implementation that manages multiple resource location
-// entries. The handles to each resource location entry acts independently.
-// For example "tcp localhost:11211" could act as memcache
-// shard 0 and "tcp localhost:11212" could act as memcache shard 1.
-type multiResourcePool struct {
- options Options
-
- createPool func(Options) ResourcePool
-
- rwMutex sync.RWMutex
- isLameDuck bool // guarded by rwMutex
- // NOTE: the locationPools is guarded by rwMutex, but the pool entries
- // are not.
- locationPools map[string]ResourcePool
-}
-
-// This returns a MultiResourcePool, which manages multiple
-// resource location entries. The handles to each resource location
-// entry acts independently.
-//
-// When createPool is nil, NewSimpleResourcePool is used as default.
-func NewMultiResourcePool(
- options Options,
- createPool func(Options) ResourcePool) ResourcePool {
-
- if createPool == nil {
- createPool = NewSimpleResourcePool
- }
-
- return &multiResourcePool{
- options: options,
- createPool: createPool,
- rwMutex: sync.RWMutex{},
- isLameDuck: false,
- locationPools: make(map[string]ResourcePool),
- }
-}
-
-// See ResourcePool for documentation.
-func (p *multiResourcePool) NumActive() int32 {
- total := int32(0)
-
- p.rwMutex.RLock()
- defer p.rwMutex.RUnlock()
-
- for _, pool := range p.locationPools {
- total += pool.NumActive()
- }
- return total
-}
-
-// See ResourcePool for documentation.
-func (p *multiResourcePool) ActiveHighWaterMark() int32 {
- high := int32(0)
-
- p.rwMutex.RLock()
- defer p.rwMutex.RUnlock()
-
- for _, pool := range p.locationPools {
- val := pool.ActiveHighWaterMark()
- if val > high {
- high = val
- }
- }
- return high
-}
-
-// See ResourcePool for documentation.
-func (p *multiResourcePool) NumIdle() int {
- total := 0
-
- p.rwMutex.RLock()
- defer p.rwMutex.RUnlock()
-
- for _, pool := range p.locationPools {
- total += pool.NumIdle()
- }
- return total
-}
-
-// See ResourcePool for documentation.
-func (p *multiResourcePool) Register(resourceLocation string) error {
- if resourceLocation == "" {
- return errors.New("Registering invalid resource location")
- }
-
- p.rwMutex.Lock()
- defer p.rwMutex.Unlock()
-
- if p.isLameDuck {
- return fmt.Errorf(
- "Cannot register %s to lame duck resource pool",
- resourceLocation)
- }
-
- if _, inMap := p.locationPools[resourceLocation]; inMap {
- return nil
- }
-
- pool := p.createPool(p.options)
- if err := pool.Register(resourceLocation); err != nil {
- return err
- }
-
- p.locationPools[resourceLocation] = pool
- return nil
-}
-
-// See ResourcePool for documentation.
-func (p *multiResourcePool) Unregister(resourceLocation string) error {
- p.rwMutex.Lock()
- defer p.rwMutex.Unlock()
-
- if pool, inMap := p.locationPools[resourceLocation]; inMap {
- _ = pool.Unregister("")
- pool.EnterLameDuckMode()
- delete(p.locationPools, resourceLocation)
- }
- return nil
-}
-
-func (p *multiResourcePool) ListRegistered() []string {
- p.rwMutex.RLock()
- defer p.rwMutex.RUnlock()
-
- result := make([]string, 0, len(p.locationPools))
- for key, _ := range p.locationPools {
- result = append(result, key)
- }
-
- return result
-}
-
-// See ResourcePool for documentation.
-func (p *multiResourcePool) Get(
- resourceLocation string) (ManagedHandle, error) {
-
- pool := p.getPool(resourceLocation)
- if pool == nil {
- return nil, fmt.Errorf(
- "%s is not registered in the resource pool",
- resourceLocation)
- }
- return pool.Get(resourceLocation)
-}
-
-// See ResourcePool for documentation.
-func (p *multiResourcePool) Release(handle ManagedHandle) error {
- pool := p.getPool(handle.ResourceLocation())
- if pool == nil {
- return errors.New(
- "Resource pool cannot take control of a handle owned " +
- "by another resource pool")
- }
-
- return pool.Release(handle)
-}
-
-// See ResourcePool for documentation.
-func (p *multiResourcePool) Discard(handle ManagedHandle) error {
- pool := p.getPool(handle.ResourceLocation())
- if pool == nil {
- return errors.New(
- "Resource pool cannot take control of a handle owned " +
- "by another resource pool")
- }
-
- return pool.Discard(handle)
-}
-
-// See ResourcePool for documentation.
-func (p *multiResourcePool) EnterLameDuckMode() {
- p.rwMutex.Lock()
- defer p.rwMutex.Unlock()
-
- p.isLameDuck = true
-
- for _, pool := range p.locationPools {
- pool.EnterLameDuckMode()
- }
-}
-
-func (p *multiResourcePool) getPool(resourceLocation string) ResourcePool {
- p.rwMutex.RLock()
- defer p.rwMutex.RUnlock()
-
- if pool, inMap := p.locationPools[resourceLocation]; inMap {
- return pool
- }
- return nil
-}
diff --git a/weed/wdclient/resource_pool/resource_pool.go b/weed/wdclient/resource_pool/resource_pool.go
deleted file mode 100644
index 26c433f50..000000000
--- a/weed/wdclient/resource_pool/resource_pool.go
+++ /dev/null
@@ -1,96 +0,0 @@
-package resource_pool
-
-import (
- "time"
-)
-
-type Options struct {
- // The maximum number of active resource handles per resource location. (A
- // non-positive value indicates the number of active resource handles is
- // unbounded).
- MaxActiveHandles int32
-
- // The maximum number of idle resource handles per resource location that
- // are kept alive by the resource pool.
- MaxIdleHandles uint32
-
- // The maximum amount of time an idle resource handle can remain alive (if
- // specified).
- MaxIdleTime *time.Duration
-
- // This limits the number of concurrent Open calls (there's no limit when
- // OpenMaxConcurrency is non-positive).
- OpenMaxConcurrency int
-
- // This function creates a resource handle (e.g., a connection) for a
- // resource location. The function must be thread-safe.
- Open func(resourceLocation string) (
- handle interface{},
- err error)
-
- // This function destroys a resource handle and performs the necessary
- // cleanup to free up resources. The function must be thread-safe.
- Close func(handle interface{}) error
-
- // This specifies the now time function. When the function is non-nil, the
- // resource pool will use the specified function instead of time.Now to
- // generate the current time.
- NowFunc func() time.Time
-}
-
-func (o Options) getCurrentTime() time.Time {
- if o.NowFunc == nil {
- return time.Now()
- } else {
- return o.NowFunc()
- }
-}
-
-// A generic interface for managed resource pool. All resource pool
-// implementations must be threadsafe.
-type ResourcePool interface {
- // This returns the number of active resource handles.
- NumActive() int32
-
- // This returns the highest number of actives handles for the entire
- // lifetime of the pool. If the pool contains multiple sub-pools, the
- // high water mark is the max of the sub-pools' high water marks.
- ActiveHighWaterMark() int32
-
- // This returns the number of alive idle handles. NOTE: This is only used
- // for testing.
- NumIdle() int
-
- // This associates a resource location to the resource pool; afterwhich,
- // the user can get resource handles for the resource location.
- Register(resourceLocation string) error
-
- // This dissociates a resource location from the resource pool; afterwhich,
- // the user can no longer get resource handles for the resource location.
- // If the given resource location corresponds to a sub-pool, the unregistered
- // sub-pool will enter lame duck mode.
- Unregister(resourceLocation string) error
-
- // This returns the list of registered resource location entries.
- ListRegistered() []string
-
- // This gets an active resource handle from the resource pool. The
- // handle will remain active until one of the following is called:
- // 1. handle.Release()
- // 2. handle.Discard()
- // 3. pool.Release(handle)
- // 4. pool.Discard(handle)
- Get(key string) (ManagedHandle, error)
-
- // This releases an active resource handle back to the resource pool.
- Release(handle ManagedHandle) error
-
- // This discards an active resource from the resource pool.
- Discard(handle ManagedHandle) error
-
- // Enter the resource pool into lame duck mode. The resource pool
- // will no longer return resource handles, and all idle resource handles
- // are closed immediately (including active resource handles that are
- // released back to the pool afterward).
- EnterLameDuckMode()
-}
diff --git a/weed/wdclient/resource_pool/semaphore.go b/weed/wdclient/resource_pool/semaphore.go
deleted file mode 100644
index 9bd6afc33..000000000
--- a/weed/wdclient/resource_pool/semaphore.go
+++ /dev/null
@@ -1,154 +0,0 @@
-package resource_pool
-
-import (
- "fmt"
- "sync"
- "sync/atomic"
- "time"
-)
-
-type Semaphore interface {
- // Increment the semaphore counter by one.
- Release()
-
- // Decrement the semaphore counter by one, and block if counter < 0
- Acquire()
-
- // Decrement the semaphore counter by one, and block if counter < 0
- // Wait for up to the given duration. Returns true if did not timeout
- TryAcquire(timeout time.Duration) bool
-}
-
-// A simple counting Semaphore.
-type boundedSemaphore struct {
- slots chan struct{}
-}
-
-// Create a bounded semaphore. The count parameter must be a positive number.
-// NOTE: The bounded semaphore will panic if the user tries to Release
-// beyond the specified count.
-func NewBoundedSemaphore(count uint) Semaphore {
- sem := &boundedSemaphore{
- slots: make(chan struct{}, int(count)),
- }
- for i := 0; i < cap(sem.slots); i++ {
- sem.slots <- struct{}{}
- }
- return sem
-}
-
-// Acquire returns on successful acquisition.
-func (sem *boundedSemaphore) Acquire() {
- <-sem.slots
-}
-
-// TryAcquire returns true if it acquires a resource slot within the
-// timeout, false otherwise.
-func (sem *boundedSemaphore) TryAcquire(timeout time.Duration) bool {
- if timeout > 0 {
- // Wait until we get a slot or timeout expires.
- tm := time.NewTimer(timeout)
- defer tm.Stop()
- select {
- case <-sem.slots:
- return true
- case <-tm.C:
- // Timeout expired. In very rare cases this might happen even if
- // there is a slot available, e.g. GC pause after we create the timer
- // and select randomly picked this one out of the two available channels.
- // We should do one final immediate check below.
- }
- }
-
- // Return true if we have a slot available immediately and false otherwise.
- select {
- case <-sem.slots:
- return true
- default:
- return false
- }
-}
-
-// Release the acquired semaphore. You must not release more than you
-// have acquired.
-func (sem *boundedSemaphore) Release() {
- select {
- case sem.slots <- struct{}{}:
- default:
- // slots is buffered. If a send blocks, it indicates a programming
- // error.
- panic(fmt.Errorf("too many releases for boundedSemaphore"))
- }
-}
-
-// This returns an unbound counting semaphore with the specified initial count.
-// The semaphore counter can be arbitrary large (i.e., Release can be called
-// unlimited amount of times).
-//
-// NOTE: In general, users should use bounded semaphore since it is more
-// efficient than unbounded semaphore.
-func NewUnboundedSemaphore(initialCount int) Semaphore {
- res := &unboundedSemaphore{
- counter: int64(initialCount),
- }
- res.cond.L = &res.lock
- return res
-}
-
-type unboundedSemaphore struct {
- lock sync.Mutex
- cond sync.Cond
- counter int64
-}
-
-func (s *unboundedSemaphore) Release() {
- s.lock.Lock()
- s.counter += 1
- if s.counter > 0 {
- // Not broadcasting here since it's unlike we can satisfy all waiting
- // goroutines. Instead, we will Signal again if there are left over
- // quota after Acquire, in case of lost wakeups.
- s.cond.Signal()
- }
- s.lock.Unlock()
-}
-
-func (s *unboundedSemaphore) Acquire() {
- s.lock.Lock()
- for s.counter < 1 {
- s.cond.Wait()
- }
- s.counter -= 1
- if s.counter > 0 {
- s.cond.Signal()
- }
- s.lock.Unlock()
-}
-
-func (s *unboundedSemaphore) TryAcquire(timeout time.Duration) bool {
- done := make(chan bool, 1)
- // Gate used to communicate between the threads and decide what the result
- // is. If the main thread decides, we have timed out, otherwise we succeed.
- decided := new(int32)
- atomic.StoreInt32(decided, 0)
- go func() {
- s.Acquire()
- if atomic.SwapInt32(decided, 1) == 0 {
- // Acquire won the race
- done <- true
- } else {
- // If we already decided the result, and this thread did not win
- s.Release()
- }
- }()
- select {
- case <-done:
- return true
- case <-time.After(timeout):
- if atomic.SwapInt32(decided, 1) == 1 {
- // The other thread already decided the result
- return true
- }
- return false
- }
-}
diff --git a/weed/wdclient/resource_pool/simple_resource_pool.go b/weed/wdclient/resource_pool/simple_resource_pool.go
deleted file mode 100644
index 99f555a02..000000000
--- a/weed/wdclient/resource_pool/simple_resource_pool.go
+++ /dev/null
@@ -1,343 +0,0 @@
-package resource_pool
-
-import (
- "errors"
- "fmt"
- "sync"
- "sync/atomic"
- "time"
-)
-
-type idleHandle struct {
- handle interface{}
- keepUntil *time.Time
-}
-
-type TooManyHandles struct {
- location string
-}
-
-func (t TooManyHandles) Error() string {
- return fmt.Sprintf("Too many handles to %s", t.location)
-}
-
-type OpenHandleError struct {
- location string
- err error
-}
-
-func (o OpenHandleError) Error() string {
- return fmt.Sprintf("Failed to open resource handle: %s (%v)", o.location, o.err)
-}
-
-// A resource pool implementation where all handles are associated to the
-// same resource location.
-type simpleResourcePool struct {
- options Options
-
- numActive *int32 // atomic counter
-
- activeHighWaterMark *int32 // atomic / monotonically increasing value
-
- openTokens Semaphore
-
- mutex sync.Mutex
- location string // guard by mutex
- idleHandles []*idleHandle // guarded by mutex
- isLameDuck bool // guarded by mutex
-}
-
-// This returns a SimpleResourcePool, where all handles are associated to a
-// single resource location.
-func NewSimpleResourcePool(options Options) ResourcePool {
- numActive := new(int32)
- atomic.StoreInt32(numActive, 0)
-
- activeHighWaterMark := new(int32)
- atomic.StoreInt32(activeHighWaterMark, 0)
-
- var tokens Semaphore
- if options.OpenMaxConcurrency > 0 {
- tokens = NewBoundedSemaphore(uint(options.OpenMaxConcurrency))
- }
-
- return &simpleResourcePool{
- location: "",
- options: options,
- numActive: numActive,
- activeHighWaterMark: activeHighWaterMark,
- openTokens: tokens,
- mutex: sync.Mutex{},
- idleHandles: make([]*idleHandle, 0, 0),
- isLameDuck: false,
- }
-}
-
-// See ResourcePool for documentation.
-func (p *simpleResourcePool) NumActive() int32 {
- return atomic.LoadInt32(p.numActive)
-}
-
-// See ResourcePool for documentation.
-func (p *simpleResourcePool) ActiveHighWaterMark() int32 {
- return atomic.LoadInt32(p.activeHighWaterMark)
-}
-
-// See ResourcePool for documentation.
-func (p *simpleResourcePool) NumIdle() int {
- p.mutex.Lock()
- defer p.mutex.Unlock()
- return len(p.idleHandles)
-}
-
-// SimpleResourcePool can only register a single (network, address) entry.
-// Register should be call before any Get calls.
-func (p *simpleResourcePool) Register(resourceLocation string) error {
- if resourceLocation == "" {
- return errors.New("Invalid resource location")
- }
-
- p.mutex.Lock()
- defer p.mutex.Unlock()
-
- if p.isLameDuck {
- return fmt.Errorf(
- "cannot register %s to lame duck resource pool",
- resourceLocation)
- }
-
- if p.location == "" {
- p.location = resourceLocation
- return nil
- }
- return errors.New("SimpleResourcePool can only register one location")
-}
-
-// SimpleResourcePool will enter lame duck mode upon calling Unregister.
-func (p *simpleResourcePool) Unregister(resourceLocation string) error {
- p.EnterLameDuckMode()
- return nil
-}
-
-func (p *simpleResourcePool) ListRegistered() []string {
- p.mutex.Lock()
- defer p.mutex.Unlock()
-
- if p.location != "" {
- return []string{p.location}
- }
- return []string{}
-}
-
-func (p *simpleResourcePool) getLocation() (string, error) {
- p.mutex.Lock()
- defer p.mutex.Unlock()
-
- if p.location == "" {
- return "", fmt.Errorf(
- "resource location is not set for SimpleResourcePool")
- }
-
- if p.isLameDuck {
- return "", fmt.Errorf(
- "lame duck resource pool cannot return handles to %s",
- p.location)
- }
-
- return p.location, nil
-}
-
-// This gets an active resource from the resource pool. Note that the
-// resourceLocation argument is ignored (The handles are associated to the
-// resource location provided by the first Register call).
-func (p *simpleResourcePool) Get(unused string) (ManagedHandle, error) {
- activeCount := atomic.AddInt32(p.numActive, 1)
- if p.options.MaxActiveHandles > 0 &&
- activeCount > p.options.MaxActiveHandles {
-
- atomic.AddInt32(p.numActive, -1)
- return nil, TooManyHandles{p.location}
- }
-
- highest := atomic.LoadInt32(p.activeHighWaterMark)
- for activeCount > highest &&
- !atomic.CompareAndSwapInt32(
- p.activeHighWaterMark,
- highest,
- activeCount) {
-
- highest = atomic.LoadInt32(p.activeHighWaterMark)
- }
-
- if h := p.getIdleHandle(); h != nil {
- return h, nil
- }
-
- location, err := p.getLocation()
- if err != nil {
- atomic.AddInt32(p.numActive, -1)
- return nil, err
- }
-
- if p.openTokens != nil {
- // Current implementation does not wait for tokens to become available.
- // If that causes availability hits, we could increase the wait,
- // similar to simple_pool.go.
- if p.openTokens.TryAcquire(0) {
- defer p.openTokens.Release()
- } else {
- // We could not immediately acquire a token.
- // Instead of waiting
- atomic.AddInt32(p.numActive, -1)
- return nil, OpenHandleError{
- p.location, errors.New("Open Error: reached OpenMaxConcurrency")}
- }
- }
-
- handle, err := p.options.Open(location)
- if err != nil {
- atomic.AddInt32(p.numActive, -1)
- return nil, OpenHandleError{p.location, err}
- }
-
- return NewManagedHandle(p.location, handle, p, p.options), nil
-}
-
-// See ResourcePool for documentation.
-func (p *simpleResourcePool) Release(handle ManagedHandle) error {
- if pool, ok := handle.Owner().(*simpleResourcePool); !ok || pool != p {
- return errors.New(
- "Resource pool cannot take control of a handle owned " +
- "by another resource pool")
- }
-
- h := handle.ReleaseUnderlyingHandle()
- if h != nil {
- // We can unref either before or after queuing the idle handle.
- // The advantage of unref-ing before queuing is that there is
- // a higher chance of successful Get when number of active handles
- // is close to the limit (but potentially more handle creation).
- // The advantage of queuing before unref-ing is that there's a
- // higher chance of reusing handle (but potentially more Get failures).
- atomic.AddInt32(p.numActive, -1)
- p.queueIdleHandles(h)
- }
-
- return nil
-}
-
-// See ResourcePool for documentation.
-func (p *simpleResourcePool) Discard(handle ManagedHandle) error {
- if pool, ok := handle.Owner().(*simpleResourcePool); !ok || pool != p {
- return errors.New(
- "Resource pool cannot take control of a handle owned " +
- "by another resource pool")
- }
-
- h := handle.ReleaseUnderlyingHandle()
- if h != nil {
- atomic.AddInt32(p.numActive, -1)
- if err := p.options.Close(h); err != nil {
- return fmt.Errorf("failed to close resource handle: %w", err)
- }
- }
- return nil
-}
-
-// See ResourcePool for documentation.
-func (p *simpleResourcePool) EnterLameDuckMode() {
- p.mutex.Lock()
-
- toClose := p.idleHandles
- p.isLameDuck = true
- p.idleHandles = []*idleHandle{}
-
- p.mutex.Unlock()
-
- p.closeHandles(toClose)
-}
-
-// This returns an idle resource, if there is one.
-func (p *simpleResourcePool) getIdleHandle() ManagedHandle {
- var toClose []*idleHandle
- defer func() {
- // NOTE: Must keep the closure around to late bind the toClose slice.
- p.closeHandles(toClose)
- }()
-
- now := p.options.getCurrentTime()
-
- p.mutex.Lock()
- defer p.mutex.Unlock()
-
- var i int
- for i = 0; i < len(p.idleHandles); i++ {
- idle := p.idleHandles[i]
- if idle.keepUntil == nil || now.Before(*idle.keepUntil) {
- break
- }
- }
- if i > 0 {
- toClose = p.idleHandles[0:i]
- }
-
- if i < len(p.idleHandles) {
- idle := p.idleHandles[i]
- p.idleHandles = p.idleHandles[i+1:]
- return NewManagedHandle(p.location, idle.handle, p, p.options)
- }
-
- if len(p.idleHandles) > 0 {
- p.idleHandles = []*idleHandle{}
- }
- return nil
-}
-
-// This adds an idle resource to the pool.
-func (p *simpleResourcePool) queueIdleHandles(handle interface{}) {
- var toClose []*idleHandle
- defer func() {
- // NOTE: Must keep the closure around to late bind the toClose slice.
- p.closeHandles(toClose)
- }()
-
- now := p.options.getCurrentTime()
- var keepUntil *time.Time
- if p.options.MaxIdleTime != nil {
- // NOTE: Assign to temp variable first to work around compiler bug
- x := now.Add(*p.options.MaxIdleTime)
- keepUntil = &x
- }
-
- p.mutex.Lock()
- defer p.mutex.Unlock()
-
- if p.isLameDuck {
- toClose = []*idleHandle{
- {handle: handle},
- }
- return
- }
-
- p.idleHandles = append(
- p.idleHandles,
- &idleHandle{
- handle: handle,
- keepUntil: keepUntil,
- })
-
- nIdleHandles := uint32(len(p.idleHandles))
- if nIdleHandles > p.options.MaxIdleHandles {
- handlesToClose := nIdleHandles - p.options.MaxIdleHandles
- toClose = p.idleHandles[0:handlesToClose]
- p.idleHandles = p.idleHandles[handlesToClose:nIdleHandles]
- }
-}
-
-// Closes resources, at this point it is assumed that this resources
-// are no longer referenced from the main idleHandles slice.
-func (p *simpleResourcePool) closeHandles(handles []*idleHandle) {
- for _, handle := range handles {
- _ = p.options.Close(handle.handle)
- }
-}
diff --git a/weed/weed.go b/weed/weed.go
index f940cdacd..f83777bf5 100644
--- a/weed/weed.go
+++ b/weed/weed.go
@@ -196,17 +196,9 @@ func help(args []string) {
var atexitFuncs []func()
-func atexit(f func()) {
- atexitFuncs = append(atexitFuncs, f)
-}
-
func exit() {
for _, f := range atexitFuncs {
f()
}
os.Exit(exitStatus)
}
-
-func debug(params ...interface{}) {
- glog.V(4).Infoln(params...)
-}
diff --git a/weed/worker/registry.go b/weed/worker/registry.go
index 0b40ddec4..fd6cecf30 100644
--- a/weed/worker/registry.go
+++ b/weed/worker/registry.go
@@ -1,9 +1,7 @@
package worker
import (
- "fmt"
"sync"
- "time"
"github.com/seaweedfs/seaweedfs/weed/worker/types"
)
@@ -15,334 +13,6 @@ type Registry struct {
mutex sync.RWMutex
}
-// NewRegistry creates a new worker registry
-func NewRegistry() *Registry {
- return &Registry{
- workers: make(map[string]*types.WorkerData),
- stats: &types.RegistryStats{
- TotalWorkers: 0,
- ActiveWorkers: 0,
- BusyWorkers: 0,
- IdleWorkers: 0,
- TotalTasks: 0,
- CompletedTasks: 0,
- FailedTasks: 0,
- StartTime: time.Now(),
- },
- }
-}
-
-// RegisterWorker registers a new worker
-func (r *Registry) RegisterWorker(worker *types.WorkerData) error {
- r.mutex.Lock()
- defer r.mutex.Unlock()
-
- if _, exists := r.workers[worker.ID]; exists {
- return fmt.Errorf("worker %s already registered", worker.ID)
- }
-
- r.workers[worker.ID] = worker
- r.updateStats()
- return nil
-}
-
-// UnregisterWorker removes a worker from the registry
-func (r *Registry) UnregisterWorker(workerID string) error {
- r.mutex.Lock()
- defer r.mutex.Unlock()
-
- if _, exists := r.workers[workerID]; !exists {
- return fmt.Errorf("worker %s not found", workerID)
- }
-
- delete(r.workers, workerID)
- r.updateStats()
- return nil
-}
-
-// GetWorker returns a worker by ID
-func (r *Registry) GetWorker(workerID string) (*types.WorkerData, bool) {
- r.mutex.RLock()
- defer r.mutex.RUnlock()
-
- worker, exists := r.workers[workerID]
- return worker, exists
-}
-
-// ListWorkers returns all registered workers
-func (r *Registry) ListWorkers() []*types.WorkerData {
- r.mutex.RLock()
- defer r.mutex.RUnlock()
-
- workers := make([]*types.WorkerData, 0, len(r.workers))
- for _, worker := range r.workers {
- workers = append(workers, worker)
- }
- return workers
-}
-
-// GetWorkersByCapability returns workers that support a specific capability
-func (r *Registry) GetWorkersByCapability(capability types.TaskType) []*types.WorkerData {
- r.mutex.RLock()
- defer r.mutex.RUnlock()
-
- var workers []*types.WorkerData
- for _, worker := range r.workers {
- for _, cap := range worker.Capabilities {
- if cap == capability {
- workers = append(workers, worker)
- break
- }
- }
- }
- return workers
-}
-
-// GetAvailableWorkers returns workers that are available for new tasks
-func (r *Registry) GetAvailableWorkers() []*types.WorkerData {
- r.mutex.RLock()
- defer r.mutex.RUnlock()
-
- var workers []*types.WorkerData
- for _, worker := range r.workers {
- if worker.Status == "active" && worker.CurrentLoad < worker.MaxConcurrent {
- workers = append(workers, worker)
- }
- }
- return workers
-}
-
-// GetBestWorkerForTask returns the best worker for a specific task
-func (r *Registry) GetBestWorkerForTask(taskType types.TaskType) *types.WorkerData {
- r.mutex.RLock()
- defer r.mutex.RUnlock()
-
- var bestWorker *types.WorkerData
- var bestScore float64
-
- for _, worker := range r.workers {
- // Check if worker supports this task type
- supportsTask := false
- for _, cap := range worker.Capabilities {
- if cap == taskType {
- supportsTask = true
- break
- }
- }
-
- if !supportsTask {
- continue
- }
-
- // Check if worker is available
- if worker.Status != "active" || worker.CurrentLoad >= worker.MaxConcurrent {
- continue
- }
-
- // Calculate score based on current load and capacity
- score := float64(worker.MaxConcurrent-worker.CurrentLoad) / float64(worker.MaxConcurrent)
- if bestWorker == nil || score > bestScore {
- bestWorker = worker
- bestScore = score
- }
- }
-
- return bestWorker
-}
-
-// UpdateWorkerHeartbeat updates the last heartbeat time for a worker
-func (r *Registry) UpdateWorkerHeartbeat(workerID string) error {
- r.mutex.Lock()
- defer r.mutex.Unlock()
-
- worker, exists := r.workers[workerID]
- if !exists {
- return fmt.Errorf("worker %s not found", workerID)
- }
-
- worker.LastHeartbeat = time.Now()
- return nil
-}
-
-// UpdateWorkerLoad updates the current load for a worker
-func (r *Registry) UpdateWorkerLoad(workerID string, load int) error {
- r.mutex.Lock()
- defer r.mutex.Unlock()
-
- worker, exists := r.workers[workerID]
- if !exists {
- return fmt.Errorf("worker %s not found", workerID)
- }
-
- worker.CurrentLoad = load
- if load >= worker.MaxConcurrent {
- worker.Status = "busy"
- } else {
- worker.Status = "active"
- }
-
- r.updateStats()
- return nil
-}
-
-// UpdateWorkerStatus updates the status of a worker
-func (r *Registry) UpdateWorkerStatus(workerID string, status string) error {
- r.mutex.Lock()
- defer r.mutex.Unlock()
-
- worker, exists := r.workers[workerID]
- if !exists {
- return fmt.Errorf("worker %s not found", workerID)
- }
-
- worker.Status = status
- r.updateStats()
- return nil
-}
-
-// CleanupStaleWorkers removes workers that haven't sent heartbeats recently
-func (r *Registry) CleanupStaleWorkers(timeout time.Duration) int {
- r.mutex.Lock()
- defer r.mutex.Unlock()
-
- var removedCount int
- cutoff := time.Now().Add(-timeout)
-
- for workerID, worker := range r.workers {
- if worker.LastHeartbeat.Before(cutoff) {
- delete(r.workers, workerID)
- removedCount++
- }
- }
-
- if removedCount > 0 {
- r.updateStats()
- }
-
- return removedCount
-}
-
-// GetStats returns current registry statistics
-func (r *Registry) GetStats() *types.RegistryStats {
- r.mutex.RLock()
- defer r.mutex.RUnlock()
-
- // Create a copy of the stats to avoid race conditions
- stats := *r.stats
- return &stats
-}
-
-// updateStats updates the registry statistics (must be called with lock held)
-func (r *Registry) updateStats() {
- r.stats.TotalWorkers = len(r.workers)
- r.stats.ActiveWorkers = 0
- r.stats.BusyWorkers = 0
- r.stats.IdleWorkers = 0
-
- for _, worker := range r.workers {
- switch worker.Status {
- case "active":
- if worker.CurrentLoad > 0 {
- r.stats.ActiveWorkers++
- } else {
- r.stats.IdleWorkers++
- }
- case "busy":
- r.stats.BusyWorkers++
- }
- }
-
- r.stats.Uptime = time.Since(r.stats.StartTime)
- r.stats.LastUpdated = time.Now()
-}
-
-// GetTaskCapabilities returns all task capabilities available in the registry
-func (r *Registry) GetTaskCapabilities() []types.TaskType {
- r.mutex.RLock()
- defer r.mutex.RUnlock()
-
- capabilitySet := make(map[types.TaskType]bool)
- for _, worker := range r.workers {
- for _, cap := range worker.Capabilities {
- capabilitySet[cap] = true
- }
- }
-
- var capabilities []types.TaskType
- for cap := range capabilitySet {
- capabilities = append(capabilities, cap)
- }
-
- return capabilities
-}
-
-// GetWorkersByStatus returns workers filtered by status
-func (r *Registry) GetWorkersByStatus(status string) []*types.WorkerData {
- r.mutex.RLock()
- defer r.mutex.RUnlock()
-
- var workers []*types.WorkerData
- for _, worker := range r.workers {
- if worker.Status == status {
- workers = append(workers, worker)
- }
- }
- return workers
-}
-
-// GetWorkerCount returns the total number of registered workers
-func (r *Registry) GetWorkerCount() int {
- r.mutex.RLock()
- defer r.mutex.RUnlock()
- return len(r.workers)
-}
-
-// GetWorkerIDs returns all worker IDs
-func (r *Registry) GetWorkerIDs() []string {
- r.mutex.RLock()
- defer r.mutex.RUnlock()
-
- ids := make([]string, 0, len(r.workers))
- for id := range r.workers {
- ids = append(ids, id)
- }
- return ids
-}
-
-// GetWorkerSummary returns a summary of all workers
-func (r *Registry) GetWorkerSummary() *types.WorkerSummary {
- r.mutex.RLock()
- defer r.mutex.RUnlock()
-
- summary := &types.WorkerSummary{
- TotalWorkers: len(r.workers),
- ByStatus: make(map[string]int),
- ByCapability: make(map[types.TaskType]int),
- TotalLoad: 0,
- MaxCapacity: 0,
- }
-
- for _, worker := range r.workers {
- summary.ByStatus[worker.Status]++
- summary.TotalLoad += worker.CurrentLoad
- summary.MaxCapacity += worker.MaxConcurrent
-
- for _, cap := range worker.Capabilities {
- summary.ByCapability[cap]++
- }
- }
-
- return summary
-}
-
// Default global registry instance
var defaultRegistry *Registry
var registryOnce sync.Once
-
-// GetDefaultRegistry returns the default global registry
-func GetDefaultRegistry() *Registry {
- registryOnce.Do(func() {
- defaultRegistry = NewRegistry()
- })
- return defaultRegistry
-}
diff --git a/weed/worker/tasks/balance/monitoring.go b/weed/worker/tasks/balance/monitoring.go
deleted file mode 100644
index 517de2484..000000000
--- a/weed/worker/tasks/balance/monitoring.go
+++ /dev/null
@@ -1,138 +0,0 @@
-package balance
-
-import (
- "sync"
- "time"
-)
-
-// BalanceMetrics contains balance-specific monitoring data
-type BalanceMetrics struct {
- // Execution metrics
- VolumesBalanced int64 `json:"volumes_balanced"`
- TotalDataTransferred int64 `json:"total_data_transferred"`
- AverageImbalance float64 `json:"average_imbalance"`
- LastBalanceTime time.Time `json:"last_balance_time"`
-
- // Performance metrics
- AverageTransferSpeed float64 `json:"average_transfer_speed_mbps"`
- TotalExecutionTime int64 `json:"total_execution_time_seconds"`
- SuccessfulOperations int64 `json:"successful_operations"`
- FailedOperations int64 `json:"failed_operations"`
-
- // Current task metrics
- CurrentImbalanceScore float64 `json:"current_imbalance_score"`
- PlannedDestinations int `json:"planned_destinations"`
-
- mutex sync.RWMutex
-}
-
-// NewBalanceMetrics creates a new balance metrics instance
-func NewBalanceMetrics() *BalanceMetrics {
- return &BalanceMetrics{
- LastBalanceTime: time.Now(),
- }
-}
-
-// RecordVolumeBalanced records a successful volume balance operation
-func (m *BalanceMetrics) RecordVolumeBalanced(volumeSize int64, transferTime time.Duration) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.VolumesBalanced++
- m.TotalDataTransferred += volumeSize
- m.SuccessfulOperations++
- m.LastBalanceTime = time.Now()
- m.TotalExecutionTime += int64(transferTime.Seconds())
-
- // Calculate average transfer speed (MB/s)
- if transferTime > 0 {
- speedMBps := float64(volumeSize) / (1024 * 1024) / transferTime.Seconds()
- if m.AverageTransferSpeed == 0 {
- m.AverageTransferSpeed = speedMBps
- } else {
- // Exponential moving average
- m.AverageTransferSpeed = 0.8*m.AverageTransferSpeed + 0.2*speedMBps
- }
- }
-}
-
-// RecordFailure records a failed balance operation
-func (m *BalanceMetrics) RecordFailure() {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.FailedOperations++
-}
-
-// UpdateImbalanceScore updates the current cluster imbalance score
-func (m *BalanceMetrics) UpdateImbalanceScore(score float64) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.CurrentImbalanceScore = score
-
- // Update average imbalance with exponential moving average
- if m.AverageImbalance == 0 {
- m.AverageImbalance = score
- } else {
- m.AverageImbalance = 0.9*m.AverageImbalance + 0.1*score
- }
-}
-
-// SetPlannedDestinations sets the number of planned destinations
-func (m *BalanceMetrics) SetPlannedDestinations(count int) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.PlannedDestinations = count
-}
-
-// GetMetrics returns a copy of the current metrics (without the mutex)
-func (m *BalanceMetrics) GetMetrics() BalanceMetrics {
- m.mutex.RLock()
- defer m.mutex.RUnlock()
-
- // Create a copy without the mutex to avoid copying lock value
- return BalanceMetrics{
- VolumesBalanced: m.VolumesBalanced,
- TotalDataTransferred: m.TotalDataTransferred,
- AverageImbalance: m.AverageImbalance,
- LastBalanceTime: m.LastBalanceTime,
- AverageTransferSpeed: m.AverageTransferSpeed,
- TotalExecutionTime: m.TotalExecutionTime,
- SuccessfulOperations: m.SuccessfulOperations,
- FailedOperations: m.FailedOperations,
- CurrentImbalanceScore: m.CurrentImbalanceScore,
- PlannedDestinations: m.PlannedDestinations,
- }
-}
-
-// GetSuccessRate returns the success rate as a percentage
-func (m *BalanceMetrics) GetSuccessRate() float64 {
- m.mutex.RLock()
- defer m.mutex.RUnlock()
-
- total := m.SuccessfulOperations + m.FailedOperations
- if total == 0 {
- return 100.0
- }
- return float64(m.SuccessfulOperations) / float64(total) * 100.0
-}
-
-// Reset resets all metrics to zero
-func (m *BalanceMetrics) Reset() {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- *m = BalanceMetrics{
- LastBalanceTime: time.Now(),
- }
-}
-
-// Global metrics instance for balance tasks
-var globalBalanceMetrics = NewBalanceMetrics()
-
-// GetGlobalBalanceMetrics returns the global balance metrics instance
-func GetGlobalBalanceMetrics() *BalanceMetrics {
- return globalBalanceMetrics
-}
diff --git a/weed/worker/tasks/base/registration.go b/weed/worker/tasks/base/registration.go
index f69db6b48..12335eb15 100644
--- a/weed/worker/tasks/base/registration.go
+++ b/weed/worker/tasks/base/registration.go
@@ -74,26 +74,6 @@ type GenericUIProvider struct {
taskDef *TaskDefinition
}
-// GetTaskType returns the task type
-func (ui *GenericUIProvider) GetTaskType() types.TaskType {
- return ui.taskDef.Type
-}
-
-// GetDisplayName returns the human-readable name
-func (ui *GenericUIProvider) GetDisplayName() string {
- return ui.taskDef.DisplayName
-}
-
-// GetDescription returns a description of what this task does
-func (ui *GenericUIProvider) GetDescription() string {
- return ui.taskDef.Description
-}
-
-// GetIcon returns the icon CSS class for this task type
-func (ui *GenericUIProvider) GetIcon() string {
- return ui.taskDef.Icon
-}
-
// GetCurrentConfig returns current config as TaskConfig
func (ui *GenericUIProvider) GetCurrentConfig() types.TaskConfig {
return ui.taskDef.Config
diff --git a/weed/worker/tasks/base/task_definition.go b/weed/worker/tasks/base/task_definition.go
index 5ebc2a4b6..04c5de06f 100644
--- a/weed/worker/tasks/base/task_definition.go
+++ b/weed/worker/tasks/base/task_definition.go
@@ -2,8 +2,6 @@ package base
import (
"fmt"
- "reflect"
- "strings"
"time"
"github.com/seaweedfs/seaweedfs/weed/admin/config"
@@ -75,108 +73,6 @@ func (c *BaseConfig) Validate() error {
return nil
}
-// StructToMap converts any struct to a map using reflection
-func StructToMap(obj interface{}) map[string]interface{} {
- result := make(map[string]interface{})
- val := reflect.ValueOf(obj)
-
- // Handle pointer to struct
- if val.Kind() == reflect.Ptr {
- val = val.Elem()
- }
-
- if val.Kind() != reflect.Struct {
- return result
- }
-
- typ := val.Type()
-
- for i := 0; i < val.NumField(); i++ {
- field := val.Field(i)
- fieldType := typ.Field(i)
-
- // Skip unexported fields
- if !field.CanInterface() {
- continue
- }
-
- // Handle embedded structs recursively (before JSON tag check)
- if field.Kind() == reflect.Struct && fieldType.Anonymous {
- embeddedMap := StructToMap(field.Interface())
- for k, v := range embeddedMap {
- result[k] = v
- }
- continue
- }
-
- // Get JSON tag name
- jsonTag := fieldType.Tag.Get("json")
- if jsonTag == "" || jsonTag == "-" {
- continue
- }
-
- // Remove options like ",omitempty"
- if commaIdx := strings.Index(jsonTag, ","); commaIdx >= 0 {
- jsonTag = jsonTag[:commaIdx]
- }
-
- result[jsonTag] = field.Interface()
- }
- return result
-}
-
-// MapToStruct loads data from map into struct using reflection
-func MapToStruct(data map[string]interface{}, obj interface{}) error {
- val := reflect.ValueOf(obj)
-
- // Must be pointer to struct
- if val.Kind() != reflect.Ptr || val.Elem().Kind() != reflect.Struct {
- return fmt.Errorf("obj must be pointer to struct")
- }
-
- val = val.Elem()
- typ := val.Type()
-
- for i := 0; i < val.NumField(); i++ {
- field := val.Field(i)
- fieldType := typ.Field(i)
-
- // Skip unexported fields
- if !field.CanSet() {
- continue
- }
-
- // Handle embedded structs recursively (before JSON tag check)
- if field.Kind() == reflect.Struct && fieldType.Anonymous {
- err := MapToStruct(data, field.Addr().Interface())
- if err != nil {
- return err
- }
- continue
- }
-
- // Get JSON tag name
- jsonTag := fieldType.Tag.Get("json")
- if jsonTag == "" || jsonTag == "-" {
- continue
- }
-
- // Remove options like ",omitempty"
- if commaIdx := strings.Index(jsonTag, ","); commaIdx >= 0 {
- jsonTag = jsonTag[:commaIdx]
- }
-
- if value, exists := data[jsonTag]; exists {
- err := setFieldValue(field, value)
- if err != nil {
- return fmt.Errorf("failed to set field %s: %v", jsonTag, err)
- }
- }
- }
-
- return nil
-}
-
// ToMap converts config to map using reflection
// ToTaskPolicy converts BaseConfig to protobuf (partial implementation)
// Note: Concrete implementations should override this to include task-specific config
@@ -207,66 +103,3 @@ func (c *BaseConfig) ApplySchemaDefaults(schema *config.Schema) error {
// Use reflection-based approach for BaseConfig since it needs to handle embedded structs
return schema.ApplyDefaultsToProtobuf(c)
}
-
-// setFieldValue sets a field value with type conversion
-func setFieldValue(field reflect.Value, value interface{}) error {
- if value == nil {
- return nil
- }
-
- valueVal := reflect.ValueOf(value)
- fieldType := field.Type()
- valueType := valueVal.Type()
-
- // Direct assignment if types match
- if valueType.AssignableTo(fieldType) {
- field.Set(valueVal)
- return nil
- }
-
- // Type conversion for common cases
- switch fieldType.Kind() {
- case reflect.Bool:
- if b, ok := value.(bool); ok {
- field.SetBool(b)
- } else {
- return fmt.Errorf("cannot convert %T to bool", value)
- }
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- switch v := value.(type) {
- case int:
- field.SetInt(int64(v))
- case int32:
- field.SetInt(int64(v))
- case int64:
- field.SetInt(v)
- case float64:
- field.SetInt(int64(v))
- default:
- return fmt.Errorf("cannot convert %T to int", value)
- }
- case reflect.Float32, reflect.Float64:
- switch v := value.(type) {
- case float32:
- field.SetFloat(float64(v))
- case float64:
- field.SetFloat(v)
- case int:
- field.SetFloat(float64(v))
- case int64:
- field.SetFloat(float64(v))
- default:
- return fmt.Errorf("cannot convert %T to float", value)
- }
- case reflect.String:
- if s, ok := value.(string); ok {
- field.SetString(s)
- } else {
- return fmt.Errorf("cannot convert %T to string", value)
- }
- default:
- return fmt.Errorf("unsupported field type %s", fieldType.Kind())
- }
-
- return nil
-}
diff --git a/weed/worker/tasks/base/task_definition_test.go b/weed/worker/tasks/base/task_definition_test.go
deleted file mode 100644
index a0a0a5a24..000000000
--- a/weed/worker/tasks/base/task_definition_test.go
+++ /dev/null
@@ -1,338 +0,0 @@
-package base
-
-import (
- "reflect"
- "testing"
-)
-
-// Test structs that mirror the actual configuration structure
-type TestBaseConfig struct {
- Enabled bool `json:"enabled"`
- ScanIntervalSeconds int `json:"scan_interval_seconds"`
- MaxConcurrent int `json:"max_concurrent"`
-}
-
-type TestTaskConfig struct {
- TestBaseConfig
- TaskSpecificField float64 `json:"task_specific_field"`
- AnotherSpecificField string `json:"another_specific_field"`
-}
-
-type TestNestedConfig struct {
- TestBaseConfig
- NestedStruct struct {
- NestedField string `json:"nested_field"`
- } `json:"nested_struct"`
- TaskField int `json:"task_field"`
-}
-
-func TestStructToMap_WithEmbeddedStruct(t *testing.T) {
- // Test case 1: Basic embedded struct
- config := &TestTaskConfig{
- TestBaseConfig: TestBaseConfig{
- Enabled: true,
- ScanIntervalSeconds: 1800,
- MaxConcurrent: 3,
- },
- TaskSpecificField: 0.25,
- AnotherSpecificField: "test_value",
- }
-
- result := StructToMap(config)
-
- // Verify all fields are present
- expectedFields := map[string]interface{}{
- "enabled": true,
- "scan_interval_seconds": 1800,
- "max_concurrent": 3,
- "task_specific_field": 0.25,
- "another_specific_field": "test_value",
- }
-
- if len(result) != len(expectedFields) {
- t.Errorf("Expected %d fields, got %d. Result: %+v", len(expectedFields), len(result), result)
- }
-
- for key, expectedValue := range expectedFields {
- if actualValue, exists := result[key]; !exists {
- t.Errorf("Missing field: %s", key)
- } else if !reflect.DeepEqual(actualValue, expectedValue) {
- t.Errorf("Field %s: expected %v (%T), got %v (%T)", key, expectedValue, expectedValue, actualValue, actualValue)
- }
- }
-}
-
-func TestStructToMap_WithNestedStruct(t *testing.T) {
- config := &TestNestedConfig{
- TestBaseConfig: TestBaseConfig{
- Enabled: false,
- ScanIntervalSeconds: 3600,
- MaxConcurrent: 1,
- },
- NestedStruct: struct {
- NestedField string `json:"nested_field"`
- }{
- NestedField: "nested_value",
- },
- TaskField: 42,
- }
-
- result := StructToMap(config)
-
- // Verify embedded struct fields are included
- if enabled, exists := result["enabled"]; !exists || enabled != false {
- t.Errorf("Expected enabled=false from embedded struct, got %v", enabled)
- }
-
- if scanInterval, exists := result["scan_interval_seconds"]; !exists || scanInterval != 3600 {
- t.Errorf("Expected scan_interval_seconds=3600 from embedded struct, got %v", scanInterval)
- }
-
- if maxConcurrent, exists := result["max_concurrent"]; !exists || maxConcurrent != 1 {
- t.Errorf("Expected max_concurrent=1 from embedded struct, got %v", maxConcurrent)
- }
-
- // Verify regular fields are included
- if taskField, exists := result["task_field"]; !exists || taskField != 42 {
- t.Errorf("Expected task_field=42, got %v", taskField)
- }
-
- // Verify nested struct is included as a whole
- if nestedStruct, exists := result["nested_struct"]; !exists {
- t.Errorf("Missing nested_struct field")
- } else {
- // The nested struct should be included as-is, not flattened
- if nested, ok := nestedStruct.(struct {
- NestedField string `json:"nested_field"`
- }); !ok || nested.NestedField != "nested_value" {
- t.Errorf("Expected nested_struct with NestedField='nested_value', got %v", nestedStruct)
- }
- }
-}
-
-func TestMapToStruct_WithEmbeddedStruct(t *testing.T) {
- // Test data with all fields including embedded struct fields
- data := map[string]interface{}{
- "enabled": true,
- "scan_interval_seconds": 2400,
- "max_concurrent": 5,
- "task_specific_field": 0.15,
- "another_specific_field": "updated_value",
- }
-
- config := &TestTaskConfig{}
- err := MapToStruct(data, config)
-
- if err != nil {
- t.Fatalf("MapToStruct failed: %v", err)
- }
-
- // Verify embedded struct fields were set
- if config.Enabled != true {
- t.Errorf("Expected Enabled=true, got %v", config.Enabled)
- }
-
- if config.ScanIntervalSeconds != 2400 {
- t.Errorf("Expected ScanIntervalSeconds=2400, got %v", config.ScanIntervalSeconds)
- }
-
- if config.MaxConcurrent != 5 {
- t.Errorf("Expected MaxConcurrent=5, got %v", config.MaxConcurrent)
- }
-
- // Verify regular fields were set
- if config.TaskSpecificField != 0.15 {
- t.Errorf("Expected TaskSpecificField=0.15, got %v", config.TaskSpecificField)
- }
-
- if config.AnotherSpecificField != "updated_value" {
- t.Errorf("Expected AnotherSpecificField='updated_value', got %v", config.AnotherSpecificField)
- }
-}
-
-func TestMapToStruct_PartialData(t *testing.T) {
- // Test with only some fields present (simulating form data)
- data := map[string]interface{}{
- "enabled": false,
- "max_concurrent": 2,
- "task_specific_field": 0.30,
- }
-
- // Start with some initial values
- config := &TestTaskConfig{
- TestBaseConfig: TestBaseConfig{
- Enabled: true,
- ScanIntervalSeconds: 1800,
- MaxConcurrent: 1,
- },
- TaskSpecificField: 0.20,
- AnotherSpecificField: "initial_value",
- }
-
- err := MapToStruct(data, config)
-
- if err != nil {
- t.Fatalf("MapToStruct failed: %v", err)
- }
-
- // Verify updated fields
- if config.Enabled != false {
- t.Errorf("Expected Enabled=false (updated), got %v", config.Enabled)
- }
-
- if config.MaxConcurrent != 2 {
- t.Errorf("Expected MaxConcurrent=2 (updated), got %v", config.MaxConcurrent)
- }
-
- if config.TaskSpecificField != 0.30 {
- t.Errorf("Expected TaskSpecificField=0.30 (updated), got %v", config.TaskSpecificField)
- }
-
- // Verify unchanged fields remain the same
- if config.ScanIntervalSeconds != 1800 {
- t.Errorf("Expected ScanIntervalSeconds=1800 (unchanged), got %v", config.ScanIntervalSeconds)
- }
-
- if config.AnotherSpecificField != "initial_value" {
- t.Errorf("Expected AnotherSpecificField='initial_value' (unchanged), got %v", config.AnotherSpecificField)
- }
-}
-
-func TestRoundTripSerialization(t *testing.T) {
- // Test complete round-trip: struct -> map -> struct
- original := &TestTaskConfig{
- TestBaseConfig: TestBaseConfig{
- Enabled: true,
- ScanIntervalSeconds: 3600,
- MaxConcurrent: 4,
- },
- TaskSpecificField: 0.18,
- AnotherSpecificField: "round_trip_test",
- }
-
- // Convert to map
- dataMap := StructToMap(original)
-
- // Convert back to struct
- roundTrip := &TestTaskConfig{}
- err := MapToStruct(dataMap, roundTrip)
-
- if err != nil {
- t.Fatalf("Round-trip MapToStruct failed: %v", err)
- }
-
- // Verify all fields match
- if !reflect.DeepEqual(original.TestBaseConfig, roundTrip.TestBaseConfig) {
- t.Errorf("BaseConfig mismatch:\nOriginal: %+v\nRound-trip: %+v", original.TestBaseConfig, roundTrip.TestBaseConfig)
- }
-
- if original.TaskSpecificField != roundTrip.TaskSpecificField {
- t.Errorf("TaskSpecificField mismatch: %v != %v", original.TaskSpecificField, roundTrip.TaskSpecificField)
- }
-
- if original.AnotherSpecificField != roundTrip.AnotherSpecificField {
- t.Errorf("AnotherSpecificField mismatch: %v != %v", original.AnotherSpecificField, roundTrip.AnotherSpecificField)
- }
-}
-
-func TestStructToMap_EmptyStruct(t *testing.T) {
- config := &TestTaskConfig{}
- result := StructToMap(config)
-
- // Should still include all fields, even with zero values
- expectedFields := []string{"enabled", "scan_interval_seconds", "max_concurrent", "task_specific_field", "another_specific_field"}
-
- for _, field := range expectedFields {
- if _, exists := result[field]; !exists {
- t.Errorf("Missing field: %s", field)
- }
- }
-}
-
-func TestStructToMap_NilPointer(t *testing.T) {
- var config *TestTaskConfig = nil
- result := StructToMap(config)
-
- if len(result) != 0 {
- t.Errorf("Expected empty map for nil pointer, got %+v", result)
- }
-}
-
-func TestMapToStruct_InvalidInput(t *testing.T) {
- data := map[string]interface{}{
- "enabled": "not_a_bool", // Wrong type
- }
-
- config := &TestTaskConfig{}
- err := MapToStruct(data, config)
-
- if err == nil {
- t.Errorf("Expected error for invalid input type, but got none")
- }
-}
-
-func TestMapToStruct_NonPointer(t *testing.T) {
- data := map[string]interface{}{
- "enabled": true,
- }
-
- config := TestTaskConfig{} // Not a pointer
- err := MapToStruct(data, config)
-
- if err == nil {
- t.Errorf("Expected error for non-pointer input, but got none")
- }
-}
-
-// Benchmark tests to ensure performance is reasonable
-func BenchmarkStructToMap(b *testing.B) {
- config := &TestTaskConfig{
- TestBaseConfig: TestBaseConfig{
- Enabled: true,
- ScanIntervalSeconds: 1800,
- MaxConcurrent: 3,
- },
- TaskSpecificField: 0.25,
- AnotherSpecificField: "benchmark_test",
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = StructToMap(config)
- }
-}
-
-func BenchmarkMapToStruct(b *testing.B) {
- data := map[string]interface{}{
- "enabled": true,
- "scan_interval_seconds": 1800,
- "max_concurrent": 3,
- "task_specific_field": 0.25,
- "another_specific_field": "benchmark_test",
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- config := &TestTaskConfig{}
- _ = MapToStruct(data, config)
- }
-}
-
-func BenchmarkRoundTrip(b *testing.B) {
- original := &TestTaskConfig{
- TestBaseConfig: TestBaseConfig{
- Enabled: true,
- ScanIntervalSeconds: 1800,
- MaxConcurrent: 3,
- },
- TaskSpecificField: 0.25,
- AnotherSpecificField: "benchmark_test",
- }
-
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- dataMap := StructToMap(original)
- roundTrip := &TestTaskConfig{}
- _ = MapToStruct(dataMap, roundTrip)
- }
-}
diff --git a/weed/worker/tasks/erasure_coding/monitoring.go b/weed/worker/tasks/erasure_coding/monitoring.go
deleted file mode 100644
index 799eb62c8..000000000
--- a/weed/worker/tasks/erasure_coding/monitoring.go
+++ /dev/null
@@ -1,229 +0,0 @@
-package erasure_coding
-
-import (
- "sync"
- "time"
-)
-
-// ErasureCodingMetrics contains erasure coding-specific monitoring data
-type ErasureCodingMetrics struct {
- // Execution metrics
- VolumesEncoded int64 `json:"volumes_encoded"`
- TotalShardsCreated int64 `json:"total_shards_created"`
- TotalDataProcessed int64 `json:"total_data_processed"`
- TotalSourcesRemoved int64 `json:"total_sources_removed"`
- LastEncodingTime time.Time `json:"last_encoding_time"`
-
- // Performance metrics
- AverageEncodingTime int64 `json:"average_encoding_time_seconds"`
- AverageShardSize int64 `json:"average_shard_size"`
- AverageDataShards int `json:"average_data_shards"`
- AverageParityShards int `json:"average_parity_shards"`
- SuccessfulOperations int64 `json:"successful_operations"`
- FailedOperations int64 `json:"failed_operations"`
-
- // Distribution metrics
- ShardsPerDataCenter map[string]int64 `json:"shards_per_datacenter"`
- ShardsPerRack map[string]int64 `json:"shards_per_rack"`
- PlacementSuccessRate float64 `json:"placement_success_rate"`
-
- // Current task metrics
- CurrentVolumeSize int64 `json:"current_volume_size"`
- CurrentShardCount int `json:"current_shard_count"`
- VolumesPendingEncoding int `json:"volumes_pending_encoding"`
-
- mutex sync.RWMutex
-}
-
-// NewErasureCodingMetrics creates a new erasure coding metrics instance
-func NewErasureCodingMetrics() *ErasureCodingMetrics {
- return &ErasureCodingMetrics{
- LastEncodingTime: time.Now(),
- ShardsPerDataCenter: make(map[string]int64),
- ShardsPerRack: make(map[string]int64),
- }
-}
-
-// RecordVolumeEncoded records a successful volume encoding operation
-func (m *ErasureCodingMetrics) RecordVolumeEncoded(volumeSize int64, shardsCreated int, dataShards int, parityShards int, encodingTime time.Duration, sourceRemoved bool) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.VolumesEncoded++
- m.TotalShardsCreated += int64(shardsCreated)
- m.TotalDataProcessed += volumeSize
- m.SuccessfulOperations++
- m.LastEncodingTime = time.Now()
-
- if sourceRemoved {
- m.TotalSourcesRemoved++
- }
-
- // Update average encoding time
- if m.AverageEncodingTime == 0 {
- m.AverageEncodingTime = int64(encodingTime.Seconds())
- } else {
- // Exponential moving average
- newTime := int64(encodingTime.Seconds())
- m.AverageEncodingTime = (m.AverageEncodingTime*4 + newTime) / 5
- }
-
- // Update average shard size
- if shardsCreated > 0 {
- avgShardSize := volumeSize / int64(shardsCreated)
- if m.AverageShardSize == 0 {
- m.AverageShardSize = avgShardSize
- } else {
- m.AverageShardSize = (m.AverageShardSize*4 + avgShardSize) / 5
- }
- }
-
- // Update average data/parity shards
- if m.AverageDataShards == 0 {
- m.AverageDataShards = dataShards
- m.AverageParityShards = parityShards
- } else {
- m.AverageDataShards = (m.AverageDataShards*4 + dataShards) / 5
- m.AverageParityShards = (m.AverageParityShards*4 + parityShards) / 5
- }
-}
-
-// RecordFailure records a failed erasure coding operation
-func (m *ErasureCodingMetrics) RecordFailure() {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.FailedOperations++
-}
-
-// RecordShardPlacement records shard placement for distribution tracking
-func (m *ErasureCodingMetrics) RecordShardPlacement(dataCenter string, rack string) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.ShardsPerDataCenter[dataCenter]++
- rackKey := dataCenter + ":" + rack
- m.ShardsPerRack[rackKey]++
-}
-
-// UpdateCurrentVolumeInfo updates current volume processing information
-func (m *ErasureCodingMetrics) UpdateCurrentVolumeInfo(volumeSize int64, shardCount int) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.CurrentVolumeSize = volumeSize
- m.CurrentShardCount = shardCount
-}
-
-// SetVolumesPendingEncoding sets the number of volumes pending encoding
-func (m *ErasureCodingMetrics) SetVolumesPendingEncoding(count int) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.VolumesPendingEncoding = count
-}
-
-// UpdatePlacementSuccessRate updates the placement success rate
-func (m *ErasureCodingMetrics) UpdatePlacementSuccessRate(rate float64) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- if m.PlacementSuccessRate == 0 {
- m.PlacementSuccessRate = rate
- } else {
- // Exponential moving average
- m.PlacementSuccessRate = 0.8*m.PlacementSuccessRate + 0.2*rate
- }
-}
-
-// GetMetrics returns a copy of the current metrics (without the mutex)
-func (m *ErasureCodingMetrics) GetMetrics() ErasureCodingMetrics {
- m.mutex.RLock()
- defer m.mutex.RUnlock()
-
- // Create deep copy of maps
- shardsPerDC := make(map[string]int64)
- for k, v := range m.ShardsPerDataCenter {
- shardsPerDC[k] = v
- }
-
- shardsPerRack := make(map[string]int64)
- for k, v := range m.ShardsPerRack {
- shardsPerRack[k] = v
- }
-
- // Create a copy without the mutex to avoid copying lock value
- return ErasureCodingMetrics{
- VolumesEncoded: m.VolumesEncoded,
- TotalShardsCreated: m.TotalShardsCreated,
- TotalDataProcessed: m.TotalDataProcessed,
- TotalSourcesRemoved: m.TotalSourcesRemoved,
- LastEncodingTime: m.LastEncodingTime,
- AverageEncodingTime: m.AverageEncodingTime,
- AverageShardSize: m.AverageShardSize,
- AverageDataShards: m.AverageDataShards,
- AverageParityShards: m.AverageParityShards,
- SuccessfulOperations: m.SuccessfulOperations,
- FailedOperations: m.FailedOperations,
- ShardsPerDataCenter: shardsPerDC,
- ShardsPerRack: shardsPerRack,
- PlacementSuccessRate: m.PlacementSuccessRate,
- CurrentVolumeSize: m.CurrentVolumeSize,
- CurrentShardCount: m.CurrentShardCount,
- VolumesPendingEncoding: m.VolumesPendingEncoding,
- }
-}
-
-// GetSuccessRate returns the success rate as a percentage
-func (m *ErasureCodingMetrics) GetSuccessRate() float64 {
- m.mutex.RLock()
- defer m.mutex.RUnlock()
-
- total := m.SuccessfulOperations + m.FailedOperations
- if total == 0 {
- return 100.0
- }
- return float64(m.SuccessfulOperations) / float64(total) * 100.0
-}
-
-// GetAverageDataProcessed returns the average data processed per volume
-func (m *ErasureCodingMetrics) GetAverageDataProcessed() float64 {
- m.mutex.RLock()
- defer m.mutex.RUnlock()
-
- if m.VolumesEncoded == 0 {
- return 0
- }
- return float64(m.TotalDataProcessed) / float64(m.VolumesEncoded)
-}
-
-// GetSourceRemovalRate returns the percentage of sources removed after encoding
-func (m *ErasureCodingMetrics) GetSourceRemovalRate() float64 {
- m.mutex.RLock()
- defer m.mutex.RUnlock()
-
- if m.VolumesEncoded == 0 {
- return 0
- }
- return float64(m.TotalSourcesRemoved) / float64(m.VolumesEncoded) * 100.0
-}
-
-// Reset resets all metrics to zero
-func (m *ErasureCodingMetrics) Reset() {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- *m = ErasureCodingMetrics{
- LastEncodingTime: time.Now(),
- ShardsPerDataCenter: make(map[string]int64),
- ShardsPerRack: make(map[string]int64),
- }
-}
-
-// Global metrics instance for erasure coding tasks
-var globalErasureCodingMetrics = NewErasureCodingMetrics()
-
-// GetGlobalErasureCodingMetrics returns the global erasure coding metrics instance
-func GetGlobalErasureCodingMetrics() *ErasureCodingMetrics {
- return globalErasureCodingMetrics
-}
diff --git a/weed/worker/tasks/registry.go b/weed/worker/tasks/registry.go
index 626a54a14..fb1c477cf 100644
--- a/weed/worker/tasks/registry.go
+++ b/weed/worker/tasks/registry.go
@@ -64,51 +64,6 @@ func AutoRegisterUI(registerFunc func(*types.UIRegistry)) {
glog.V(1).Infof("Auto-registered task UI provider")
}
-// SetDefaultCapabilitiesFromRegistry sets the default worker capabilities
-// based on all registered task types
-func SetDefaultCapabilitiesFromRegistry() {
- typesRegistry := GetGlobalTypesRegistry()
-
- var capabilities []types.TaskType
- for taskType := range typesRegistry.GetAllDetectors() {
- capabilities = append(capabilities, taskType)
- }
-
- // Set the default capabilities in the types package
- types.SetDefaultCapabilities(capabilities)
-
- glog.V(1).Infof("Set default worker capabilities from registry: %v", capabilities)
-}
-
-// BuildMaintenancePolicyFromTasks creates a maintenance policy with default configurations
-// from all registered tasks using their UI providers
-func BuildMaintenancePolicyFromTasks() *types.MaintenancePolicy {
- policy := types.NewMaintenancePolicy()
-
- // Get all registered task types from the UI registry
- uiRegistry := GetGlobalUIRegistry()
-
- for taskType, provider := range uiRegistry.GetAllProviders() {
- // Get the default configuration from the UI provider
- defaultConfig := provider.GetCurrentConfig()
-
- // Set the configuration in the policy
- policy.SetTaskConfig(taskType, defaultConfig)
-
- glog.V(3).Infof("Added default config for task type %s to policy", taskType)
- }
-
- glog.V(2).Infof("Built maintenance policy with %d task configurations", len(policy.TaskConfigs))
- return policy
-}
-
-// SetMaintenancePolicyFromTasks sets the default maintenance policy from registered tasks
-func SetMaintenancePolicyFromTasks() {
- // This function can be called to initialize the policy from registered tasks
- // For now, we'll just log that this should be called by the integration layer
- glog.V(1).Infof("SetMaintenancePolicyFromTasks called - policy should be built by the integration layer")
-}
-
// TaskRegistry manages task factories
type TaskRegistry struct {
factories map[types.TaskType]types.TaskFactory
diff --git a/weed/worker/tasks/schema_provider.go b/weed/worker/tasks/schema_provider.go
index 4d69556b1..9715aad17 100644
--- a/weed/worker/tasks/schema_provider.go
+++ b/weed/worker/tasks/schema_provider.go
@@ -36,16 +36,3 @@ func RegisterTaskConfigSchema(taskType string, provider TaskConfigSchemaProvider
defer globalSchemaRegistry.mutex.Unlock()
globalSchemaRegistry.providers[taskType] = provider
}
-
-// GetTaskConfigSchema returns the schema for the specified task type
-func GetTaskConfigSchema(taskType string) *TaskConfigSchema {
- globalSchemaRegistry.mutex.RLock()
- provider, exists := globalSchemaRegistry.providers[taskType]
- globalSchemaRegistry.mutex.RUnlock()
-
- if !exists {
- return nil
- }
-
- return provider.GetConfigSchema()
-}
diff --git a/weed/worker/tasks/task.go b/weed/worker/tasks/task.go
index f3eed8b2d..4ce022326 100644
--- a/weed/worker/tasks/task.go
+++ b/weed/worker/tasks/task.go
@@ -1,12 +1,9 @@
package tasks
import (
- "context"
- "fmt"
"sync"
"time"
- "github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
"github.com/seaweedfs/seaweedfs/weed/worker/types"
)
@@ -26,353 +23,11 @@ type BaseTask struct {
currentStage string // Current stage description
}
-// NewBaseTask creates a new base task
-func NewBaseTask(taskType types.TaskType) *BaseTask {
- return &BaseTask{
- taskType: taskType,
- progress: 0.0,
- cancelled: false,
- loggerConfig: DefaultTaskLoggerConfig(),
- }
-}
-
-// NewBaseTaskWithLogger creates a new base task with custom logger configuration
-func NewBaseTaskWithLogger(taskType types.TaskType, loggerConfig TaskLoggerConfig) *BaseTask {
- return &BaseTask{
- taskType: taskType,
- progress: 0.0,
- cancelled: false,
- loggerConfig: loggerConfig,
- }
-}
-
-// InitializeLogger initializes the task logger with task details
-func (t *BaseTask) InitializeLogger(taskID string, workerID string, params types.TaskParams) error {
- return t.InitializeTaskLogger(taskID, workerID, params)
-}
-
-// InitializeTaskLogger initializes the task logger with task details (LoggerProvider interface)
-func (t *BaseTask) InitializeTaskLogger(taskID string, workerID string, params types.TaskParams) error {
- t.mutex.Lock()
- defer t.mutex.Unlock()
-
- t.taskID = taskID
-
- logger, err := NewTaskLogger(taskID, t.taskType, workerID, params, t.loggerConfig)
- if err != nil {
- return fmt.Errorf("failed to initialize task logger: %w", err)
- }
-
- t.logger = logger
- t.logger.Info("BaseTask initialized for task %s (type: %s)", taskID, t.taskType)
-
- return nil
-}
-
-// Type returns the task type
-func (t *BaseTask) Type() types.TaskType {
- return t.taskType
-}
-
-// GetProgress returns the current progress (0.0 to 100.0)
-func (t *BaseTask) GetProgress() float64 {
- t.mutex.RLock()
- defer t.mutex.RUnlock()
- return t.progress
-}
-
-// SetProgress sets the current progress and logs it
-func (t *BaseTask) SetProgress(progress float64) {
- t.mutex.Lock()
- if progress < 0 {
- progress = 0
- }
- if progress > 100 {
- progress = 100
- }
- oldProgress := t.progress
- callback := t.progressCallback
- stage := t.currentStage
- t.progress = progress
- t.mutex.Unlock()
-
- // Log progress change
- if t.logger != nil && progress != oldProgress {
- message := stage
- if message == "" {
- message = fmt.Sprintf("Progress updated from %.1f%% to %.1f%%", oldProgress, progress)
- }
- t.logger.LogProgress(progress, message)
- }
-
- // Call progress callback if set
- if callback != nil && progress != oldProgress {
- callback(progress, stage)
- }
-}
-
-// SetProgressWithStage sets the current progress with a stage description
-func (t *BaseTask) SetProgressWithStage(progress float64, stage string) {
- t.mutex.Lock()
- if progress < 0 {
- progress = 0
- }
- if progress > 100 {
- progress = 100
- }
- callback := t.progressCallback
- t.progress = progress
- t.currentStage = stage
- t.mutex.Unlock()
-
- // Log progress change
- if t.logger != nil {
- t.logger.LogProgress(progress, stage)
- }
-
- // Call progress callback if set
- if callback != nil {
- callback(progress, stage)
- }
-}
-
-// SetCurrentStage sets the current stage description
-func (t *BaseTask) SetCurrentStage(stage string) {
- t.mutex.Lock()
- defer t.mutex.Unlock()
- t.currentStage = stage
-}
-
-// GetCurrentStage returns the current stage description
-func (t *BaseTask) GetCurrentStage() string {
- t.mutex.RLock()
- defer t.mutex.RUnlock()
- return t.currentStage
-}
-
-// Cancel cancels the task
-func (t *BaseTask) Cancel() error {
- t.mutex.Lock()
- defer t.mutex.Unlock()
-
- if t.cancelled {
- return nil
- }
-
- t.cancelled = true
-
- if t.logger != nil {
- t.logger.LogStatus("cancelled", "Task cancelled by request")
- t.logger.Warning("Task %s was cancelled", t.taskID)
- }
-
- return nil
-}
-
-// IsCancelled returns whether the task is cancelled
-func (t *BaseTask) IsCancelled() bool {
- t.mutex.RLock()
- defer t.mutex.RUnlock()
- return t.cancelled
-}
-
-// SetStartTime sets the task start time
-func (t *BaseTask) SetStartTime(startTime time.Time) {
- t.mutex.Lock()
- defer t.mutex.Unlock()
- t.startTime = startTime
-
- if t.logger != nil {
- t.logger.LogStatus("running", fmt.Sprintf("Task started at %s", startTime.Format(time.RFC3339)))
- }
-}
-
-// GetStartTime returns the task start time
-func (t *BaseTask) GetStartTime() time.Time {
- t.mutex.RLock()
- defer t.mutex.RUnlock()
- return t.startTime
-}
-
-// SetEstimatedDuration sets the estimated duration
-func (t *BaseTask) SetEstimatedDuration(duration time.Duration) {
- t.mutex.Lock()
- defer t.mutex.Unlock()
- t.estimatedDuration = duration
-
- if t.logger != nil {
- t.logger.LogWithFields("INFO", "Estimated duration set", map[string]interface{}{
- "estimated_duration": duration.String(),
- "estimated_seconds": duration.Seconds(),
- })
- }
-}
-
-// GetEstimatedDuration returns the estimated duration
-func (t *BaseTask) GetEstimatedDuration() time.Duration {
- t.mutex.RLock()
- defer t.mutex.RUnlock()
- return t.estimatedDuration
-}
-
-// SetProgressCallback sets the progress callback function
-func (t *BaseTask) SetProgressCallback(callback func(float64, string)) {
- t.mutex.Lock()
- defer t.mutex.Unlock()
- t.progressCallback = callback
-}
-
-// SetLoggerConfig sets the logger configuration for this task
-func (t *BaseTask) SetLoggerConfig(config TaskLoggerConfig) {
- t.mutex.Lock()
- defer t.mutex.Unlock()
- t.loggerConfig = config
-}
-
-// GetLogger returns the task logger
-func (t *BaseTask) GetLogger() TaskLogger {
- t.mutex.RLock()
- defer t.mutex.RUnlock()
- return t.logger
-}
-
-// GetTaskLogger returns the task logger (LoggerProvider interface)
-func (t *BaseTask) GetTaskLogger() TaskLogger {
- t.mutex.RLock()
- defer t.mutex.RUnlock()
- return t.logger
-}
-
-// LogInfo logs an info message
-func (t *BaseTask) LogInfo(message string, args ...interface{}) {
- if t.logger != nil {
- t.logger.Info(message, args...)
- }
-}
-
-// LogWarning logs a warning message
-func (t *BaseTask) LogWarning(message string, args ...interface{}) {
- if t.logger != nil {
- t.logger.Warning(message, args...)
- }
-}
-
-// LogError logs an error message
-func (t *BaseTask) LogError(message string, args ...interface{}) {
- if t.logger != nil {
- t.logger.Error(message, args...)
- }
-}
-
-// LogDebug logs a debug message
-func (t *BaseTask) LogDebug(message string, args ...interface{}) {
- if t.logger != nil {
- t.logger.Debug(message, args...)
- }
-}
-
-// LogWithFields logs a message with structured fields
-func (t *BaseTask) LogWithFields(level string, message string, fields map[string]interface{}) {
- if t.logger != nil {
- t.logger.LogWithFields(level, message, fields)
- }
-}
-
-// FinishTask finalizes the task and closes the logger
-func (t *BaseTask) FinishTask(success bool, errorMsg string) error {
- if t.logger != nil {
- if success {
- t.logger.LogStatus("completed", "Task completed successfully")
- t.logger.Info("Task %s finished successfully", t.taskID)
- } else {
- t.logger.LogStatus("failed", fmt.Sprintf("Task failed: %s", errorMsg))
- t.logger.Error("Task %s failed: %s", t.taskID, errorMsg)
- }
-
- // Close logger
- if err := t.logger.Close(); err != nil {
- glog.Errorf("Failed to close task logger: %v", err)
- }
- }
-
- return nil
-}
-
-// ExecuteTask is a wrapper that handles common task execution logic with logging
-func (t *BaseTask) ExecuteTask(ctx context.Context, params types.TaskParams, executor func(context.Context, types.TaskParams) error) error {
- // Initialize logger if not already done
- if t.logger == nil {
- // Generate a temporary task ID if none provided
- if t.taskID == "" {
- t.taskID = fmt.Sprintf("task_%d", time.Now().UnixNano())
- }
-
- workerID := "unknown"
- if err := t.InitializeLogger(t.taskID, workerID, params); err != nil {
- glog.Warningf("Failed to initialize task logger: %v", err)
- }
- }
-
- t.SetStartTime(time.Now())
- t.SetProgress(0)
-
- if t.logger != nil {
- t.logger.LogWithFields("INFO", "Task execution started", map[string]interface{}{
- "volume_id": params.VolumeID,
- "server": getServerFromSources(params.TypedParams.Sources),
- "collection": params.Collection,
- })
- }
-
- // Create a context that can be cancelled
- ctx, cancel := context.WithCancel(ctx)
- defer cancel()
-
- // Monitor for cancellation
- go func() {
- for !t.IsCancelled() {
- select {
- case <-ctx.Done():
- return
- case <-time.After(time.Second):
- // Check cancellation every second
- }
- }
- t.LogWarning("Task cancellation detected, cancelling context")
- cancel()
- }()
-
- // Execute the actual task
- t.LogInfo("Starting task executor")
- err := executor(ctx, params)
-
- if err != nil {
- t.LogError("Task executor failed: %v", err)
- t.FinishTask(false, err.Error())
- return err
- }
-
- if t.IsCancelled() {
- t.LogWarning("Task was cancelled during execution")
- t.FinishTask(false, "cancelled")
- return context.Canceled
- }
-
- t.SetProgress(100)
- t.LogInfo("Task executor completed successfully")
- t.FinishTask(true, "")
- return nil
-}
-
// UnsupportedTaskTypeError represents an error for unsupported task types
type UnsupportedTaskTypeError struct {
TaskType types.TaskType
}
-func (e *UnsupportedTaskTypeError) Error() string {
- return "unsupported task type: " + string(e.TaskType)
-}
-
// BaseTaskFactory provides common functionality for task factories
type BaseTaskFactory struct {
taskType types.TaskType
@@ -399,37 +54,12 @@ func (f *BaseTaskFactory) Description() string {
return f.description
}
-// ValidateParams validates task parameters
-func ValidateParams(params types.TaskParams, requiredFields ...string) error {
- for _, field := range requiredFields {
- switch field {
- case "volume_id":
- if params.VolumeID == 0 {
- return &ValidationError{Field: field, Message: "volume_id is required"}
- }
- case "server":
- if len(params.TypedParams.Sources) == 0 {
- return &ValidationError{Field: field, Message: "server is required"}
- }
- case "collection":
- if params.Collection == "" {
- return &ValidationError{Field: field, Message: "collection is required"}
- }
- }
- }
- return nil
-}
-
// ValidationError represents a parameter validation error
type ValidationError struct {
Field string
Message string
}
-func (e *ValidationError) Error() string {
- return e.Field + ": " + e.Message
-}
-
// getServerFromSources extracts the server address from unified sources
func getServerFromSources(sources []*worker_pb.TaskSource) string {
if len(sources) > 0 {
diff --git a/weed/worker/tasks/task_log_handler.go b/weed/worker/tasks/task_log_handler.go
index fee62325e..e2d2fc185 100644
--- a/weed/worker/tasks/task_log_handler.go
+++ b/weed/worker/tasks/task_log_handler.go
@@ -223,36 +223,3 @@ func (h *TaskLogHandler) readTaskLogEntries(logDir string, request *worker_pb.Ta
return pbEntries, nil
}
-
-// ListAvailableTaskLogs returns a list of available task log directories
-func (h *TaskLogHandler) ListAvailableTaskLogs() ([]string, error) {
- entries, err := os.ReadDir(h.baseLogDir)
- if err != nil {
- return nil, fmt.Errorf("failed to read base log directory: %w", err)
- }
-
- var taskDirs []string
- for _, entry := range entries {
- if entry.IsDir() {
- taskDirs = append(taskDirs, entry.Name())
- }
- }
-
- return taskDirs, nil
-}
-
-// CleanupOldLogs removes old task logs beyond the specified limit
-func (h *TaskLogHandler) CleanupOldLogs(maxTasks int) error {
- config := TaskLoggerConfig{
- BaseLogDir: h.baseLogDir,
- MaxTasks: maxTasks,
- }
-
- // Create a temporary logger to trigger cleanup
- tempLogger := &FileTaskLogger{
- config: config,
- }
-
- tempLogger.cleanupOldLogs()
- return nil
-}
diff --git a/weed/worker/tasks/ui_base.go b/weed/worker/tasks/ui_base.go
index eb9369337..265914aa6 100644
--- a/weed/worker/tasks/ui_base.go
+++ b/weed/worker/tasks/ui_base.go
@@ -1,9 +1,6 @@
package tasks
import (
- "reflect"
-
- "github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
"github.com/seaweedfs/seaweedfs/weed/worker/types"
)
@@ -85,100 +82,5 @@ type CommonConfigGetter[T any] struct {
schedulerFunc func() T
}
-// NewCommonConfigGetter creates a new common config getter
-func NewCommonConfigGetter[T any](
- defaultConfig T,
- detectorFunc func() T,
- schedulerFunc func() T,
-) *CommonConfigGetter[T] {
- return &CommonConfigGetter[T]{
- defaultConfig: defaultConfig,
- detectorFunc: detectorFunc,
- schedulerFunc: schedulerFunc,
- }
-}
-
-// GetConfig returns the merged configuration
-func (cg *CommonConfigGetter[T]) GetConfig() T {
- config := cg.defaultConfig
-
- // Apply detector values if available
- if cg.detectorFunc != nil {
- detectorConfig := cg.detectorFunc()
- mergeConfigs(&config, detectorConfig)
- }
-
- // Apply scheduler values if available
- if cg.schedulerFunc != nil {
- schedulerConfig := cg.schedulerFunc()
- mergeConfigs(&config, schedulerConfig)
- }
-
- return config
-}
-
-// mergeConfigs merges non-zero values from source into dest
-func mergeConfigs[T any](dest *T, source T) {
- destValue := reflect.ValueOf(dest).Elem()
- sourceValue := reflect.ValueOf(source)
-
- if destValue.Kind() != reflect.Struct || sourceValue.Kind() != reflect.Struct {
- return
- }
-
- for i := 0; i < destValue.NumField(); i++ {
- destField := destValue.Field(i)
- sourceField := sourceValue.Field(i)
-
- if !destField.CanSet() {
- continue
- }
-
- // Only copy non-zero values
- if !sourceField.IsZero() {
- if destField.Type() == sourceField.Type() {
- destField.Set(sourceField)
- }
- }
- }
-}
-
// RegisterUIFunc provides a common registration function signature
type RegisterUIFunc[D, S any] func(uiRegistry *types.UIRegistry, detector D, scheduler S)
-
-// CommonRegisterUI provides a common registration implementation
-func CommonRegisterUI[D, S any](
- taskType types.TaskType,
- displayName string,
- uiRegistry *types.UIRegistry,
- detector D,
- scheduler S,
- schemaFunc func() *TaskConfigSchema,
- configFunc func() types.TaskConfig,
- applyTaskPolicyFunc func(policy *worker_pb.TaskPolicy) error,
- applyTaskConfigFunc func(config types.TaskConfig) error,
-) {
- // Get metadata from schema
- schema := schemaFunc()
- description := "Task configuration"
- icon := "fas fa-cog"
-
- if schema != nil {
- description = schema.Description
- icon = schema.Icon
- }
-
- uiProvider := NewBaseUIProvider(
- taskType,
- displayName,
- description,
- icon,
- schemaFunc,
- configFunc,
- applyTaskPolicyFunc,
- applyTaskConfigFunc,
- )
-
- uiRegistry.RegisterUI(uiProvider)
- glog.V(1).Infof("Registered %s task UI provider", taskType)
-}
diff --git a/weed/worker/tasks/util/csv.go b/weed/worker/tasks/util/csv.go
deleted file mode 100644
index 50fb09bff..000000000
--- a/weed/worker/tasks/util/csv.go
+++ /dev/null
@@ -1,20 +0,0 @@
-package util
-
-import "strings"
-
-// ParseCSVSet splits a comma-separated string into a set of trimmed,
-// non-empty values. Returns nil if the input is empty.
-func ParseCSVSet(csv string) map[string]bool {
- csv = strings.TrimSpace(csv)
- if csv == "" {
- return nil
- }
- set := make(map[string]bool)
- for _, item := range strings.Split(csv, ",") {
- trimmed := strings.TrimSpace(item)
- if trimmed != "" {
- set[trimmed] = true
- }
- }
- return set
-}
diff --git a/weed/worker/tasks/vacuum/monitoring.go b/weed/worker/tasks/vacuum/monitoring.go
deleted file mode 100644
index c7dfd673e..000000000
--- a/weed/worker/tasks/vacuum/monitoring.go
+++ /dev/null
@@ -1,151 +0,0 @@
-package vacuum
-
-import (
- "sync"
- "time"
-)
-
-// VacuumMetrics contains vacuum-specific monitoring data
-type VacuumMetrics struct {
- // Execution metrics
- VolumesVacuumed int64 `json:"volumes_vacuumed"`
- TotalSpaceReclaimed int64 `json:"total_space_reclaimed"`
- TotalFilesProcessed int64 `json:"total_files_processed"`
- TotalGarbageCollected int64 `json:"total_garbage_collected"`
- LastVacuumTime time.Time `json:"last_vacuum_time"`
-
- // Performance metrics
- AverageVacuumTime int64 `json:"average_vacuum_time_seconds"`
- AverageGarbageRatio float64 `json:"average_garbage_ratio"`
- SuccessfulOperations int64 `json:"successful_operations"`
- FailedOperations int64 `json:"failed_operations"`
-
- // Current task metrics
- CurrentGarbageRatio float64 `json:"current_garbage_ratio"`
- VolumesPendingVacuum int `json:"volumes_pending_vacuum"`
-
- mutex sync.RWMutex
-}
-
-// NewVacuumMetrics creates a new vacuum metrics instance
-func NewVacuumMetrics() *VacuumMetrics {
- return &VacuumMetrics{
- LastVacuumTime: time.Now(),
- }
-}
-
-// RecordVolumeVacuumed records a successful volume vacuum operation
-func (m *VacuumMetrics) RecordVolumeVacuumed(spaceReclaimed int64, filesProcessed int64, garbageCollected int64, vacuumTime time.Duration, garbageRatio float64) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.VolumesVacuumed++
- m.TotalSpaceReclaimed += spaceReclaimed
- m.TotalFilesProcessed += filesProcessed
- m.TotalGarbageCollected += garbageCollected
- m.SuccessfulOperations++
- m.LastVacuumTime = time.Now()
-
- // Update average vacuum time
- if m.AverageVacuumTime == 0 {
- m.AverageVacuumTime = int64(vacuumTime.Seconds())
- } else {
- // Exponential moving average
- newTime := int64(vacuumTime.Seconds())
- m.AverageVacuumTime = (m.AverageVacuumTime*4 + newTime) / 5
- }
-
- // Update average garbage ratio
- if m.AverageGarbageRatio == 0 {
- m.AverageGarbageRatio = garbageRatio
- } else {
- // Exponential moving average
- m.AverageGarbageRatio = 0.8*m.AverageGarbageRatio + 0.2*garbageRatio
- }
-}
-
-// RecordFailure records a failed vacuum operation
-func (m *VacuumMetrics) RecordFailure() {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.FailedOperations++
-}
-
-// UpdateCurrentGarbageRatio updates the current volume's garbage ratio
-func (m *VacuumMetrics) UpdateCurrentGarbageRatio(ratio float64) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.CurrentGarbageRatio = ratio
-}
-
-// SetVolumesPendingVacuum sets the number of volumes pending vacuum
-func (m *VacuumMetrics) SetVolumesPendingVacuum(count int) {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- m.VolumesPendingVacuum = count
-}
-
-// GetMetrics returns a copy of the current metrics (without the mutex)
-func (m *VacuumMetrics) GetMetrics() VacuumMetrics {
- m.mutex.RLock()
- defer m.mutex.RUnlock()
-
- // Create a copy without the mutex to avoid copying lock value
- return VacuumMetrics{
- VolumesVacuumed: m.VolumesVacuumed,
- TotalSpaceReclaimed: m.TotalSpaceReclaimed,
- TotalFilesProcessed: m.TotalFilesProcessed,
- TotalGarbageCollected: m.TotalGarbageCollected,
- LastVacuumTime: m.LastVacuumTime,
- AverageVacuumTime: m.AverageVacuumTime,
- AverageGarbageRatio: m.AverageGarbageRatio,
- SuccessfulOperations: m.SuccessfulOperations,
- FailedOperations: m.FailedOperations,
- CurrentGarbageRatio: m.CurrentGarbageRatio,
- VolumesPendingVacuum: m.VolumesPendingVacuum,
- }
-}
-
-// GetSuccessRate returns the success rate as a percentage
-func (m *VacuumMetrics) GetSuccessRate() float64 {
- m.mutex.RLock()
- defer m.mutex.RUnlock()
-
- total := m.SuccessfulOperations + m.FailedOperations
- if total == 0 {
- return 100.0
- }
- return float64(m.SuccessfulOperations) / float64(total) * 100.0
-}
-
-// GetAverageSpaceReclaimed returns the average space reclaimed per volume
-func (m *VacuumMetrics) GetAverageSpaceReclaimed() float64 {
- m.mutex.RLock()
- defer m.mutex.RUnlock()
-
- if m.VolumesVacuumed == 0 {
- return 0
- }
- return float64(m.TotalSpaceReclaimed) / float64(m.VolumesVacuumed)
-}
-
-// Reset resets all metrics to zero
-func (m *VacuumMetrics) Reset() {
- m.mutex.Lock()
- defer m.mutex.Unlock()
-
- *m = VacuumMetrics{
- LastVacuumTime: time.Now(),
- }
-}
-
-// Global metrics instance for vacuum tasks
-var globalVacuumMetrics = NewVacuumMetrics()
-
-// GetGlobalVacuumMetrics returns the global vacuum metrics instance
-func GetGlobalVacuumMetrics() *VacuumMetrics {
- return globalVacuumMetrics
-}
diff --git a/weed/worker/types/config_types.go b/weed/worker/types/config_types.go
index 5a9e94fd5..1f91ec085 100644
--- a/weed/worker/types/config_types.go
+++ b/weed/worker/types/config_types.go
@@ -109,15 +109,6 @@ type MaintenanceWorkersData struct {
var defaultCapabilities []TaskType
var defaultCapabilitiesMutex sync.RWMutex
-// SetDefaultCapabilities sets the default capabilities for workers
-// This should be called after task registration is complete
-func SetDefaultCapabilities(capabilities []TaskType) {
- defaultCapabilitiesMutex.Lock()
- defer defaultCapabilitiesMutex.Unlock()
- defaultCapabilities = make([]TaskType, len(capabilities))
- copy(defaultCapabilities, capabilities)
-}
-
// GetDefaultCapabilities returns the default capabilities for workers
func GetDefaultCapabilities() []TaskType {
defaultCapabilitiesMutex.RLock()
@@ -129,18 +120,6 @@ func GetDefaultCapabilities() []TaskType {
return result
}
-// DefaultMaintenanceConfig returns default maintenance configuration
-func DefaultMaintenanceConfig() *MaintenanceConfig {
- return &MaintenanceConfig{
- Enabled: true,
- ScanInterval: 30 * time.Minute,
- CleanInterval: 6 * time.Hour,
- TaskRetention: 7 * 24 * time.Hour, // 7 days
- WorkerTimeout: 5 * time.Minute,
- Policy: NewMaintenancePolicy(),
- }
-}
-
// DefaultWorkerConfig returns default worker configuration
func DefaultWorkerConfig() *WorkerConfig {
// Get dynamic capabilities from registered task types
@@ -154,119 +133,3 @@ func DefaultWorkerConfig() *WorkerConfig {
Capabilities: capabilities,
}
}
-
-// NewMaintenancePolicy creates a new dynamic maintenance policy
-func NewMaintenancePolicy() *MaintenancePolicy {
- return &MaintenancePolicy{
- TaskConfigs: make(map[TaskType]interface{}),
- GlobalSettings: &GlobalMaintenanceSettings{
- DefaultMaxConcurrent: 2,
- MaintenanceEnabled: true,
- DefaultScanInterval: 30 * time.Minute,
- DefaultTaskTimeout: 5 * time.Minute,
- DefaultRetryCount: 3,
- DefaultRetryInterval: 5 * time.Minute,
- DefaultPriorityBoostAge: 24 * time.Hour,
- GlobalConcurrentLimit: 5,
- },
- }
-}
-
-// SetTaskConfig sets the configuration for a specific task type
-func (p *MaintenancePolicy) SetTaskConfig(taskType TaskType, config interface{}) {
- if p.TaskConfigs == nil {
- p.TaskConfigs = make(map[TaskType]interface{})
- }
- p.TaskConfigs[taskType] = config
-}
-
-// GetTaskConfig returns the configuration for a specific task type
-func (p *MaintenancePolicy) GetTaskConfig(taskType TaskType) interface{} {
- if p.TaskConfigs == nil {
- return nil
- }
- return p.TaskConfigs[taskType]
-}
-
-// IsTaskEnabled returns whether a task type is enabled (generic helper)
-func (p *MaintenancePolicy) IsTaskEnabled(taskType TaskType) bool {
- if !p.GlobalSettings.MaintenanceEnabled {
- return false
- }
-
- config := p.GetTaskConfig(taskType)
- if config == nil {
- return false
- }
-
- // Try to get enabled field from config using type assertion
- if configMap, ok := config.(map[string]interface{}); ok {
- if enabled, exists := configMap["enabled"]; exists {
- if enabledBool, ok := enabled.(bool); ok {
- return enabledBool
- }
- }
- }
-
- // If we can't determine from config, default to global setting
- return p.GlobalSettings.MaintenanceEnabled
-}
-
-// GetMaxConcurrent returns the max concurrent setting for a task type
-func (p *MaintenancePolicy) GetMaxConcurrent(taskType TaskType) int {
- config := p.GetTaskConfig(taskType)
- if config == nil {
- return p.GlobalSettings.DefaultMaxConcurrent
- }
-
- // Try to get max_concurrent field from config
- if configMap, ok := config.(map[string]interface{}); ok {
- if maxConcurrent, exists := configMap["max_concurrent"]; exists {
- if maxConcurrentInt, ok := maxConcurrent.(int); ok {
- return maxConcurrentInt
- }
- if maxConcurrentFloat, ok := maxConcurrent.(float64); ok {
- return int(maxConcurrentFloat)
- }
- }
- }
-
- return p.GlobalSettings.DefaultMaxConcurrent
-}
-
-// GetScanInterval returns the scan interval for a task type
-func (p *MaintenancePolicy) GetScanInterval(taskType TaskType) time.Duration {
- config := p.GetTaskConfig(taskType)
- if config == nil {
- return p.GlobalSettings.DefaultScanInterval
- }
-
- // Try to get scan_interval field from config
- if configMap, ok := config.(map[string]interface{}); ok {
- if scanInterval, exists := configMap["scan_interval"]; exists {
- if scanIntervalDuration, ok := scanInterval.(time.Duration); ok {
- return scanIntervalDuration
- }
- if scanIntervalString, ok := scanInterval.(string); ok {
- if duration, err := time.ParseDuration(scanIntervalString); err == nil {
- return duration
- }
- }
- }
- }
-
- return p.GlobalSettings.DefaultScanInterval
-}
-
-// GetAllTaskTypes returns all configured task types
-func (p *MaintenancePolicy) GetAllTaskTypes() []TaskType {
- if p.TaskConfigs == nil {
- return []TaskType{}
- }
-
- taskTypes := make([]TaskType, 0, len(p.TaskConfigs))
- for taskType := range p.TaskConfigs {
- taskTypes = append(taskTypes, taskType)
- }
- return taskTypes
-}
diff --git a/weed/worker/types/task.go b/weed/worker/types/task.go
index 7e924453c..5ebed89c1 100644
--- a/weed/worker/types/task.go
+++ b/weed/worker/types/task.go
@@ -53,14 +53,6 @@ type Logger interface {
// NoOpLogger is a logger that does nothing (silent)
type NoOpLogger struct{}
-func (l *NoOpLogger) Info(msg string, args ...interface{}) {}
-func (l *NoOpLogger) Warning(msg string, args ...interface{}) {}
-func (l *NoOpLogger) Error(msg string, args ...interface{}) {}
-func (l *NoOpLogger) Debug(msg string, args ...interface{}) {}
-func (l *NoOpLogger) WithFields(fields map[string]interface{}) Logger {
- return l // Return self since we're doing nothing anyway
-}
-
// GlogFallbackLogger is a logger that falls back to glog
type GlogFallbackLogger struct{}
@@ -137,87 +129,3 @@ type UnifiedBaseTask struct {
currentStage string
workingDir string
}
-
-// NewBaseTask creates a new base task
-func NewUnifiedBaseTask(id string, taskType TaskType) *UnifiedBaseTask {
- return &UnifiedBaseTask{
- id: id,
- taskType: taskType,
- }
-}
-
-// ID returns the task ID
-func (t *UnifiedBaseTask) ID() string {
- return t.id
-}
-
-// Type returns the task type
-func (t *UnifiedBaseTask) Type() TaskType {
- return t.taskType
-}
-
-// SetProgressCallback sets the progress callback
-func (t *UnifiedBaseTask) SetProgressCallback(callback func(float64, string)) {
- t.progressCallback = callback
-}
-
-// ReportProgress reports current progress through the callback
-func (t *UnifiedBaseTask) ReportProgress(progress float64) {
- if t.progressCallback != nil {
- t.progressCallback(progress, t.currentStage)
- }
-}
-
-// ReportProgressWithStage reports current progress with a specific stage description
-func (t *UnifiedBaseTask) ReportProgressWithStage(progress float64, stage string) {
- t.currentStage = stage
- if t.progressCallback != nil {
- t.progressCallback(progress, stage)
- }
-}
-
-// SetCurrentStage sets the current stage description
-func (t *UnifiedBaseTask) SetCurrentStage(stage string) {
- t.currentStage = stage
-}
-
-// GetCurrentStage returns the current stage description
-func (t *UnifiedBaseTask) GetCurrentStage() string {
- return t.currentStage
-}
-
-// Cancel marks the task as cancelled
-func (t *UnifiedBaseTask) Cancel() error {
- t.cancelled = true
- return nil
-}
-
-// IsCancellable returns true if the task can be cancelled
-func (t *UnifiedBaseTask) IsCancellable() bool {
- return true
-}
-
-// IsCancelled returns true if the task has been cancelled
-func (t *UnifiedBaseTask) IsCancelled() bool {
- return t.cancelled
-}
-
-// SetLogger sets the task logger
-func (t *UnifiedBaseTask) SetLogger(logger Logger) {
- t.logger = logger
-}
-
-// GetLogger returns the task logger
-func (t *UnifiedBaseTask) GetLogger() Logger {
- return t.logger
-}
-
-// SetWorkingDir sets the task working directory
-func (t *UnifiedBaseTask) SetWorkingDir(workingDir string) {
- t.workingDir = workingDir
-}
-
-// GetWorkingDir returns the task working directory
-func (t *UnifiedBaseTask) GetWorkingDir() string {
- return t.workingDir
-}
diff --git a/weed/worker/types/task_ui.go b/weed/worker/types/task_ui.go
index 8a57e83be..e10d727ac 100644
--- a/weed/worker/types/task_ui.go
+++ b/weed/worker/types/task_ui.go
@@ -6,47 +6,6 @@ import (
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
)
-// Helper function to convert seconds to the most appropriate interval unit
-func secondsToIntervalValueUnit(totalSeconds int) (int, string) {
- if totalSeconds == 0 {
- return 0, "minute"
- }
-
- // Preserve seconds when not divisible by minutes
- if totalSeconds < 60 || totalSeconds%60 != 0 {
- return totalSeconds, "second"
- }
-
- // Check if it's evenly divisible by days
- if totalSeconds%(24*3600) == 0 {
- return totalSeconds / (24 * 3600), "day"
- }
-
- // Check if it's evenly divisible by hours
- if totalSeconds%3600 == 0 {
- return totalSeconds / 3600, "hour"
- }
-
- // Default to minutes
- return totalSeconds / 60, "minute"
-}
-
-// Helper function to convert interval value and unit to seconds
-func IntervalValueUnitToSeconds(value int, unit string) int {
- switch unit {
- case "day":
- return value * 24 * 3600
- case "hour":
- return value * 3600
- case "minute":
- return value * 60
- case "second":
- return value
- default:
- return value * 60 // Default to minutes
- }
-}
-
// TaskConfig defines the interface for task configurations
// This matches the interfaces used in base package and handlers
type TaskConfig interface {
diff --git a/weed/worker/types/typed_task_interface.go b/weed/worker/types/typed_task_interface.go
index 39eaa2286..1ff26ab40 100644
--- a/weed/worker/types/typed_task_interface.go
+++ b/weed/worker/types/typed_task_interface.go
@@ -90,24 +90,6 @@ func (r *TypedTaskRegistry) RegisterTypedTask(taskType TaskType, creator TypedTa
r.creators[taskType] = creator
}
-// CreateTypedTask creates a new typed task instance
-func (r *TypedTaskRegistry) CreateTypedTask(taskType TaskType) (TypedTaskInterface, error) {
- creator, exists := r.creators[taskType]
- if !exists {
- return nil, ErrTaskTypeNotFound
- }
- return creator(), nil
-}
-
-// GetSupportedTypes returns all registered typed task types
-func (r *TypedTaskRegistry) GetSupportedTypes() []TaskType {
- types := make([]TaskType, 0, len(r.creators))
- for taskType := range r.creators {
- types = append(types, taskType)
- }
- return types
-}
-
// Global typed task registry
var globalTypedTaskRegistry = NewTypedTaskRegistry()
@@ -115,8 +97,3 @@ var globalTypedTaskRegistry = NewTypedTaskRegistry()
func RegisterGlobalTypedTask(taskType TaskType, creator TypedTaskCreator) {
globalTypedTaskRegistry.RegisterTypedTask(taskType, creator)
}
-
-// GetGlobalTypedTaskRegistry returns the global typed task registry
-func GetGlobalTypedTaskRegistry() *TypedTaskRegistry {
- return globalTypedTaskRegistry
-}
diff --git a/weed/worker/types/worker.go b/weed/worker/types/worker.go
index 9db5ba2c4..ac6cfac08 100644
--- a/weed/worker/types/worker.go
+++ b/weed/worker/types/worker.go
@@ -30,47 +30,3 @@ type BaseWorker struct {
currentTasks map[string]Task
logger Logger
}
-
-// NewBaseWorker creates a new base worker
-func NewBaseWorker(id string) *BaseWorker {
- return &BaseWorker{
- id: id,
- currentTasks: make(map[string]Task),
- }
-}
-
-// Configure applies worker configuration
-func (w *BaseWorker) Configure(config WorkerCreationConfig) error {
- w.id = config.ID
- w.capabilities = config.Capabilities
- w.maxConcurrent = config.MaxConcurrent
-
- if config.LoggerFactory != nil {
- logger, err := config.LoggerFactory.CreateLogger(context.Background(), LoggerConfig{
- ServiceName: "worker-" + w.id,
- MinLevel: LogLevelInfo,
- })
- if err != nil {
- return err
- }
- w.logger = logger
- }
-
- return nil
-}
-
-// GetCapabilities returns worker capabilities
-func (w *BaseWorker) GetCapabilities() []TaskType {
- return w.capabilities
-}
-
-// GetStatus returns current worker status
-func (w *BaseWorker) GetStatus() WorkerStatus {
- return WorkerStatus{
- WorkerID: w.id,
- Status: "active",
- Capabilities: w.capabilities,
- MaxConcurrent: w.maxConcurrent,
- CurrentLoad: len(w.currentTasks),
- }
-}
diff --git a/weed/worker/worker.go b/weed/worker/worker.go
index be2a2e9df..ffcd00b9e 100644
--- a/weed/worker/worker.go
+++ b/weed/worker/worker.go
@@ -383,31 +383,6 @@ func (w *Worker) setReqTick(tick *time.Ticker) *time.Ticker {
return w.getReqTick()
}
-func (w *Worker) getStartTime() time.Time {
- respCh := make(chan time.Time, 1)
- w.cmds <- workerCommand{
- action: ActionGetStartTime,
- data: respCh,
- }
- return <-respCh
-}
-func (w *Worker) getCompletedTasks() int {
- respCh := make(chan int, 1)
- w.cmds <- workerCommand{
- action: ActionGetCompletedTasks,
- data: respCh,
- }
- return <-respCh
-}
-func (w *Worker) getFailedTasks() int {
- respCh := make(chan int, 1)
- w.cmds <- workerCommand{
- action: ActionGetFailedTasks,
- data: respCh,
- }
- return <-respCh
-}
-
// getTaskLoggerConfig returns the task logger configuration with worker's log directory
func (w *Worker) getTaskLoggerConfig() tasks.TaskLoggerConfig {
config := tasks.DefaultTaskLoggerConfig()
@@ -543,27 +518,6 @@ func (w *Worker) handleStop(cmd workerCommand) {
cmd.resp <- nil
}
-// RegisterTask registers a task factory
-func (w *Worker) RegisterTask(taskType types.TaskType, factory types.TaskFactory) {
- w.registry.Register(taskType, factory)
-}
-
-// GetCapabilities returns the worker capabilities
-func (w *Worker) GetCapabilities() []types.TaskType {
- return w.config.Capabilities
-}
-
-// GetStatus returns the current worker status
-func (w *Worker) GetStatus() types.WorkerStatus {
- respCh := make(statusResponse, 1)
- w.cmds <- workerCommand{
- action: ActionGetStatus,
- data: respCh,
- resp: nil,
- }
- return <-respCh
-}
-
// HandleTask handles a task execution
func (w *Worker) HandleTask(task *types.TaskInput) error {
glog.V(1).Infof("Worker %s received task %s (type: %s, volume: %d)",
@@ -579,26 +533,6 @@ func (w *Worker) HandleTask(task *types.TaskInput) error {
return nil
}
-// SetCapabilities sets the worker capabilities
-func (w *Worker) SetCapabilities(capabilities []types.TaskType) {
- w.config.Capabilities = capabilities
-}
-
-// SetMaxConcurrent sets the maximum concurrent tasks
-func (w *Worker) SetMaxConcurrent(max int) {
- w.config.MaxConcurrent = max
-}
-
-// SetHeartbeatInterval sets the heartbeat interval
-func (w *Worker) SetHeartbeatInterval(interval time.Duration) {
- w.config.HeartbeatInterval = interval
-}
-
-// SetTaskRequestInterval sets the task request interval
-func (w *Worker) SetTaskRequestInterval(interval time.Duration) {
- w.config.TaskRequestInterval = interval
-}
-
// SetAdminClient sets the admin client
func (w *Worker) SetAdminClient(client AdminClient) {
w.cmds <- workerCommand{
@@ -828,11 +762,6 @@ func (w *Worker) requestTasks() {
}
}
-// GetTaskRegistry returns the task registry
-func (w *Worker) GetTaskRegistry() *tasks.TaskRegistry {
- return w.registry
-}
-
// connectionMonitorLoop monitors connection status
func (w *Worker) connectionMonitorLoop() {
ticker := time.NewTicker(30 * time.Second) // Check every 30 seconds
@@ -867,34 +796,6 @@ func (w *Worker) connectionMonitorLoop() {
}
}
-// GetConfig returns the worker configuration
-func (w *Worker) GetConfig() *types.WorkerConfig {
- return w.config
-}
-
-// GetPerformanceMetrics returns performance metrics
-func (w *Worker) GetPerformanceMetrics() *types.WorkerPerformance {
-
- uptime := time.Since(w.getStartTime())
- var successRate float64
- totalTasks := w.getCompletedTasks() + w.getFailedTasks()
- if totalTasks > 0 {
- successRate = float64(w.getCompletedTasks()) / float64(totalTasks) * 100
- }
-
- return &types.WorkerPerformance{
- TasksCompleted: w.getCompletedTasks(),
- TasksFailed: w.getFailedTasks(),
- AverageTaskTime: 0, // Would need to track this
- Uptime: uptime,
- SuccessRate: successRate,
- }
-}
-
-func (w *Worker) GetAdmin() AdminClient {
- return w.getAdmin()
-}
-
// messageProcessingLoop processes incoming admin messages
func (w *Worker) messageProcessingLoop() {
glog.Infof("MESSAGE LOOP STARTED: Worker %s message processing loop started", w.id)