chore: remove ~50k lines of unreachable dead code (#8913)
* chore: remove unreachable dead code across the codebase Remove ~50,000 lines of unreachable code identified by static analysis. Major removals: - weed/filer/redis_lua: entire unused Redis Lua filer store implementation - weed/wdclient/net2, resource_pool: unused connection/resource pool packages - weed/plugin/worker/lifecycle: unused lifecycle plugin worker - weed/s3api: unused S3 policy templates, presigned URL IAM, streaming copy, multipart IAM, key rotation, and various SSE helper functions - weed/mq/kafka: unused partition mapping, compression, schema, and protocol functions - weed/mq/offset: unused SQL storage and migration code - weed/worker: unused registry, task, and monitoring functions - weed/query: unused SQL engine, parquet scanner, and type functions - weed/shell: unused EC proportional rebalance functions - weed/storage/erasure_coding/distribution: unused distribution analysis functions - Individual unreachable functions removed from 150+ files across admin, credential, filer, iam, kms, mount, mq, operation, pb, s3api, server, shell, storage, topology, and util packages * fix(s3): reset shared memory store in IAM test to prevent flaky failure TestLoadIAMManagerFromConfig_EmptyConfigWithFallbackKey was flaky because the MemoryStore credential backend is a singleton registered via init(). Earlier tests that create anonymous identities pollute the shared store, causing LookupAnonymous() to unexpectedly return true. Fix by calling Reset() on the memory store before the test runs. * style: run gofmt on changed files * fix: restore KMS functions used by integration tests * fix(plugin): prevent panic on send to closed worker session channel The Plugin.sendToWorker method could panic with "send on closed channel" when a worker disconnected while a message was being sent. The race was between streamSession.close() closing the outgoing channel and sendToWorker writing to it concurrently. Add a done channel to streamSession that is closed before the outgoing channel, and check it in sendToWorker's select to safely detect closed sessions without panicking.
This commit is contained in:
1
.claude/scheduled_tasks.lock
Normal file
1
.claude/scheduled_tasks.lock
Normal file
@@ -0,0 +1 @@
|
||||
{"sessionId":"d6574c47-eafc-4a94-9dce-f9ffea22b53c","pid":10111,"acquiredAt":1775248373916}
|
||||
5
.superset/config.json
Normal file
5
.superset/config.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"setup": [],
|
||||
"teardown": [],
|
||||
"run": []
|
||||
}
|
||||
11
seaweed-volume/Cargo.lock
generated
11
seaweed-volume/Cargo.lock
generated
@@ -2561,6 +2561,15 @@ version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-src"
|
||||
version = "300.5.5+3.5.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f1787d533e03597a7934fd0a765f0d28e94ecc5fb7789f8053b1e699a56f709"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl-sys"
|
||||
version = "0.9.111"
|
||||
@@ -2569,6 +2578,7 @@ checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"openssl-src",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
@@ -4654,6 +4664,7 @@ dependencies = [
|
||||
"memmap2",
|
||||
"mime_guess",
|
||||
"multer",
|
||||
"openssl",
|
||||
"parking_lot 0.12.5",
|
||||
"pprof",
|
||||
"prometheus",
|
||||
|
||||
@@ -447,11 +447,6 @@ type QueueStats = maintenance.QueueStats
|
||||
type WorkerDetailsData = maintenance.WorkerDetailsData
|
||||
type WorkerPerformance = maintenance.WorkerPerformance
|
||||
|
||||
// GetTaskIcon returns the icon CSS class for a task type from its UI provider
|
||||
func GetTaskIcon(taskType MaintenanceTaskType) string {
|
||||
return maintenance.GetTaskIcon(taskType)
|
||||
}
|
||||
|
||||
// Status constants (these are still static)
|
||||
const (
|
||||
TaskStatusPending = maintenance.TaskStatusPending
|
||||
|
||||
@@ -312,29 +312,6 @@ func (h *ClusterHandlers) ShowClusterFilers(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
}
|
||||
|
||||
// ShowClusterBrokers renders the cluster message brokers page
|
||||
func (h *ClusterHandlers) ShowClusterBrokers(w http.ResponseWriter, r *http.Request) {
|
||||
// Get cluster brokers data
|
||||
brokersData, err := h.adminServer.GetClusterBrokers()
|
||||
if err != nil {
|
||||
writeJSONError(w, http.StatusInternalServerError, "Failed to get cluster brokers: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
username := usernameOrDefault(r)
|
||||
brokersData.Username = username
|
||||
|
||||
// Render HTML template
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
brokersComponent := app.ClusterBrokers(*brokersData)
|
||||
viewCtx := layout.NewViewContext(r, username, dash.CSRFTokenFromContext(r.Context()))
|
||||
layoutComponent := layout.Layout(viewCtx, brokersComponent)
|
||||
if err := layoutComponent.Render(r.Context(), w); err != nil {
|
||||
writeJSONError(w, http.StatusInternalServerError, "Failed to render template: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// GetClusterTopology returns the cluster topology as JSON
|
||||
func (h *ClusterHandlers) GetClusterTopology(w http.ResponseWriter, r *http.Request) {
|
||||
topology, err := h.adminServer.GetClusterTopology()
|
||||
|
||||
@@ -78,34 +78,6 @@ func (h *MessageQueueHandlers) ShowTopics(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
}
|
||||
|
||||
// ShowSubscribers renders the message queue subscribers page
|
||||
func (h *MessageQueueHandlers) ShowSubscribers(w http.ResponseWriter, r *http.Request) {
|
||||
// Get subscribers data
|
||||
subscribersData, err := h.adminServer.GetSubscribers()
|
||||
if err != nil {
|
||||
writeJSONError(w, http.StatusInternalServerError, "Failed to get subscribers: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Set username
|
||||
username := dash.UsernameFromContext(r.Context())
|
||||
if username == "" {
|
||||
username = "admin"
|
||||
}
|
||||
subscribersData.Username = username
|
||||
|
||||
// Render HTML template
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
subscribersComponent := app.Subscribers(*subscribersData)
|
||||
viewCtx := layout.NewViewContext(r, username, dash.CSRFTokenFromContext(r.Context()))
|
||||
layoutComponent := layout.Layout(viewCtx, subscribersComponent)
|
||||
err = layoutComponent.Render(r.Context(), w)
|
||||
if err != nil {
|
||||
writeJSONError(w, http.StatusInternalServerError, "Failed to render template: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// ShowTopicDetails renders the topic details page
|
||||
func (h *MessageQueueHandlers) ShowTopicDetails(w http.ResponseWriter, r *http.Request) {
|
||||
// Get topic parameters from URL
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
package maintenance
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
|
||||
)
|
||||
|
||||
// VerifyProtobufConfig demonstrates that the protobuf configuration system is working
|
||||
func VerifyProtobufConfig() error {
|
||||
// Create configuration manager
|
||||
configManager := NewMaintenanceConfigManager()
|
||||
config := configManager.GetConfig()
|
||||
|
||||
// Verify basic configuration
|
||||
if !config.Enabled {
|
||||
return fmt.Errorf("expected config to be enabled by default")
|
||||
}
|
||||
|
||||
if config.ScanIntervalSeconds != 30*60 {
|
||||
return fmt.Errorf("expected scan interval to be 1800 seconds, got %d", config.ScanIntervalSeconds)
|
||||
}
|
||||
|
||||
// Verify policy configuration
|
||||
if config.Policy == nil {
|
||||
return fmt.Errorf("expected policy to be configured")
|
||||
}
|
||||
|
||||
if config.Policy.GlobalMaxConcurrent != 4 {
|
||||
return fmt.Errorf("expected global max concurrent to be 4, got %d", config.Policy.GlobalMaxConcurrent)
|
||||
}
|
||||
|
||||
// Verify task policies
|
||||
vacuumPolicy := config.Policy.TaskPolicies["vacuum"]
|
||||
if vacuumPolicy == nil {
|
||||
return fmt.Errorf("expected vacuum policy to be configured")
|
||||
}
|
||||
|
||||
if !vacuumPolicy.Enabled {
|
||||
return fmt.Errorf("expected vacuum policy to be enabled")
|
||||
}
|
||||
|
||||
// Verify typed configuration access
|
||||
vacuumConfig := vacuumPolicy.GetVacuumConfig()
|
||||
if vacuumConfig == nil {
|
||||
return fmt.Errorf("expected vacuum config to be accessible")
|
||||
}
|
||||
|
||||
if vacuumConfig.GarbageThreshold != 0.3 {
|
||||
return fmt.Errorf("expected garbage threshold to be 0.3, got %f", vacuumConfig.GarbageThreshold)
|
||||
}
|
||||
|
||||
// Verify helper functions work
|
||||
if !IsTaskEnabled(config.Policy, "vacuum") {
|
||||
return fmt.Errorf("expected vacuum task to be enabled via helper function")
|
||||
}
|
||||
|
||||
maxConcurrent := GetMaxConcurrent(config.Policy, "vacuum")
|
||||
if maxConcurrent != 2 {
|
||||
return fmt.Errorf("expected vacuum max concurrent to be 2, got %d", maxConcurrent)
|
||||
}
|
||||
|
||||
// Verify erasure coding configuration
|
||||
ecPolicy := config.Policy.TaskPolicies["erasure_coding"]
|
||||
if ecPolicy == nil {
|
||||
return fmt.Errorf("expected EC policy to be configured")
|
||||
}
|
||||
|
||||
ecConfig := ecPolicy.GetErasureCodingConfig()
|
||||
if ecConfig == nil {
|
||||
return fmt.Errorf("expected EC config to be accessible")
|
||||
}
|
||||
|
||||
// Verify configurable EC fields only
|
||||
if ecConfig.FullnessRatio <= 0 || ecConfig.FullnessRatio > 1 {
|
||||
return fmt.Errorf("expected EC config to have valid fullness ratio (0-1), got %f", ecConfig.FullnessRatio)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProtobufConfigSummary returns a summary of the current protobuf configuration
|
||||
func GetProtobufConfigSummary() string {
|
||||
configManager := NewMaintenanceConfigManager()
|
||||
config := configManager.GetConfig()
|
||||
|
||||
summary := fmt.Sprintf("SeaweedFS Protobuf Maintenance Configuration:\n")
|
||||
summary += fmt.Sprintf(" Enabled: %v\n", config.Enabled)
|
||||
summary += fmt.Sprintf(" Scan Interval: %d seconds\n", config.ScanIntervalSeconds)
|
||||
summary += fmt.Sprintf(" Max Retries: %d\n", config.MaxRetries)
|
||||
summary += fmt.Sprintf(" Global Max Concurrent: %d\n", config.Policy.GlobalMaxConcurrent)
|
||||
summary += fmt.Sprintf(" Task Policies: %d configured\n", len(config.Policy.TaskPolicies))
|
||||
|
||||
for taskType, policy := range config.Policy.TaskPolicies {
|
||||
summary += fmt.Sprintf(" %s: enabled=%v, max_concurrent=%d\n",
|
||||
taskType, policy.Enabled, policy.MaxConcurrent)
|
||||
}
|
||||
|
||||
return summary
|
||||
}
|
||||
|
||||
// CreateCustomConfig demonstrates creating a custom protobuf configuration
|
||||
func CreateCustomConfig() *worker_pb.MaintenanceConfig {
|
||||
return &worker_pb.MaintenanceConfig{
|
||||
Enabled: true,
|
||||
ScanIntervalSeconds: 60 * 60, // 1 hour
|
||||
MaxRetries: 5,
|
||||
Policy: &worker_pb.MaintenancePolicy{
|
||||
GlobalMaxConcurrent: 8,
|
||||
TaskPolicies: map[string]*worker_pb.TaskPolicy{
|
||||
"custom_vacuum": {
|
||||
Enabled: true,
|
||||
MaxConcurrent: 4,
|
||||
TaskConfig: &worker_pb.TaskPolicy_VacuumConfig{
|
||||
VacuumConfig: &worker_pb.VacuumTaskConfig{
|
||||
GarbageThreshold: 0.5,
|
||||
MinVolumeAgeHours: 48,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,24 +1,9 @@
|
||||
package maintenance
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
|
||||
)
|
||||
|
||||
// MaintenanceConfigManager handles protobuf-based configuration
|
||||
type MaintenanceConfigManager struct {
|
||||
config *worker_pb.MaintenanceConfig
|
||||
}
|
||||
|
||||
// NewMaintenanceConfigManager creates a new config manager with defaults
|
||||
func NewMaintenanceConfigManager() *MaintenanceConfigManager {
|
||||
return &MaintenanceConfigManager{
|
||||
config: DefaultMaintenanceConfigProto(),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultMaintenanceConfigProto returns default configuration as protobuf
|
||||
func DefaultMaintenanceConfigProto() *worker_pb.MaintenanceConfig {
|
||||
return &worker_pb.MaintenanceConfig{
|
||||
@@ -34,253 +19,3 @@ func DefaultMaintenanceConfigProto() *worker_pb.MaintenanceConfig {
|
||||
Policy: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig returns the current configuration
|
||||
func (mcm *MaintenanceConfigManager) GetConfig() *worker_pb.MaintenanceConfig {
|
||||
return mcm.config
|
||||
}
|
||||
|
||||
// Type-safe configuration accessors
|
||||
|
||||
// GetVacuumConfig returns vacuum-specific configuration for a task type
|
||||
func (mcm *MaintenanceConfigManager) GetVacuumConfig(taskType string) *worker_pb.VacuumTaskConfig {
|
||||
if policy := mcm.getTaskPolicy(taskType); policy != nil {
|
||||
if vacuumConfig := policy.GetVacuumConfig(); vacuumConfig != nil {
|
||||
return vacuumConfig
|
||||
}
|
||||
}
|
||||
// Return defaults if not configured
|
||||
return &worker_pb.VacuumTaskConfig{
|
||||
GarbageThreshold: 0.3,
|
||||
MinVolumeAgeHours: 24,
|
||||
}
|
||||
}
|
||||
|
||||
// GetErasureCodingConfig returns EC-specific configuration for a task type
|
||||
func (mcm *MaintenanceConfigManager) GetErasureCodingConfig(taskType string) *worker_pb.ErasureCodingTaskConfig {
|
||||
if policy := mcm.getTaskPolicy(taskType); policy != nil {
|
||||
if ecConfig := policy.GetErasureCodingConfig(); ecConfig != nil {
|
||||
return ecConfig
|
||||
}
|
||||
}
|
||||
// Return defaults if not configured
|
||||
return &worker_pb.ErasureCodingTaskConfig{
|
||||
FullnessRatio: 0.95,
|
||||
QuietForSeconds: 3600,
|
||||
MinVolumeSizeMb: 100,
|
||||
CollectionFilter: "",
|
||||
}
|
||||
}
|
||||
|
||||
// GetBalanceConfig returns balance-specific configuration for a task type
|
||||
func (mcm *MaintenanceConfigManager) GetBalanceConfig(taskType string) *worker_pb.BalanceTaskConfig {
|
||||
if policy := mcm.getTaskPolicy(taskType); policy != nil {
|
||||
if balanceConfig := policy.GetBalanceConfig(); balanceConfig != nil {
|
||||
return balanceConfig
|
||||
}
|
||||
}
|
||||
// Return defaults if not configured
|
||||
return &worker_pb.BalanceTaskConfig{
|
||||
ImbalanceThreshold: 0.2,
|
||||
MinServerCount: 2,
|
||||
}
|
||||
}
|
||||
|
||||
// GetReplicationConfig returns replication-specific configuration for a task type
|
||||
func (mcm *MaintenanceConfigManager) GetReplicationConfig(taskType string) *worker_pb.ReplicationTaskConfig {
|
||||
if policy := mcm.getTaskPolicy(taskType); policy != nil {
|
||||
if replicationConfig := policy.GetReplicationConfig(); replicationConfig != nil {
|
||||
return replicationConfig
|
||||
}
|
||||
}
|
||||
// Return defaults if not configured
|
||||
return &worker_pb.ReplicationTaskConfig{
|
||||
TargetReplicaCount: 2,
|
||||
}
|
||||
}
|
||||
|
||||
// Typed convenience methods for getting task configurations
|
||||
|
||||
// GetVacuumTaskConfigForType returns vacuum configuration for a specific task type
|
||||
func (mcm *MaintenanceConfigManager) GetVacuumTaskConfigForType(taskType string) *worker_pb.VacuumTaskConfig {
|
||||
return GetVacuumTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
|
||||
}
|
||||
|
||||
// GetErasureCodingTaskConfigForType returns erasure coding configuration for a specific task type
|
||||
func (mcm *MaintenanceConfigManager) GetErasureCodingTaskConfigForType(taskType string) *worker_pb.ErasureCodingTaskConfig {
|
||||
return GetErasureCodingTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
|
||||
}
|
||||
|
||||
// GetBalanceTaskConfigForType returns balance configuration for a specific task type
|
||||
func (mcm *MaintenanceConfigManager) GetBalanceTaskConfigForType(taskType string) *worker_pb.BalanceTaskConfig {
|
||||
return GetBalanceTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
|
||||
}
|
||||
|
||||
// GetReplicationTaskConfigForType returns replication configuration for a specific task type
|
||||
func (mcm *MaintenanceConfigManager) GetReplicationTaskConfigForType(taskType string) *worker_pb.ReplicationTaskConfig {
|
||||
return GetReplicationTaskConfig(mcm.config.Policy, MaintenanceTaskType(taskType))
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
func (mcm *MaintenanceConfigManager) getTaskPolicy(taskType string) *worker_pb.TaskPolicy {
|
||||
if mcm.config.Policy != nil && mcm.config.Policy.TaskPolicies != nil {
|
||||
return mcm.config.Policy.TaskPolicies[taskType]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsTaskEnabled returns whether a task type is enabled
|
||||
func (mcm *MaintenanceConfigManager) IsTaskEnabled(taskType string) bool {
|
||||
if policy := mcm.getTaskPolicy(taskType); policy != nil {
|
||||
return policy.Enabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetMaxConcurrent returns the max concurrent limit for a task type
|
||||
func (mcm *MaintenanceConfigManager) GetMaxConcurrent(taskType string) int32 {
|
||||
if policy := mcm.getTaskPolicy(taskType); policy != nil {
|
||||
return policy.MaxConcurrent
|
||||
}
|
||||
return 1 // Default
|
||||
}
|
||||
|
||||
// GetRepeatInterval returns the repeat interval for a task type in seconds
|
||||
func (mcm *MaintenanceConfigManager) GetRepeatInterval(taskType string) int32 {
|
||||
if policy := mcm.getTaskPolicy(taskType); policy != nil {
|
||||
return policy.RepeatIntervalSeconds
|
||||
}
|
||||
return mcm.config.Policy.DefaultRepeatIntervalSeconds
|
||||
}
|
||||
|
||||
// GetCheckInterval returns the check interval for a task type in seconds
|
||||
func (mcm *MaintenanceConfigManager) GetCheckInterval(taskType string) int32 {
|
||||
if policy := mcm.getTaskPolicy(taskType); policy != nil {
|
||||
return policy.CheckIntervalSeconds
|
||||
}
|
||||
return mcm.config.Policy.DefaultCheckIntervalSeconds
|
||||
}
|
||||
|
||||
// Duration accessor methods
|
||||
|
||||
// GetScanInterval returns the scan interval as a time.Duration
|
||||
func (mcm *MaintenanceConfigManager) GetScanInterval() time.Duration {
|
||||
return time.Duration(mcm.config.ScanIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
// GetWorkerTimeout returns the worker timeout as a time.Duration
|
||||
func (mcm *MaintenanceConfigManager) GetWorkerTimeout() time.Duration {
|
||||
return time.Duration(mcm.config.WorkerTimeoutSeconds) * time.Second
|
||||
}
|
||||
|
||||
// GetTaskTimeout returns the task timeout as a time.Duration
|
||||
func (mcm *MaintenanceConfigManager) GetTaskTimeout() time.Duration {
|
||||
return time.Duration(mcm.config.TaskTimeoutSeconds) * time.Second
|
||||
}
|
||||
|
||||
// GetRetryDelay returns the retry delay as a time.Duration
|
||||
func (mcm *MaintenanceConfigManager) GetRetryDelay() time.Duration {
|
||||
return time.Duration(mcm.config.RetryDelaySeconds) * time.Second
|
||||
}
|
||||
|
||||
// GetCleanupInterval returns the cleanup interval as a time.Duration
|
||||
func (mcm *MaintenanceConfigManager) GetCleanupInterval() time.Duration {
|
||||
return time.Duration(mcm.config.CleanupIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
// GetTaskRetention returns the task retention period as a time.Duration
|
||||
func (mcm *MaintenanceConfigManager) GetTaskRetention() time.Duration {
|
||||
return time.Duration(mcm.config.TaskRetentionSeconds) * time.Second
|
||||
}
|
||||
|
||||
// ValidateMaintenanceConfigWithSchema validates protobuf maintenance configuration using ConfigField rules
|
||||
func ValidateMaintenanceConfigWithSchema(config *worker_pb.MaintenanceConfig) error {
|
||||
if config == nil {
|
||||
return fmt.Errorf("configuration cannot be nil")
|
||||
}
|
||||
|
||||
// Get the schema to access field validation rules
|
||||
schema := GetMaintenanceConfigSchema()
|
||||
|
||||
// Validate each field individually using the ConfigField rules
|
||||
if err := validateFieldWithSchema(schema, "enabled", config.Enabled); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateFieldWithSchema(schema, "scan_interval_seconds", int(config.ScanIntervalSeconds)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateFieldWithSchema(schema, "worker_timeout_seconds", int(config.WorkerTimeoutSeconds)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateFieldWithSchema(schema, "task_timeout_seconds", int(config.TaskTimeoutSeconds)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateFieldWithSchema(schema, "retry_delay_seconds", int(config.RetryDelaySeconds)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateFieldWithSchema(schema, "max_retries", int(config.MaxRetries)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateFieldWithSchema(schema, "cleanup_interval_seconds", int(config.CleanupIntervalSeconds)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateFieldWithSchema(schema, "task_retention_seconds", int(config.TaskRetentionSeconds)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate policy fields if present
|
||||
if config.Policy != nil {
|
||||
// Note: These field names might need to be adjusted based on the actual schema
|
||||
if err := validatePolicyField("global_max_concurrent", int(config.Policy.GlobalMaxConcurrent)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validatePolicyField("default_repeat_interval_seconds", int(config.Policy.DefaultRepeatIntervalSeconds)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validatePolicyField("default_check_interval_seconds", int(config.Policy.DefaultCheckIntervalSeconds)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateFieldWithSchema validates a single field using its ConfigField definition
|
||||
func validateFieldWithSchema(schema *MaintenanceConfigSchema, fieldName string, value interface{}) error {
|
||||
field := schema.GetFieldByName(fieldName)
|
||||
if field == nil {
|
||||
// Field not in schema, skip validation
|
||||
return nil
|
||||
}
|
||||
|
||||
return field.ValidateValue(value)
|
||||
}
|
||||
|
||||
// validatePolicyField validates policy fields (simplified validation for now)
|
||||
func validatePolicyField(fieldName string, value int) error {
|
||||
switch fieldName {
|
||||
case "global_max_concurrent":
|
||||
if value < 1 || value > 20 {
|
||||
return fmt.Errorf("Global Max Concurrent must be between 1 and 20, got %d", value)
|
||||
}
|
||||
case "default_repeat_interval":
|
||||
if value < 1 || value > 168 {
|
||||
return fmt.Errorf("Default Repeat Interval must be between 1 and 168 hours, got %d", value)
|
||||
}
|
||||
case "default_check_interval":
|
||||
if value < 1 || value > 168 {
|
||||
return fmt.Errorf("Default Check Interval must be between 1 and 168 hours, got %d", value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1055,28 +1055,6 @@ func (mq *MaintenanceQueue) getMaxConcurrentForTaskType(taskType MaintenanceTask
|
||||
return 1
|
||||
}
|
||||
|
||||
// getRunningTasks returns all currently running tasks
|
||||
func (mq *MaintenanceQueue) getRunningTasks() []*MaintenanceTask {
|
||||
var runningTasks []*MaintenanceTask
|
||||
for _, task := range mq.tasks {
|
||||
if task.Status == TaskStatusAssigned || task.Status == TaskStatusInProgress {
|
||||
runningTasks = append(runningTasks, task)
|
||||
}
|
||||
}
|
||||
return runningTasks
|
||||
}
|
||||
|
||||
// getAvailableWorkers returns all workers that can take more work
|
||||
func (mq *MaintenanceQueue) getAvailableWorkers() []*MaintenanceWorker {
|
||||
var availableWorkers []*MaintenanceWorker
|
||||
for _, worker := range mq.workers {
|
||||
if worker.Status == "active" && worker.CurrentLoad < worker.MaxConcurrent {
|
||||
availableWorkers = append(availableWorkers, worker)
|
||||
}
|
||||
}
|
||||
return availableWorkers
|
||||
}
|
||||
|
||||
// trackPendingOperation adds a task to the pending operations tracker
|
||||
func (mq *MaintenanceQueue) trackPendingOperation(task *MaintenanceTask) {
|
||||
if mq.integration == nil {
|
||||
|
||||
@@ -2,15 +2,11 @@ package maintenance
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/master_pb"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
|
||||
"github.com/seaweedfs/seaweedfs/weed/worker/tasks"
|
||||
"github.com/seaweedfs/seaweedfs/weed/worker/types"
|
||||
)
|
||||
|
||||
// AdminClient interface defines what the maintenance system needs from the admin server
|
||||
@@ -21,51 +17,6 @@ type AdminClient interface {
|
||||
// MaintenanceTaskType represents different types of maintenance operations
|
||||
type MaintenanceTaskType string
|
||||
|
||||
// GetRegisteredMaintenanceTaskTypes returns all registered task types as MaintenanceTaskType values
|
||||
// sorted alphabetically for consistent menu ordering
|
||||
func GetRegisteredMaintenanceTaskTypes() []MaintenanceTaskType {
|
||||
typesRegistry := tasks.GetGlobalTypesRegistry()
|
||||
var taskTypes []MaintenanceTaskType
|
||||
|
||||
for workerTaskType := range typesRegistry.GetAllDetectors() {
|
||||
maintenanceTaskType := MaintenanceTaskType(string(workerTaskType))
|
||||
taskTypes = append(taskTypes, maintenanceTaskType)
|
||||
}
|
||||
|
||||
// Sort task types alphabetically to ensure consistent menu ordering
|
||||
sort.Slice(taskTypes, func(i, j int) bool {
|
||||
return string(taskTypes[i]) < string(taskTypes[j])
|
||||
})
|
||||
|
||||
return taskTypes
|
||||
}
|
||||
|
||||
// GetMaintenanceTaskType returns a specific task type if it's registered, or empty string if not found
|
||||
func GetMaintenanceTaskType(taskTypeName string) MaintenanceTaskType {
|
||||
typesRegistry := tasks.GetGlobalTypesRegistry()
|
||||
|
||||
for workerTaskType := range typesRegistry.GetAllDetectors() {
|
||||
if string(workerTaskType) == taskTypeName {
|
||||
return MaintenanceTaskType(taskTypeName)
|
||||
}
|
||||
}
|
||||
|
||||
return MaintenanceTaskType("")
|
||||
}
|
||||
|
||||
// IsMaintenanceTaskTypeRegistered checks if a task type is registered
|
||||
func IsMaintenanceTaskTypeRegistered(taskType MaintenanceTaskType) bool {
|
||||
typesRegistry := tasks.GetGlobalTypesRegistry()
|
||||
|
||||
for workerTaskType := range typesRegistry.GetAllDetectors() {
|
||||
if string(workerTaskType) == string(taskType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// MaintenanceTaskPriority represents task execution priority
|
||||
type MaintenanceTaskPriority int
|
||||
|
||||
@@ -200,14 +151,6 @@ func GetTaskPolicy(mp *MaintenancePolicy, taskType MaintenanceTaskType) *TaskPol
|
||||
return mp.TaskPolicies[string(taskType)]
|
||||
}
|
||||
|
||||
// SetTaskPolicy sets the policy for a specific task type
|
||||
func SetTaskPolicy(mp *MaintenancePolicy, taskType MaintenanceTaskType, policy *TaskPolicy) {
|
||||
if mp.TaskPolicies == nil {
|
||||
mp.TaskPolicies = make(map[string]*TaskPolicy)
|
||||
}
|
||||
mp.TaskPolicies[string(taskType)] = policy
|
||||
}
|
||||
|
||||
// IsTaskEnabled returns whether a task type is enabled
|
||||
func IsTaskEnabled(mp *MaintenancePolicy, taskType MaintenanceTaskType) bool {
|
||||
policy := GetTaskPolicy(mp, taskType)
|
||||
@@ -235,84 +178,6 @@ func GetRepeatInterval(mp *MaintenancePolicy, taskType MaintenanceTaskType) int
|
||||
return int(policy.RepeatIntervalSeconds)
|
||||
}
|
||||
|
||||
// GetVacuumTaskConfig returns the vacuum task configuration
|
||||
func GetVacuumTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.VacuumTaskConfig {
|
||||
policy := GetTaskPolicy(mp, taskType)
|
||||
if policy == nil {
|
||||
return nil
|
||||
}
|
||||
return policy.GetVacuumConfig()
|
||||
}
|
||||
|
||||
// GetErasureCodingTaskConfig returns the erasure coding task configuration
|
||||
func GetErasureCodingTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.ErasureCodingTaskConfig {
|
||||
policy := GetTaskPolicy(mp, taskType)
|
||||
if policy == nil {
|
||||
return nil
|
||||
}
|
||||
return policy.GetErasureCodingConfig()
|
||||
}
|
||||
|
||||
// GetBalanceTaskConfig returns the balance task configuration
|
||||
func GetBalanceTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.BalanceTaskConfig {
|
||||
policy := GetTaskPolicy(mp, taskType)
|
||||
if policy == nil {
|
||||
return nil
|
||||
}
|
||||
return policy.GetBalanceConfig()
|
||||
}
|
||||
|
||||
// GetReplicationTaskConfig returns the replication task configuration
|
||||
func GetReplicationTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType) *worker_pb.ReplicationTaskConfig {
|
||||
policy := GetTaskPolicy(mp, taskType)
|
||||
if policy == nil {
|
||||
return nil
|
||||
}
|
||||
return policy.GetReplicationConfig()
|
||||
}
|
||||
|
||||
// Note: GetTaskConfig was removed - use typed getters: GetVacuumTaskConfig, GetErasureCodingTaskConfig, GetBalanceTaskConfig, or GetReplicationTaskConfig
|
||||
|
||||
// SetVacuumTaskConfig sets the vacuum task configuration
|
||||
func SetVacuumTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.VacuumTaskConfig) {
|
||||
policy := GetTaskPolicy(mp, taskType)
|
||||
if policy != nil {
|
||||
policy.TaskConfig = &worker_pb.TaskPolicy_VacuumConfig{
|
||||
VacuumConfig: config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetErasureCodingTaskConfig sets the erasure coding task configuration
|
||||
func SetErasureCodingTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.ErasureCodingTaskConfig) {
|
||||
policy := GetTaskPolicy(mp, taskType)
|
||||
if policy != nil {
|
||||
policy.TaskConfig = &worker_pb.TaskPolicy_ErasureCodingConfig{
|
||||
ErasureCodingConfig: config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetBalanceTaskConfig sets the balance task configuration
|
||||
func SetBalanceTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.BalanceTaskConfig) {
|
||||
policy := GetTaskPolicy(mp, taskType)
|
||||
if policy != nil {
|
||||
policy.TaskConfig = &worker_pb.TaskPolicy_BalanceConfig{
|
||||
BalanceConfig: config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetReplicationTaskConfig sets the replication task configuration
|
||||
func SetReplicationTaskConfig(mp *MaintenancePolicy, taskType MaintenanceTaskType, config *worker_pb.ReplicationTaskConfig) {
|
||||
policy := GetTaskPolicy(mp, taskType)
|
||||
if policy != nil {
|
||||
policy.TaskConfig = &worker_pb.TaskPolicy_ReplicationConfig{
|
||||
ReplicationConfig: config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetTaskConfig sets a configuration value for a task type (legacy method - use typed setters above)
|
||||
// Note: SetTaskConfig was removed - use typed setters: SetVacuumTaskConfig, SetErasureCodingTaskConfig, SetBalanceTaskConfig, or SetReplicationTaskConfig
|
||||
|
||||
@@ -475,180 +340,6 @@ type ClusterReplicationTask struct {
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// BuildMaintenancePolicyFromTasks creates a maintenance policy with configurations
|
||||
// from all registered tasks using their UI providers
|
||||
func BuildMaintenancePolicyFromTasks() *MaintenancePolicy {
|
||||
policy := &MaintenancePolicy{
|
||||
TaskPolicies: make(map[string]*TaskPolicy),
|
||||
GlobalMaxConcurrent: 4,
|
||||
DefaultRepeatIntervalSeconds: 6 * 3600, // 6 hours in seconds
|
||||
DefaultCheckIntervalSeconds: 12 * 3600, // 12 hours in seconds
|
||||
}
|
||||
|
||||
// Get all registered task types from the UI registry
|
||||
uiRegistry := tasks.GetGlobalUIRegistry()
|
||||
typesRegistry := tasks.GetGlobalTypesRegistry()
|
||||
|
||||
for taskType, provider := range uiRegistry.GetAllProviders() {
|
||||
// Convert task type to maintenance task type
|
||||
maintenanceTaskType := MaintenanceTaskType(string(taskType))
|
||||
|
||||
// Get the default configuration from the UI provider
|
||||
defaultConfig := provider.GetCurrentConfig()
|
||||
|
||||
// Create task policy from UI configuration
|
||||
taskPolicy := &TaskPolicy{
|
||||
Enabled: true, // Default enabled
|
||||
MaxConcurrent: 2, // Default concurrency
|
||||
RepeatIntervalSeconds: policy.DefaultRepeatIntervalSeconds,
|
||||
CheckIntervalSeconds: policy.DefaultCheckIntervalSeconds,
|
||||
}
|
||||
|
||||
// Extract configuration using TaskConfig interface - no more map conversions!
|
||||
if taskConfig, ok := defaultConfig.(interface{ ToTaskPolicy() *worker_pb.TaskPolicy }); ok {
|
||||
// Use protobuf directly for clean, type-safe config extraction
|
||||
pbTaskPolicy := taskConfig.ToTaskPolicy()
|
||||
taskPolicy.Enabled = pbTaskPolicy.Enabled
|
||||
taskPolicy.MaxConcurrent = pbTaskPolicy.MaxConcurrent
|
||||
if pbTaskPolicy.RepeatIntervalSeconds > 0 {
|
||||
taskPolicy.RepeatIntervalSeconds = pbTaskPolicy.RepeatIntervalSeconds
|
||||
}
|
||||
if pbTaskPolicy.CheckIntervalSeconds > 0 {
|
||||
taskPolicy.CheckIntervalSeconds = pbTaskPolicy.CheckIntervalSeconds
|
||||
}
|
||||
}
|
||||
|
||||
// Also get defaults from scheduler if available (using types.TaskScheduler explicitly)
|
||||
var scheduler types.TaskScheduler = typesRegistry.GetScheduler(taskType)
|
||||
if scheduler != nil {
|
||||
if taskPolicy.MaxConcurrent <= 0 {
|
||||
taskPolicy.MaxConcurrent = int32(scheduler.GetMaxConcurrent())
|
||||
}
|
||||
// Convert default repeat interval to seconds
|
||||
if repeatInterval := scheduler.GetDefaultRepeatInterval(); repeatInterval > 0 {
|
||||
taskPolicy.RepeatIntervalSeconds = int32(repeatInterval.Seconds())
|
||||
}
|
||||
}
|
||||
|
||||
// Also get defaults from detector if available (using types.TaskDetector explicitly)
|
||||
var detector types.TaskDetector = typesRegistry.GetDetector(taskType)
|
||||
if detector != nil {
|
||||
// Convert scan interval to check interval (seconds)
|
||||
if scanInterval := detector.ScanInterval(); scanInterval > 0 {
|
||||
taskPolicy.CheckIntervalSeconds = int32(scanInterval.Seconds())
|
||||
}
|
||||
}
|
||||
|
||||
policy.TaskPolicies[string(maintenanceTaskType)] = taskPolicy
|
||||
glog.V(3).Infof("Built policy for task type %s: enabled=%v, max_concurrent=%d",
|
||||
maintenanceTaskType, taskPolicy.Enabled, taskPolicy.MaxConcurrent)
|
||||
}
|
||||
|
||||
glog.V(2).Infof("Built maintenance policy with %d task configurations", len(policy.TaskPolicies))
|
||||
return policy
|
||||
}
|
||||
|
||||
// SetPolicyFromTasks sets the maintenance policy from registered tasks
|
||||
func SetPolicyFromTasks(policy *MaintenancePolicy) {
|
||||
if policy == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Build new policy from tasks
|
||||
newPolicy := BuildMaintenancePolicyFromTasks()
|
||||
|
||||
// Copy task policies
|
||||
policy.TaskPolicies = newPolicy.TaskPolicies
|
||||
|
||||
glog.V(1).Infof("Updated maintenance policy with %d task configurations from registered tasks", len(policy.TaskPolicies))
|
||||
}
|
||||
|
||||
// GetTaskIcon returns the icon CSS class for a task type from its UI provider
|
||||
func GetTaskIcon(taskType MaintenanceTaskType) string {
|
||||
typesRegistry := tasks.GetGlobalTypesRegistry()
|
||||
uiRegistry := tasks.GetGlobalUIRegistry()
|
||||
|
||||
// Convert MaintenanceTaskType to TaskType
|
||||
for workerTaskType := range typesRegistry.GetAllDetectors() {
|
||||
if string(workerTaskType) == string(taskType) {
|
||||
// Get the UI provider for this task type
|
||||
provider := uiRegistry.GetProvider(workerTaskType)
|
||||
if provider != nil {
|
||||
return provider.GetIcon()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Default icon if no UI provider found
|
||||
return "fas fa-cog text-muted"
|
||||
}
|
||||
|
||||
// GetTaskDisplayName returns the display name for a task type from its UI provider
|
||||
func GetTaskDisplayName(taskType MaintenanceTaskType) string {
|
||||
typesRegistry := tasks.GetGlobalTypesRegistry()
|
||||
uiRegistry := tasks.GetGlobalUIRegistry()
|
||||
|
||||
// Convert MaintenanceTaskType to TaskType
|
||||
for workerTaskType := range typesRegistry.GetAllDetectors() {
|
||||
if string(workerTaskType) == string(taskType) {
|
||||
// Get the UI provider for this task type
|
||||
provider := uiRegistry.GetProvider(workerTaskType)
|
||||
if provider != nil {
|
||||
return provider.GetDisplayName()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to the task type string
|
||||
return string(taskType)
|
||||
}
|
||||
|
||||
// GetTaskDescription returns the description for a task type from its UI provider
|
||||
func GetTaskDescription(taskType MaintenanceTaskType) string {
|
||||
typesRegistry := tasks.GetGlobalTypesRegistry()
|
||||
uiRegistry := tasks.GetGlobalUIRegistry()
|
||||
|
||||
// Convert MaintenanceTaskType to TaskType
|
||||
for workerTaskType := range typesRegistry.GetAllDetectors() {
|
||||
if string(workerTaskType) == string(taskType) {
|
||||
// Get the UI provider for this task type
|
||||
provider := uiRegistry.GetProvider(workerTaskType)
|
||||
if provider != nil {
|
||||
return provider.GetDescription()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to a generic description
|
||||
return "Configure detailed settings for " + string(taskType) + " tasks."
|
||||
}
|
||||
|
||||
// BuildMaintenanceMenuItems creates menu items for all registered task types
|
||||
func BuildMaintenanceMenuItems() []*MaintenanceMenuItem {
|
||||
var menuItems []*MaintenanceMenuItem
|
||||
|
||||
// Get all registered task types
|
||||
registeredTypes := GetRegisteredMaintenanceTaskTypes()
|
||||
|
||||
for _, taskType := range registeredTypes {
|
||||
menuItem := &MaintenanceMenuItem{
|
||||
TaskType: taskType,
|
||||
DisplayName: GetTaskDisplayName(taskType),
|
||||
Description: GetTaskDescription(taskType),
|
||||
Icon: GetTaskIcon(taskType),
|
||||
IsEnabled: IsMaintenanceTaskTypeRegistered(taskType),
|
||||
Path: "/maintenance/config/" + string(taskType),
|
||||
}
|
||||
|
||||
menuItems = append(menuItems, menuItem)
|
||||
}
|
||||
|
||||
return menuItems
|
||||
}
|
||||
|
||||
// Helper functions to extract configuration fields
|
||||
|
||||
// Note: Removed getVacuumConfigField, getErasureCodingConfigField, getBalanceConfigField, getReplicationConfigField
|
||||
|
||||
@@ -1,421 +0,0 @@
|
||||
package maintenance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"github.com/seaweedfs/seaweedfs/weed/worker"
|
||||
"github.com/seaweedfs/seaweedfs/weed/worker/tasks"
|
||||
"github.com/seaweedfs/seaweedfs/weed/worker/types"
|
||||
|
||||
// Import task packages to trigger their auto-registration
|
||||
_ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/balance"
|
||||
_ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/erasure_coding"
|
||||
_ "github.com/seaweedfs/seaweedfs/weed/worker/tasks/vacuum"
|
||||
)
|
||||
|
||||
// MaintenanceWorkerService manages maintenance task execution
|
||||
// TaskExecutor defines the function signature for task execution
|
||||
type TaskExecutor func(*MaintenanceWorkerService, *MaintenanceTask) error
|
||||
|
||||
// TaskExecutorFactory creates a task executor for a given worker service
|
||||
type TaskExecutorFactory func() TaskExecutor
|
||||
|
||||
// Global registry for task executor factories
|
||||
var taskExecutorFactories = make(map[MaintenanceTaskType]TaskExecutorFactory)
|
||||
var executorRegistryMutex sync.RWMutex
|
||||
var executorRegistryInitOnce sync.Once
|
||||
|
||||
// initializeExecutorFactories dynamically registers executor factories for all auto-registered task types
|
||||
func initializeExecutorFactories() {
|
||||
executorRegistryInitOnce.Do(func() {
|
||||
// Get all registered task types from the global registry
|
||||
typesRegistry := tasks.GetGlobalTypesRegistry()
|
||||
|
||||
var taskTypes []MaintenanceTaskType
|
||||
for workerTaskType := range typesRegistry.GetAllDetectors() {
|
||||
// Convert types.TaskType to MaintenanceTaskType by string conversion
|
||||
maintenanceTaskType := MaintenanceTaskType(string(workerTaskType))
|
||||
taskTypes = append(taskTypes, maintenanceTaskType)
|
||||
}
|
||||
|
||||
// Register generic executor for all task types
|
||||
for _, taskType := range taskTypes {
|
||||
RegisterTaskExecutorFactory(taskType, createGenericTaskExecutor)
|
||||
}
|
||||
|
||||
glog.V(1).Infof("Dynamically registered generic task executor for %d task types: %v", len(taskTypes), taskTypes)
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterTaskExecutorFactory registers a factory function for creating task executors
|
||||
func RegisterTaskExecutorFactory(taskType MaintenanceTaskType, factory TaskExecutorFactory) {
|
||||
executorRegistryMutex.Lock()
|
||||
defer executorRegistryMutex.Unlock()
|
||||
taskExecutorFactories[taskType] = factory
|
||||
glog.V(2).Infof("Registered executor factory for task type: %s", taskType)
|
||||
}
|
||||
|
||||
// GetTaskExecutorFactory returns the factory for a task type
|
||||
func GetTaskExecutorFactory(taskType MaintenanceTaskType) (TaskExecutorFactory, bool) {
|
||||
// Ensure executor factories are initialized
|
||||
initializeExecutorFactories()
|
||||
|
||||
executorRegistryMutex.RLock()
|
||||
defer executorRegistryMutex.RUnlock()
|
||||
factory, exists := taskExecutorFactories[taskType]
|
||||
return factory, exists
|
||||
}
|
||||
|
||||
// GetSupportedExecutorTaskTypes returns all task types with registered executor factories
|
||||
func GetSupportedExecutorTaskTypes() []MaintenanceTaskType {
|
||||
// Ensure executor factories are initialized
|
||||
initializeExecutorFactories()
|
||||
|
||||
executorRegistryMutex.RLock()
|
||||
defer executorRegistryMutex.RUnlock()
|
||||
|
||||
taskTypes := make([]MaintenanceTaskType, 0, len(taskExecutorFactories))
|
||||
for taskType := range taskExecutorFactories {
|
||||
taskTypes = append(taskTypes, taskType)
|
||||
}
|
||||
return taskTypes
|
||||
}
|
||||
|
||||
// createGenericTaskExecutor creates a generic task executor that uses the task registry
|
||||
func createGenericTaskExecutor() TaskExecutor {
|
||||
return func(mws *MaintenanceWorkerService, task *MaintenanceTask) error {
|
||||
return mws.executeGenericTask(task)
|
||||
}
|
||||
}
|
||||
|
||||
// init does minimal initialization - actual registration happens lazily
|
||||
func init() {
|
||||
// Executor factory registration will happen lazily when first accessed
|
||||
glog.V(1).Infof("Maintenance worker initialized - executor factories will be registered on first access")
|
||||
}
|
||||
|
||||
type MaintenanceWorkerService struct {
|
||||
workerID string
|
||||
address string
|
||||
adminServer string
|
||||
capabilities []MaintenanceTaskType
|
||||
maxConcurrent int
|
||||
currentTasks map[string]*MaintenanceTask
|
||||
queue *MaintenanceQueue
|
||||
adminClient AdminClient
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
|
||||
// Task execution registry
|
||||
taskExecutors map[MaintenanceTaskType]TaskExecutor
|
||||
|
||||
// Task registry for creating task instances
|
||||
taskRegistry *tasks.TaskRegistry
|
||||
}
|
||||
|
||||
// NewMaintenanceWorkerService creates a new maintenance worker service
|
||||
func NewMaintenanceWorkerService(workerID, address, adminServer string) *MaintenanceWorkerService {
|
||||
// Get all registered maintenance task types dynamically
|
||||
capabilities := GetRegisteredMaintenanceTaskTypes()
|
||||
|
||||
worker := &MaintenanceWorkerService{
|
||||
workerID: workerID,
|
||||
address: address,
|
||||
adminServer: adminServer,
|
||||
capabilities: capabilities,
|
||||
maxConcurrent: 2, // Default concurrent task limit
|
||||
currentTasks: make(map[string]*MaintenanceTask),
|
||||
stopChan: make(chan struct{}),
|
||||
taskExecutors: make(map[MaintenanceTaskType]TaskExecutor),
|
||||
taskRegistry: tasks.GetGlobalTaskRegistry(), // Use global registry with auto-registered tasks
|
||||
}
|
||||
|
||||
// Initialize task executor registry
|
||||
worker.initializeTaskExecutors()
|
||||
|
||||
glog.V(1).Infof("Created maintenance worker with %d registered task types", len(worker.taskRegistry.GetAll()))
|
||||
|
||||
return worker
|
||||
}
|
||||
|
||||
// executeGenericTask executes a task using the task registry instead of hardcoded methods
|
||||
func (mws *MaintenanceWorkerService) executeGenericTask(task *MaintenanceTask) error {
|
||||
glog.V(2).Infof("Executing generic task %s: %s for volume %d", task.ID, task.Type, task.VolumeID)
|
||||
|
||||
// Validate that task has proper typed parameters
|
||||
if task.TypedParams == nil {
|
||||
return fmt.Errorf("task %s has no typed parameters - task was not properly planned (insufficient destinations)", task.ID)
|
||||
}
|
||||
|
||||
// Convert MaintenanceTask to types.TaskType
|
||||
taskType := types.TaskType(string(task.Type))
|
||||
|
||||
// Create task instance using the registry
|
||||
taskInstance, err := mws.taskRegistry.Get(taskType).Create(task.TypedParams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create task instance: %w", err)
|
||||
}
|
||||
|
||||
// Update progress to show task has started
|
||||
mws.updateTaskProgress(task.ID, 5)
|
||||
|
||||
// Execute the task
|
||||
err = taskInstance.Execute(context.Background(), task.TypedParams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("task execution failed: %w", err)
|
||||
}
|
||||
|
||||
// Update progress to show completion
|
||||
mws.updateTaskProgress(task.ID, 100)
|
||||
|
||||
glog.V(2).Infof("Generic task %s completed successfully", task.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// initializeTaskExecutors sets up the task execution registry dynamically
|
||||
func (mws *MaintenanceWorkerService) initializeTaskExecutors() {
|
||||
mws.taskExecutors = make(map[MaintenanceTaskType]TaskExecutor)
|
||||
|
||||
// Get all registered executor factories and create executors
|
||||
executorRegistryMutex.RLock()
|
||||
defer executorRegistryMutex.RUnlock()
|
||||
|
||||
for taskType, factory := range taskExecutorFactories {
|
||||
executor := factory()
|
||||
mws.taskExecutors[taskType] = executor
|
||||
glog.V(3).Infof("Initialized executor for task type: %s", taskType)
|
||||
}
|
||||
|
||||
glog.V(2).Infof("Initialized %d task executors", len(mws.taskExecutors))
|
||||
}
|
||||
|
||||
// RegisterTaskExecutor allows dynamic registration of new task executors
|
||||
func (mws *MaintenanceWorkerService) RegisterTaskExecutor(taskType MaintenanceTaskType, executor TaskExecutor) {
|
||||
if mws.taskExecutors == nil {
|
||||
mws.taskExecutors = make(map[MaintenanceTaskType]TaskExecutor)
|
||||
}
|
||||
mws.taskExecutors[taskType] = executor
|
||||
glog.V(1).Infof("Registered executor for task type: %s", taskType)
|
||||
}
|
||||
|
||||
// GetSupportedTaskTypes returns all task types that this worker can execute
|
||||
func (mws *MaintenanceWorkerService) GetSupportedTaskTypes() []MaintenanceTaskType {
|
||||
return GetSupportedExecutorTaskTypes()
|
||||
}
|
||||
|
||||
// Start begins the worker service
|
||||
func (mws *MaintenanceWorkerService) Start() error {
|
||||
mws.running = true
|
||||
|
||||
// Register with admin server
|
||||
worker := &MaintenanceWorker{
|
||||
ID: mws.workerID,
|
||||
Address: mws.address,
|
||||
Capabilities: mws.capabilities,
|
||||
MaxConcurrent: mws.maxConcurrent,
|
||||
}
|
||||
|
||||
if mws.queue != nil {
|
||||
mws.queue.RegisterWorker(worker)
|
||||
}
|
||||
|
||||
// Start worker loop
|
||||
go mws.workerLoop()
|
||||
|
||||
glog.Infof("Maintenance worker %s started at %s", mws.workerID, mws.address)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop terminates the worker service
|
||||
func (mws *MaintenanceWorkerService) Stop() {
|
||||
mws.running = false
|
||||
close(mws.stopChan)
|
||||
|
||||
// Wait for current tasks to complete or timeout
|
||||
timeout := time.NewTimer(30 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for len(mws.currentTasks) > 0 {
|
||||
select {
|
||||
case <-timeout.C:
|
||||
glog.Warningf("Worker %s stopping with %d tasks still running", mws.workerID, len(mws.currentTasks))
|
||||
return
|
||||
case <-time.After(time.Second):
|
||||
// Check again
|
||||
}
|
||||
}
|
||||
|
||||
glog.Infof("Maintenance worker %s stopped", mws.workerID)
|
||||
}
|
||||
|
||||
// workerLoop is the main worker event loop
|
||||
func (mws *MaintenanceWorkerService) workerLoop() {
|
||||
heartbeatTicker := time.NewTicker(30 * time.Second)
|
||||
defer heartbeatTicker.Stop()
|
||||
|
||||
taskRequestTicker := time.NewTicker(5 * time.Second)
|
||||
defer taskRequestTicker.Stop()
|
||||
|
||||
for mws.running {
|
||||
select {
|
||||
case <-mws.stopChan:
|
||||
return
|
||||
case <-heartbeatTicker.C:
|
||||
mws.sendHeartbeat()
|
||||
case <-taskRequestTicker.C:
|
||||
mws.requestTasks()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendHeartbeat sends heartbeat to admin server
|
||||
func (mws *MaintenanceWorkerService) sendHeartbeat() {
|
||||
if mws.queue != nil {
|
||||
mws.queue.UpdateWorkerHeartbeat(mws.workerID)
|
||||
}
|
||||
}
|
||||
|
||||
// requestTasks requests new tasks from the admin server
|
||||
func (mws *MaintenanceWorkerService) requestTasks() {
|
||||
if len(mws.currentTasks) >= mws.maxConcurrent {
|
||||
return // Already at capacity
|
||||
}
|
||||
|
||||
if mws.queue != nil {
|
||||
task := mws.queue.GetNextTask(mws.workerID, mws.capabilities)
|
||||
if task != nil {
|
||||
mws.executeTask(task)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeTask executes a maintenance task
|
||||
func (mws *MaintenanceWorkerService) executeTask(task *MaintenanceTask) {
|
||||
mws.currentTasks[task.ID] = task
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
delete(mws.currentTasks, task.ID)
|
||||
}()
|
||||
|
||||
glog.Infof("Worker %s executing task %s: %s", mws.workerID, task.ID, task.Type)
|
||||
|
||||
// Execute task using dynamic executor registry
|
||||
var err error
|
||||
if executor, exists := mws.taskExecutors[task.Type]; exists {
|
||||
err = executor(mws, task)
|
||||
} else {
|
||||
err = fmt.Errorf("unsupported task type: %s", task.Type)
|
||||
glog.Errorf("No executor registered for task type: %s", task.Type)
|
||||
}
|
||||
|
||||
// Report task completion
|
||||
if mws.queue != nil {
|
||||
errorMsg := ""
|
||||
if err != nil {
|
||||
errorMsg = err.Error()
|
||||
}
|
||||
mws.queue.CompleteTask(task.ID, errorMsg)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
glog.Errorf("Worker %s failed to execute task %s: %v", mws.workerID, task.ID, err)
|
||||
} else {
|
||||
glog.Infof("Worker %s completed task %s successfully", mws.workerID, task.ID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// updateTaskProgress updates the progress of a task
|
||||
func (mws *MaintenanceWorkerService) updateTaskProgress(taskID string, progress float64) {
|
||||
if mws.queue != nil {
|
||||
mws.queue.UpdateTaskProgress(taskID, progress)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus returns the current status of the worker
|
||||
func (mws *MaintenanceWorkerService) GetStatus() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"worker_id": mws.workerID,
|
||||
"address": mws.address,
|
||||
"running": mws.running,
|
||||
"capabilities": mws.capabilities,
|
||||
"max_concurrent": mws.maxConcurrent,
|
||||
"current_tasks": len(mws.currentTasks),
|
||||
"task_details": mws.currentTasks,
|
||||
}
|
||||
}
|
||||
|
||||
// SetQueue sets the maintenance queue for the worker
|
||||
func (mws *MaintenanceWorkerService) SetQueue(queue *MaintenanceQueue) {
|
||||
mws.queue = queue
|
||||
}
|
||||
|
||||
// SetAdminClient sets the admin client for the worker
|
||||
func (mws *MaintenanceWorkerService) SetAdminClient(client AdminClient) {
|
||||
mws.adminClient = client
|
||||
}
|
||||
|
||||
// SetCapabilities sets the worker capabilities
|
||||
func (mws *MaintenanceWorkerService) SetCapabilities(capabilities []MaintenanceTaskType) {
|
||||
mws.capabilities = capabilities
|
||||
}
|
||||
|
||||
// SetMaxConcurrent sets the maximum concurrent tasks
|
||||
func (mws *MaintenanceWorkerService) SetMaxConcurrent(max int) {
|
||||
mws.maxConcurrent = max
|
||||
}
|
||||
|
||||
// SetHeartbeatInterval sets the heartbeat interval (placeholder for future use)
|
||||
func (mws *MaintenanceWorkerService) SetHeartbeatInterval(interval time.Duration) {
|
||||
// Future implementation for configurable heartbeat
|
||||
}
|
||||
|
||||
// SetTaskRequestInterval sets the task request interval (placeholder for future use)
|
||||
func (mws *MaintenanceWorkerService) SetTaskRequestInterval(interval time.Duration) {
|
||||
// Future implementation for configurable task requests
|
||||
}
|
||||
|
||||
// MaintenanceWorkerCommand represents a standalone maintenance worker command
|
||||
type MaintenanceWorkerCommand struct {
|
||||
workerService *MaintenanceWorkerService
|
||||
}
|
||||
|
||||
// NewMaintenanceWorkerCommand creates a new worker command
|
||||
func NewMaintenanceWorkerCommand(workerID, address, adminServer string) *MaintenanceWorkerCommand {
|
||||
return &MaintenanceWorkerCommand{
|
||||
workerService: NewMaintenanceWorkerService(workerID, address, adminServer),
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the maintenance worker as a standalone service
|
||||
func (mwc *MaintenanceWorkerCommand) Run() error {
|
||||
// Generate or load persistent worker ID if not provided
|
||||
if mwc.workerService.workerID == "" {
|
||||
// Get current working directory for worker ID persistence
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get working directory: %w", err)
|
||||
}
|
||||
|
||||
workerID, err := worker.GenerateOrLoadWorkerID(wd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate or load worker ID: %w", err)
|
||||
}
|
||||
mwc.workerService.workerID = workerID
|
||||
}
|
||||
|
||||
// Start the worker service
|
||||
err := mwc.workerService.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start maintenance worker: %w", err)
|
||||
}
|
||||
|
||||
// Wait for interrupt signal
|
||||
select {}
|
||||
}
|
||||
@@ -122,6 +122,7 @@ type Plugin struct {
|
||||
type streamSession struct {
|
||||
workerID string
|
||||
outgoing chan *plugin_pb.AdminToWorkerMessage
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
@@ -274,6 +275,7 @@ func (r *Plugin) WorkerStream(stream plugin_pb.PluginControlService_WorkerStream
|
||||
session := &streamSession{
|
||||
workerID: workerID,
|
||||
outgoing: make(chan *plugin_pb.AdminToWorkerMessage, r.outgoingBuffer),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
r.putSession(session)
|
||||
defer r.cleanupSession(workerID)
|
||||
@@ -908,8 +910,10 @@ func (r *Plugin) sendLoop(
|
||||
return nil
|
||||
case <-r.shutdownCh:
|
||||
return nil
|
||||
case msg, ok := <-session.outgoing:
|
||||
if !ok {
|
||||
case <-session.done:
|
||||
return nil
|
||||
case msg := <-session.outgoing:
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
if err := stream.Send(msg); err != nil {
|
||||
@@ -930,6 +934,8 @@ func (r *Plugin) sendToWorker(workerID string, message *plugin_pb.AdminToWorkerM
|
||||
select {
|
||||
case <-r.shutdownCh:
|
||||
return fmt.Errorf("plugin is shutting down")
|
||||
case <-session.done:
|
||||
return fmt.Errorf("worker %s session is closed", workerID)
|
||||
case session.outgoing <- message:
|
||||
return nil
|
||||
case <-time.After(r.sendTimeout):
|
||||
@@ -1425,7 +1431,7 @@ func CloneConfigValueMap(in map[string]*plugin_pb.ConfigValue) map[string]*plugi
|
||||
|
||||
func (s *streamSession) close() {
|
||||
s.closeOnce.Do(func() {
|
||||
close(s.outgoing)
|
||||
close(s.done)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestRunDetectionSendsCancelOnContextDone(t *testing.T) {
|
||||
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
|
||||
},
|
||||
})
|
||||
session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4)}
|
||||
session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4), done: make(chan struct{})}
|
||||
pluginSvc.putSession(session)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -77,7 +77,7 @@ func TestExecuteJobSendsCancelOnContextDone(t *testing.T) {
|
||||
{JobType: jobType, CanExecute: true, MaxExecutionConcurrency: 1},
|
||||
},
|
||||
})
|
||||
session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4)}
|
||||
session := &streamSession{workerID: workerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 4), done: make(chan struct{})}
|
||||
pluginSvc.putSession(session)
|
||||
|
||||
job := &plugin_pb.JobSpec{JobId: "job-1", JobType: jobType}
|
||||
@@ -135,8 +135,8 @@ func TestAdminScriptExecutionBlocksOtherDetection(t *testing.T) {
|
||||
{JobType: "vacuum", CanDetect: true, MaxDetectionConcurrency: 1},
|
||||
},
|
||||
})
|
||||
adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)}
|
||||
otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)}
|
||||
adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
|
||||
otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
|
||||
pluginSvc.putSession(adminSession)
|
||||
pluginSvc.putSession(otherSession)
|
||||
|
||||
@@ -214,8 +214,8 @@ func TestAdminScriptExecutionBlocksOtherExecution(t *testing.T) {
|
||||
{JobType: "vacuum", CanExecute: true, MaxExecutionConcurrency: 1},
|
||||
},
|
||||
})
|
||||
adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)}
|
||||
otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8)}
|
||||
adminSession := &streamSession{workerID: adminWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
|
||||
otherSession := &streamSession{workerID: otherWorkerID, outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 8), done: make(chan struct{})}
|
||||
pluginSvc.putSession(adminSession)
|
||||
pluginSvc.putSession(otherSession)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestRunDetectionIncludesLatestSuccessfulRun(t *testing.T) {
|
||||
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
|
||||
},
|
||||
})
|
||||
session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)}
|
||||
session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
|
||||
pluginSvc.putSession(session)
|
||||
|
||||
oldSuccess := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
@@ -80,7 +80,7 @@ func TestRunDetectionOmitsLastSuccessfulRunWhenNoSuccessHistory(t *testing.T) {
|
||||
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
|
||||
},
|
||||
})
|
||||
session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)}
|
||||
session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
|
||||
pluginSvc.putSession(session)
|
||||
|
||||
if err := pluginSvc.store.AppendRunRecord(jobType, &JobRunRecord{
|
||||
@@ -130,7 +130,7 @@ func TestRunDetectionWithReportCapturesDetectionActivities(t *testing.T) {
|
||||
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
|
||||
},
|
||||
})
|
||||
session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)}
|
||||
session := &streamSession{workerID: "worker-a", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
|
||||
pluginSvc.putSession(session)
|
||||
|
||||
reportCh := make(chan *DetectionReport, 1)
|
||||
@@ -210,7 +210,7 @@ func TestRunDetectionAdminScriptUsesLastCompletedRun(t *testing.T) {
|
||||
{JobType: jobType, CanDetect: true, MaxDetectionConcurrency: 1},
|
||||
},
|
||||
})
|
||||
session := &streamSession{workerID: "worker-admin-script", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1)}
|
||||
session := &streamSession{workerID: "worker-admin-script", outgoing: make(chan *plugin_pb.AdminToWorkerMessage, 1), done: make(chan struct{})}
|
||||
pluginSvc.putSession(session)
|
||||
|
||||
successCompleted := time.Date(2026, 2, 1, 10, 0, 0, 0, time.UTC)
|
||||
|
||||
@@ -95,16 +95,6 @@ func (r *Plugin) laneSchedulerLoop(ls *schedulerLaneState) {
|
||||
}
|
||||
}
|
||||
|
||||
// schedulerLoop is kept for backward compatibility; it delegates to
|
||||
// laneSchedulerLoop with the default lane. New code should not call this.
|
||||
func (r *Plugin) schedulerLoop() {
|
||||
ls := r.lanes[LaneDefault]
|
||||
if ls == nil {
|
||||
ls = newLaneState(LaneDefault)
|
||||
}
|
||||
r.laneSchedulerLoop(ls)
|
||||
}
|
||||
|
||||
// runLaneSchedulerIteration runs one scheduling pass for a single lane,
|
||||
// processing only the job types assigned to that lane.
|
||||
//
|
||||
@@ -229,82 +219,6 @@ func (r *Plugin) runLaneSchedulerIterationConcurrent(ls *schedulerLaneState, job
|
||||
return hadJobs.Load()
|
||||
}
|
||||
|
||||
// runSchedulerIteration is kept for backward compatibility. It runs a
|
||||
// single iteration across ALL job types (equivalent to the old single-loop
|
||||
// behavior). It is only used by the legacy schedulerLoop() fallback.
|
||||
func (r *Plugin) runSchedulerIteration() bool {
|
||||
ls := r.lanes[LaneDefault]
|
||||
if ls == nil {
|
||||
ls = newLaneState(LaneDefault)
|
||||
}
|
||||
// For backward compat, the old function processes all job types.
|
||||
r.expireStaleJobs(time.Now().UTC())
|
||||
|
||||
jobTypes := r.registry.DetectableJobTypes()
|
||||
if len(jobTypes) == 0 {
|
||||
r.setSchedulerLoopState("", "idle")
|
||||
return false
|
||||
}
|
||||
|
||||
r.setSchedulerLoopState("", "waiting_for_lock")
|
||||
releaseLock, err := r.acquireAdminLock("plugin scheduler iteration")
|
||||
if err != nil {
|
||||
glog.Warningf("Plugin scheduler failed to acquire lock: %v", err)
|
||||
r.setSchedulerLoopState("", "idle")
|
||||
return false
|
||||
}
|
||||
if releaseLock != nil {
|
||||
defer releaseLock()
|
||||
}
|
||||
|
||||
active := make(map[string]struct{}, len(jobTypes))
|
||||
hadJobs := false
|
||||
|
||||
for _, jobType := range jobTypes {
|
||||
active[jobType] = struct{}{}
|
||||
|
||||
policy, enabled, err := r.loadSchedulerPolicy(jobType)
|
||||
if err != nil {
|
||||
glog.Warningf("Plugin scheduler failed to load policy for %s: %v", jobType, err)
|
||||
continue
|
||||
}
|
||||
if !enabled {
|
||||
r.clearSchedulerJobType(jobType)
|
||||
continue
|
||||
}
|
||||
initialDelay := time.Duration(0)
|
||||
if runInfo := r.snapshotSchedulerRun(jobType); runInfo.lastRunStartedAt.IsZero() {
|
||||
initialDelay = 5 * time.Second
|
||||
}
|
||||
if !r.markDetectionDue(jobType, policy.DetectionInterval, initialDelay) {
|
||||
continue
|
||||
}
|
||||
|
||||
detected := r.runJobTypeIteration(jobType, policy)
|
||||
if detected {
|
||||
hadJobs = true
|
||||
}
|
||||
}
|
||||
|
||||
r.pruneSchedulerState(active)
|
||||
r.pruneDetectorLeases(active)
|
||||
r.setSchedulerLoopState("", "idle")
|
||||
return hadJobs
|
||||
}
|
||||
|
||||
// wakeLane wakes the scheduler goroutine for a specific lane.
|
||||
func (r *Plugin) wakeLane(lane SchedulerLane) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
if ls, ok := r.lanes[lane]; ok {
|
||||
select {
|
||||
case ls.wakeCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// wakeAllLanes wakes all lane scheduler goroutines.
|
||||
func (r *Plugin) wakeAllLanes() {
|
||||
if r == nil {
|
||||
|
||||
@@ -210,16 +210,6 @@ func (r *Plugin) setSchedulerLoopStateForJobType(jobType, phase string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Plugin) recordSchedulerIterationComplete(hadJobs bool) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
r.schedulerLoopMu.Lock()
|
||||
r.schedulerLoopState.lastIterationHadJobs = hadJobs
|
||||
r.schedulerLoopState.lastIterationCompleted = time.Now().UTC()
|
||||
r.schedulerLoopMu.Unlock()
|
||||
}
|
||||
|
||||
func (r *Plugin) snapshotSchedulerLoopState() schedulerLoopState {
|
||||
if r == nil {
|
||||
return schedulerLoopState{}
|
||||
|
||||
@@ -6,20 +6,6 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// getStatusColor returns Bootstrap color class for status
|
||||
func getStatusColor(status string) string {
|
||||
switch status {
|
||||
case "active", "healthy":
|
||||
return "success"
|
||||
case "warning":
|
||||
return "warning"
|
||||
case "critical", "unreachable":
|
||||
return "danger"
|
||||
default:
|
||||
return "secondary"
|
||||
}
|
||||
}
|
||||
|
||||
// formatBytes converts bytes to human readable format
|
||||
func formatBytes(bytes int64) string {
|
||||
if bytes == 0 {
|
||||
|
||||
@@ -95,18 +95,6 @@ func NewCluster() *Cluster {
|
||||
}
|
||||
}
|
||||
|
||||
func (cluster *Cluster) getGroupMembers(filerGroup FilerGroupName, nodeType string, createIfNotFound bool) *GroupMembers {
|
||||
switch nodeType {
|
||||
case FilerType:
|
||||
return cluster.filerGroups.getGroupMembers(filerGroup, createIfNotFound)
|
||||
case BrokerType:
|
||||
return cluster.brokerGroups.getGroupMembers(filerGroup, createIfNotFound)
|
||||
case S3Type:
|
||||
return cluster.s3Groups.getGroupMembers(filerGroup, createIfNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cluster *Cluster) AddClusterNode(ns, nodeType string, dataCenter DataCenter, rack Rack, address pb.ServerAddress, version string) []*master_pb.KeepConnectedResponse {
|
||||
filerGroup := FilerGroupName(ns)
|
||||
switch nodeType {
|
||||
|
||||
@@ -511,11 +511,6 @@ func recoveryMiddleware(next http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// GetAdminOptions returns the admin command options for testing
|
||||
func GetAdminOptions() *AdminOptions {
|
||||
return &AdminOptions{}
|
||||
}
|
||||
|
||||
// loadOrGenerateSessionKeys loads or creates authentication/encryption keys for session cookies.
|
||||
func loadOrGenerateSessionKeys(dataDir string) ([]byte, []byte, error) {
|
||||
const keyLen = 32
|
||||
|
||||
@@ -132,16 +132,3 @@ func fetchContent(masterFn operation.GetMasterFn, grpcDialOption grpc.DialOption
|
||||
content, e = io.ReadAll(rc.Body)
|
||||
return
|
||||
}
|
||||
|
||||
func WriteFile(filename string, data []byte, perm os.FileMode) error {
|
||||
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n, err := f.Write(data)
|
||||
f.Close()
|
||||
if err == nil && n < len(data) {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -57,42 +57,6 @@ func LoadCredentialConfiguration() (*CredentialConfig, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetCredentialStoreConfig extracts credential store configuration from command line flags
|
||||
// This is used when credential store is configured via command line instead of credential.toml
|
||||
func GetCredentialStoreConfig(store string, config util.Configuration, prefix string) *CredentialConfig {
|
||||
if store == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &CredentialConfig{
|
||||
Store: store,
|
||||
Config: config,
|
||||
Prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// MergeCredentialConfig merges command line credential config with credential.toml config
|
||||
// Command line flags take priority over credential.toml
|
||||
func MergeCredentialConfig(cmdLineStore string, cmdLineConfig util.Configuration, cmdLinePrefix string) (*CredentialConfig, error) {
|
||||
// If command line credential store is specified, use it
|
||||
if cmdLineStore != "" {
|
||||
glog.V(0).Infof("Using command line credential configuration: store=%s", cmdLineStore)
|
||||
return GetCredentialStoreConfig(cmdLineStore, cmdLineConfig, cmdLinePrefix), nil
|
||||
}
|
||||
|
||||
// Otherwise, try to load from credential.toml
|
||||
config, err := LoadCredentialConfiguration()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
glog.V(1).Info("No credential store configured")
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// NewCredentialManagerWithDefaults creates a credential manager with fallback to defaults
|
||||
// If explicitStore is provided, it will be used regardless of credential.toml
|
||||
// If explicitStore is empty, it tries credential.toml first, then defaults to "filer_etc"
|
||||
|
||||
@@ -207,32 +207,6 @@ func (store *FilerEtcStore) loadPoliciesFromMultiFile(ctx context.Context, polic
|
||||
})
|
||||
}
|
||||
|
||||
func (store *FilerEtcStore) migratePoliciesToMultiFile(ctx context.Context, policies map[string]policy_engine.PolicyDocument) error {
|
||||
glog.Infof("Migrating IAM policies to multi-file layout...")
|
||||
|
||||
// 1. Save all policies to individual files
|
||||
for name, policy := range policies {
|
||||
if err := store.savePolicy(ctx, name, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Rename legacy file
|
||||
return store.withFilerClient(func(client filer_pb.SeaweedFilerClient) error {
|
||||
_, err := client.AtomicRenameEntry(ctx, &filer_pb.AtomicRenameEntryRequest{
|
||||
OldDirectory: filer.IamConfigDirectory,
|
||||
OldName: filer.IamPoliciesFile,
|
||||
NewDirectory: filer.IamConfigDirectory,
|
||||
NewName: IamLegacyPoliciesOldFile,
|
||||
})
|
||||
if err != nil {
|
||||
glog.Errorf("Failed to rename legacy IAM policies file %s/%s to %s: %v",
|
||||
filer.IamConfigDirectory, filer.IamPoliciesFile, IamLegacyPoliciesOldFile, err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (store *FilerEtcStore) savePolicy(ctx context.Context, name string, document policy_engine.PolicyDocument) error {
|
||||
if err := validatePolicyName(name); err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,221 +0,0 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/iam_pb"
|
||||
"github.com/seaweedfs/seaweedfs/weed/util"
|
||||
)
|
||||
|
||||
// MigrateCredentials migrates credentials from one store to another
|
||||
func MigrateCredentials(fromStoreName, toStoreName CredentialStoreTypeName, configuration util.Configuration, fromPrefix, toPrefix string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create source credential manager
|
||||
fromCM, err := NewCredentialManager(fromStoreName, configuration, fromPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create source credential manager (%s): %v", fromStoreName, err)
|
||||
}
|
||||
defer fromCM.Shutdown()
|
||||
|
||||
// Create destination credential manager
|
||||
toCM, err := NewCredentialManager(toStoreName, configuration, toPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create destination credential manager (%s): %v", toStoreName, err)
|
||||
}
|
||||
defer toCM.Shutdown()
|
||||
|
||||
// Load configuration from source
|
||||
glog.Infof("Loading configuration from %s store...", fromStoreName)
|
||||
config, err := fromCM.LoadConfiguration(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load configuration from source store: %w", err)
|
||||
}
|
||||
|
||||
if config == nil || len(config.Identities) == 0 {
|
||||
glog.Info("No identities found in source store")
|
||||
return nil
|
||||
}
|
||||
|
||||
glog.Infof("Found %d identities in source store", len(config.Identities))
|
||||
|
||||
// Migrate each identity
|
||||
var migrated, failed int
|
||||
for _, identity := range config.Identities {
|
||||
glog.V(1).Infof("Migrating user: %s", identity.Name)
|
||||
|
||||
// Check if user already exists in destination
|
||||
existingUser, err := toCM.GetUser(ctx, identity.Name)
|
||||
if err != nil && err != ErrUserNotFound {
|
||||
glog.Errorf("Failed to check if user %s exists in destination: %v", identity.Name, err)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
if existingUser != nil {
|
||||
glog.Warningf("User %s already exists in destination store, skipping", identity.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
// Create user in destination
|
||||
err = toCM.CreateUser(ctx, identity)
|
||||
if err != nil {
|
||||
glog.Errorf("Failed to create user %s in destination store: %v", identity.Name, err)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
migrated++
|
||||
glog.V(1).Infof("Successfully migrated user: %s", identity.Name)
|
||||
}
|
||||
|
||||
glog.Infof("Migration completed: %d migrated, %d failed", migrated, failed)
|
||||
|
||||
if failed > 0 {
|
||||
return fmt.Errorf("migration completed with %d failures", failed)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExportCredentials exports credentials from a store to a configuration
|
||||
func ExportCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string) (*iam_pb.S3ApiConfiguration, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create credential manager
|
||||
cm, err := NewCredentialManager(storeName, configuration, prefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create credential manager (%s): %v", storeName, err)
|
||||
}
|
||||
defer cm.Shutdown()
|
||||
|
||||
// Load configuration
|
||||
config, err := cm.LoadConfiguration(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load configuration: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// ImportCredentials imports credentials from a configuration to a store
|
||||
func ImportCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string, config *iam_pb.S3ApiConfiguration) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create credential manager
|
||||
cm, err := NewCredentialManager(storeName, configuration, prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create credential manager (%s): %v", storeName, err)
|
||||
}
|
||||
defer cm.Shutdown()
|
||||
|
||||
// Import each identity
|
||||
var imported, failed int
|
||||
for _, identity := range config.Identities {
|
||||
glog.V(1).Infof("Importing user: %s", identity.Name)
|
||||
|
||||
// Check if user already exists
|
||||
existingUser, err := cm.GetUser(ctx, identity.Name)
|
||||
if err != nil && err != ErrUserNotFound {
|
||||
glog.Errorf("Failed to check if user %s exists: %v", identity.Name, err)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
if existingUser != nil {
|
||||
glog.Warningf("User %s already exists, skipping", identity.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
// Create user
|
||||
err = cm.CreateUser(ctx, identity)
|
||||
if err != nil {
|
||||
glog.Errorf("Failed to create user %s: %v", identity.Name, err)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
imported++
|
||||
glog.V(1).Infof("Successfully imported user: %s", identity.Name)
|
||||
}
|
||||
|
||||
glog.Infof("Import completed: %d imported, %d failed", imported, failed)
|
||||
|
||||
if failed > 0 {
|
||||
return fmt.Errorf("import completed with %d failures", failed)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateCredentials validates that all credentials in a store are accessible
|
||||
func ValidateCredentials(storeName CredentialStoreTypeName, configuration util.Configuration, prefix string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create credential manager
|
||||
cm, err := NewCredentialManager(storeName, configuration, prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create credential manager (%s): %v", storeName, err)
|
||||
}
|
||||
defer cm.Shutdown()
|
||||
|
||||
// Load configuration
|
||||
config, err := cm.LoadConfiguration(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load configuration: %w", err)
|
||||
}
|
||||
|
||||
if config == nil || len(config.Identities) == 0 {
|
||||
glog.Info("No identities found in store")
|
||||
return nil
|
||||
}
|
||||
|
||||
glog.Infof("Validating %d identities...", len(config.Identities))
|
||||
|
||||
// Validate each identity
|
||||
var validated, failed int
|
||||
for _, identity := range config.Identities {
|
||||
// Check if user can be retrieved
|
||||
user, err := cm.GetUser(ctx, identity.Name)
|
||||
if err != nil {
|
||||
glog.Errorf("Failed to retrieve user %s: %v", identity.Name, err)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
glog.Errorf("User %s not found", identity.Name)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate access keys
|
||||
for _, credential := range identity.Credentials {
|
||||
accessKeyUser, err := cm.GetUserByAccessKey(ctx, credential.AccessKey)
|
||||
if err != nil {
|
||||
glog.Errorf("Failed to retrieve user by access key %s: %v", credential.AccessKey, err)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
if accessKeyUser == nil || accessKeyUser.Name != identity.Name {
|
||||
glog.Errorf("Access key %s does not map to correct user %s", credential.AccessKey, identity.Name)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
validated++
|
||||
glog.V(1).Infof("Successfully validated user: %s", identity.Name)
|
||||
}
|
||||
|
||||
glog.Infof("Validation completed: %d validated, %d failed", validated, failed)
|
||||
|
||||
if failed > 0 {
|
||||
return fmt.Errorf("validation completed with %d failures", failed)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -246,10 +246,6 @@ func NewLogFileEntryCollector(f *Filer, startPosition log_buffer.MessagePosition
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LogFileEntryCollector) hasMore() bool {
|
||||
return c.dayEntryQueue.Len() > 0
|
||||
}
|
||||
|
||||
func (c *LogFileEntryCollector) collectMore(v *OrderedLogVisitor) (err error) {
|
||||
dayEntry := c.dayEntryQueue.Dequeue()
|
||||
if dayEntry == nil {
|
||||
|
||||
@@ -2,7 +2,6 @@ package filer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
|
||||
@@ -36,39 +35,3 @@ func Replay(filerStore FilerStore, resp *filer_pb.SubscribeMetadataResponse) err
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParallelProcessDirectoryStructure processes each entry in parallel, and also ensure parent directories are processed first.
|
||||
// This also assumes the parent directories are in the entryChan already.
|
||||
func ParallelProcessDirectoryStructure(entryChan chan *Entry, concurrency int, eachEntryFn func(entry *Entry) error) (firstErr error) {
|
||||
|
||||
executors := util.NewLimitedConcurrentExecutor(concurrency)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for entry := range entryChan {
|
||||
wg.Add(1)
|
||||
if entry.IsDirectory() {
|
||||
func() {
|
||||
defer wg.Done()
|
||||
if err := eachEntryFn(entry); err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
executors.Execute(func() {
|
||||
defer wg.Done()
|
||||
if err := eachEntryFn(entry); err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
if firstErr != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -16,15 +16,6 @@ type ItemList struct {
|
||||
prefix string
|
||||
}
|
||||
|
||||
func newItemList(client redis.UniversalClient, prefix string, store skiplist.ListStore, batchSize int) *ItemList {
|
||||
return &ItemList{
|
||||
skipList: skiplist.New(store),
|
||||
batchSize: batchSize,
|
||||
client: client,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Be reluctant to create new nodes. Try to fit into either previous node or next node.
|
||||
Prefer to add to previous node.
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
package redis_lua
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/seaweedfs/seaweedfs/weed/filer"
|
||||
"github.com/seaweedfs/seaweedfs/weed/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
filer.Stores = append(filer.Stores, &RedisLuaClusterStore{})
|
||||
}
|
||||
|
||||
type RedisLuaClusterStore struct {
|
||||
UniversalRedisLuaStore
|
||||
}
|
||||
|
||||
func (store *RedisLuaClusterStore) GetName() string {
|
||||
return "redis_lua_cluster"
|
||||
}
|
||||
|
||||
func (store *RedisLuaClusterStore) Initialize(configuration util.Configuration, prefix string) (err error) {
|
||||
|
||||
configuration.SetDefault(prefix+"useReadOnly", false)
|
||||
configuration.SetDefault(prefix+"routeByLatency", false)
|
||||
|
||||
return store.initialize(
|
||||
configuration.GetStringSlice(prefix+"addresses"),
|
||||
configuration.GetString(prefix+"username"),
|
||||
configuration.GetString(prefix+"password"),
|
||||
configuration.GetString(prefix+"keyPrefix"),
|
||||
configuration.GetBool(prefix+"useReadOnly"),
|
||||
configuration.GetBool(prefix+"routeByLatency"),
|
||||
configuration.GetStringSlice(prefix+"superLargeDirectories"),
|
||||
)
|
||||
}
|
||||
|
||||
func (store *RedisLuaClusterStore) initialize(addresses []string, username string, password string, keyPrefix string, readOnly, routeByLatency bool, superLargeDirectories []string) (err error) {
|
||||
store.Client = redis.NewClusterClient(&redis.ClusterOptions{
|
||||
Addrs: addresses,
|
||||
Username: username,
|
||||
Password: password,
|
||||
ReadOnly: readOnly,
|
||||
RouteByLatency: routeByLatency,
|
||||
})
|
||||
store.keyPrefix = keyPrefix
|
||||
store.loadSuperLargeDirectories(superLargeDirectories)
|
||||
return
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package redis_lua
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/seaweedfs/seaweedfs/weed/filer"
|
||||
"github.com/seaweedfs/seaweedfs/weed/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
filer.Stores = append(filer.Stores, &RedisLuaSentinelStore{})
|
||||
}
|
||||
|
||||
type RedisLuaSentinelStore struct {
|
||||
UniversalRedisLuaStore
|
||||
}
|
||||
|
||||
func (store *RedisLuaSentinelStore) GetName() string {
|
||||
return "redis_lua_sentinel"
|
||||
}
|
||||
|
||||
func (store *RedisLuaSentinelStore) Initialize(configuration util.Configuration, prefix string) (err error) {
|
||||
return store.initialize(
|
||||
configuration.GetStringSlice(prefix+"addresses"),
|
||||
configuration.GetString(prefix+"masterName"),
|
||||
configuration.GetString(prefix+"username"),
|
||||
configuration.GetString(prefix+"password"),
|
||||
configuration.GetInt(prefix+"database"),
|
||||
configuration.GetString(prefix+"keyPrefix"),
|
||||
)
|
||||
}
|
||||
|
||||
func (store *RedisLuaSentinelStore) initialize(addresses []string, masterName string, username string, password string, database int, keyPrefix string) (err error) {
|
||||
store.Client = redis.NewFailoverClient(&redis.FailoverOptions{
|
||||
MasterName: masterName,
|
||||
SentinelAddrs: addresses,
|
||||
Username: username,
|
||||
Password: password,
|
||||
DB: database,
|
||||
MinRetryBackoff: time.Millisecond * 100,
|
||||
MaxRetryBackoff: time.Minute * 1,
|
||||
ReadTimeout: time.Second * 30,
|
||||
WriteTimeout: time.Second * 5,
|
||||
})
|
||||
store.keyPrefix = keyPrefix
|
||||
return
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
package redis_lua
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/seaweedfs/seaweedfs/weed/filer"
|
||||
"github.com/seaweedfs/seaweedfs/weed/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
filer.Stores = append(filer.Stores, &RedisLuaStore{})
|
||||
}
|
||||
|
||||
type RedisLuaStore struct {
|
||||
UniversalRedisLuaStore
|
||||
}
|
||||
|
||||
func (store *RedisLuaStore) GetName() string {
|
||||
return "redis_lua"
|
||||
}
|
||||
|
||||
func (store *RedisLuaStore) Initialize(configuration util.Configuration, prefix string) (err error) {
|
||||
return store.initialize(
|
||||
configuration.GetString(prefix+"address"),
|
||||
configuration.GetString(prefix+"username"),
|
||||
configuration.GetString(prefix+"password"),
|
||||
configuration.GetInt(prefix+"database"),
|
||||
configuration.GetString(prefix+"keyPrefix"),
|
||||
configuration.GetStringSlice(prefix+"superLargeDirectories"),
|
||||
)
|
||||
}
|
||||
|
||||
func (store *RedisLuaStore) initialize(hostPort string, username string, password string, database int, keyPrefix string, superLargeDirectories []string) (err error) {
|
||||
store.Client = redis.NewClient(&redis.Options{
|
||||
Addr: hostPort,
|
||||
Username: username,
|
||||
Password: password,
|
||||
DB: database,
|
||||
})
|
||||
store.keyPrefix = keyPrefix
|
||||
store.loadSuperLargeDirectories(superLargeDirectories)
|
||||
return
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
-- KEYS[1]: full path of entry
|
||||
local fullpath = KEYS[1]
|
||||
-- KEYS[2]: full path of entry
|
||||
local fullpath_list_key = KEYS[2]
|
||||
-- KEYS[3]: dir of the entry
|
||||
local dir_list_key = KEYS[3]
|
||||
|
||||
-- ARGV[1]: isSuperLargeDirectory
|
||||
local isSuperLargeDirectory = ARGV[1] == "1"
|
||||
-- ARGV[2]: name of the entry
|
||||
local name = ARGV[2]
|
||||
|
||||
redis.call("DEL", fullpath, fullpath_list_key)
|
||||
|
||||
if not isSuperLargeDirectory and name ~= "" then
|
||||
redis.call("ZREM", dir_list_key, name)
|
||||
end
|
||||
|
||||
return 0
|
||||
@@ -1,15 +0,0 @@
|
||||
-- KEYS[1]: full path of entry
|
||||
local fullpath = KEYS[1]
|
||||
|
||||
if fullpath ~= "" and string.sub(fullpath, -1) == "/" then
|
||||
fullpath = string.sub(fullpath, 0, -2)
|
||||
end
|
||||
|
||||
local files = redis.call("ZRANGE", fullpath .. "\0", "0", "-1")
|
||||
|
||||
for _, name in ipairs(files) do
|
||||
local file_path = fullpath .. "/" .. name
|
||||
redis.call("DEL", file_path, file_path .. "\0")
|
||||
end
|
||||
|
||||
return 0
|
||||
@@ -1,25 +0,0 @@
|
||||
package stored_procedure
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
func init() {
|
||||
InsertEntryScript = redis.NewScript(insertEntry)
|
||||
DeleteEntryScript = redis.NewScript(deleteEntry)
|
||||
DeleteFolderChildrenScript = redis.NewScript(deleteFolderChildren)
|
||||
}
|
||||
|
||||
//go:embed insert_entry.lua
|
||||
var insertEntry string
|
||||
var InsertEntryScript *redis.Script
|
||||
|
||||
//go:embed delete_entry.lua
|
||||
var deleteEntry string
|
||||
var DeleteEntryScript *redis.Script
|
||||
|
||||
//go:embed delete_folder_children.lua
|
||||
var deleteFolderChildren string
|
||||
var DeleteFolderChildrenScript *redis.Script
|
||||
@@ -1,27 +0,0 @@
|
||||
-- KEYS[1]: full path of entry
|
||||
local full_path = KEYS[1]
|
||||
-- KEYS[2]: dir of the entry
|
||||
local dir_list_key = KEYS[2]
|
||||
|
||||
-- ARGV[1]: content of the entry
|
||||
local entry = ARGV[1]
|
||||
-- ARGV[2]: TTL of the entry
|
||||
local ttlSec = tonumber(ARGV[2])
|
||||
-- ARGV[3]: isSuperLargeDirectory
|
||||
local isSuperLargeDirectory = ARGV[3] == "1"
|
||||
-- ARGV[4]: zscore of the entry in zset
|
||||
local zscore = tonumber(ARGV[4])
|
||||
-- ARGV[5]: name of the entry
|
||||
local name = ARGV[5]
|
||||
|
||||
if ttlSec > 0 then
|
||||
redis.call("SET", full_path, entry, "EX", ttlSec)
|
||||
else
|
||||
redis.call("SET", full_path, entry)
|
||||
end
|
||||
|
||||
if not isSuperLargeDirectory and name ~= "" then
|
||||
redis.call("ZADD", dir_list_key, "NX", zscore, name)
|
||||
end
|
||||
|
||||
return 0
|
||||
@@ -1,206 +0,0 @@
|
||||
package redis_lua
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/filer"
|
||||
"github.com/seaweedfs/seaweedfs/weed/filer/redis_lua/stored_procedure"
|
||||
"github.com/seaweedfs/seaweedfs/weed/glog"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
|
||||
"github.com/seaweedfs/seaweedfs/weed/util"
|
||||
)
|
||||
|
||||
const (
|
||||
DIR_LIST_MARKER = "\x00"
|
||||
)
|
||||
|
||||
type UniversalRedisLuaStore struct {
|
||||
Client redis.UniversalClient
|
||||
keyPrefix string
|
||||
superLargeDirectoryHash map[string]bool
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) isSuperLargeDirectory(dir string) (isSuperLargeDirectory bool) {
|
||||
_, isSuperLargeDirectory = store.superLargeDirectoryHash[dir]
|
||||
return
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) loadSuperLargeDirectories(superLargeDirectories []string) {
|
||||
// set directory hash
|
||||
store.superLargeDirectoryHash = make(map[string]bool)
|
||||
for _, dir := range superLargeDirectories {
|
||||
store.superLargeDirectoryHash[dir] = true
|
||||
}
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) getKey(key string) string {
|
||||
if store.keyPrefix == "" {
|
||||
return key
|
||||
}
|
||||
return store.keyPrefix + key
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) BeginTransaction(ctx context.Context) (context.Context, error) {
|
||||
return ctx, nil
|
||||
}
|
||||
func (store *UniversalRedisLuaStore) CommitTransaction(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
func (store *UniversalRedisLuaStore) RollbackTransaction(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) InsertEntry(ctx context.Context, entry *filer.Entry) (err error) {
|
||||
|
||||
value, err := entry.EncodeAttributesAndChunks()
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding %s %+v: %v", entry.FullPath, entry.Attr, err)
|
||||
}
|
||||
|
||||
if len(entry.GetChunks()) > filer.CountEntryChunksForGzip {
|
||||
value = util.MaybeGzipData(value)
|
||||
}
|
||||
|
||||
dir, name := entry.FullPath.DirAndName()
|
||||
|
||||
err = stored_procedure.InsertEntryScript.Run(ctx, store.Client,
|
||||
[]string{store.getKey(string(entry.FullPath)), store.getKey(genDirectoryListKey(dir))},
|
||||
value, entry.TtlSec,
|
||||
store.isSuperLargeDirectory(dir), 0, name).Err()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("persisting %s : %v", entry.FullPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) UpdateEntry(ctx context.Context, entry *filer.Entry) (err error) {
|
||||
|
||||
return store.InsertEntry(ctx, entry)
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) FindEntry(ctx context.Context, fullpath util.FullPath) (entry *filer.Entry, err error) {
|
||||
|
||||
data, err := store.Client.Get(ctx, store.getKey(string(fullpath))).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, filer_pb.ErrNotFound
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get %s : %v", fullpath, err)
|
||||
}
|
||||
|
||||
entry = &filer.Entry{
|
||||
FullPath: fullpath,
|
||||
}
|
||||
err = entry.DecodeAttributesAndChunks(util.MaybeDecompressData([]byte(data)))
|
||||
if err != nil {
|
||||
return entry, fmt.Errorf("decode %s : %v", entry.FullPath, err)
|
||||
}
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) DeleteEntry(ctx context.Context, fullpath util.FullPath) (err error) {
|
||||
|
||||
dir, name := fullpath.DirAndName()
|
||||
|
||||
err = stored_procedure.DeleteEntryScript.Run(ctx, store.Client,
|
||||
[]string{store.getKey(string(fullpath)), store.getKey(genDirectoryListKey(string(fullpath))), store.getKey(genDirectoryListKey(dir))},
|
||||
store.isSuperLargeDirectory(dir), name).Err()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("DeleteEntry %s : %v", fullpath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) DeleteFolderChildren(ctx context.Context, fullpath util.FullPath) (err error) {
|
||||
|
||||
if store.isSuperLargeDirectory(string(fullpath)) {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = stored_procedure.DeleteFolderChildrenScript.Run(ctx, store.Client,
|
||||
[]string{store.getKey(string(fullpath))}).Err()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("DeleteFolderChildren %s : %v", fullpath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) ListDirectoryPrefixedEntries(ctx context.Context, dirPath util.FullPath, startFileName string, includeStartFile bool, limit int64, prefix string, eachEntryFunc filer.ListEachEntryFunc) (lastFileName string, err error) {
|
||||
return lastFileName, filer.ErrUnsupportedListDirectoryPrefixed
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) ListDirectoryEntries(ctx context.Context, dirPath util.FullPath, startFileName string, includeStartFile bool, limit int64, eachEntryFunc filer.ListEachEntryFunc) (lastFileName string, err error) {
|
||||
|
||||
dirListKey := store.getKey(genDirectoryListKey(string(dirPath)))
|
||||
|
||||
min := "-"
|
||||
if startFileName != "" {
|
||||
if includeStartFile {
|
||||
min = "[" + startFileName
|
||||
} else {
|
||||
min = "(" + startFileName
|
||||
}
|
||||
}
|
||||
|
||||
members, err := store.Client.ZRangeByLex(ctx, dirListKey, &redis.ZRangeBy{
|
||||
Min: min,
|
||||
Max: "+",
|
||||
Offset: 0,
|
||||
Count: limit,
|
||||
}).Result()
|
||||
if err != nil {
|
||||
return lastFileName, fmt.Errorf("list %s : %v", dirPath, err)
|
||||
}
|
||||
|
||||
// fetch entry meta
|
||||
for _, fileName := range members {
|
||||
path := util.NewFullPath(string(dirPath), fileName)
|
||||
entry, err := store.FindEntry(ctx, path)
|
||||
lastFileName = fileName
|
||||
if err != nil {
|
||||
glog.V(0).InfofCtx(ctx, "list %s : %v", path, err)
|
||||
if err == filer_pb.ErrNotFound {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if entry.TtlSec > 0 {
|
||||
if entry.Attr.Crtime.Add(time.Duration(entry.TtlSec) * time.Second).Before(time.Now()) {
|
||||
store.DeleteEntry(ctx, path)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
resEachEntryFunc, resEachEntryFuncErr := eachEntryFunc(entry)
|
||||
if resEachEntryFuncErr != nil {
|
||||
err = fmt.Errorf("failed to process eachEntryFunc: %w", resEachEntryFuncErr)
|
||||
break
|
||||
}
|
||||
|
||||
if !resEachEntryFunc {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return lastFileName, err
|
||||
}
|
||||
|
||||
func genDirectoryListKey(dir string) (dirList string) {
|
||||
return dir + DIR_LIST_MARKER
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) Shutdown() {
|
||||
store.Client.Close()
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
package redis_lua
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/seaweedfs/seaweedfs/weed/filer"
|
||||
)
|
||||
|
||||
func (store *UniversalRedisLuaStore) KvPut(ctx context.Context, key []byte, value []byte) (err error) {
|
||||
|
||||
_, err = store.Client.Set(ctx, string(key), value, 0).Result()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("kv put: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) KvGet(ctx context.Context, key []byte) (value []byte, err error) {
|
||||
|
||||
data, err := store.Client.Get(ctx, string(key)).Result()
|
||||
|
||||
if err == redis.Nil {
|
||||
return nil, filer.ErrKvNotFound
|
||||
}
|
||||
|
||||
return []byte(data), err
|
||||
}
|
||||
|
||||
func (store *UniversalRedisLuaStore) KvDelete(ctx context.Context, key []byte) (err error) {
|
||||
|
||||
_, err = store.Client.Del(ctx, string(key)).Result()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("kv delete: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -102,10 +102,6 @@ func PrepareStreamContent(masterClient wdclient.HasLookupFileIdFunction, jwtFunc
|
||||
|
||||
type VolumeServerJwtFunction func(fileId string) string
|
||||
|
||||
func noJwtFunc(string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type CacheInvalidator interface {
|
||||
InvalidateCache(fileId string)
|
||||
}
|
||||
@@ -276,33 +272,6 @@ func writeZero(w io.Writer, size int64) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func ReadAll(ctx context.Context, buffer []byte, masterClient *wdclient.MasterClient, chunks []*filer_pb.FileChunk) error {
|
||||
|
||||
lookupFileIdFn := func(ctx context.Context, fileId string) (targetUrls []string, err error) {
|
||||
return masterClient.LookupFileId(ctx, fileId)
|
||||
}
|
||||
|
||||
chunkViews := ViewFromChunks(ctx, lookupFileIdFn, chunks, 0, int64(len(buffer)))
|
||||
|
||||
idx := 0
|
||||
|
||||
for x := chunkViews.Front(); x != nil; x = x.Next {
|
||||
chunkView := x.Value
|
||||
urlStrings, err := lookupFileIdFn(ctx, chunkView.FileId)
|
||||
if err != nil {
|
||||
glog.V(1).InfofCtx(ctx, "operation LookupFileId %s failed, err: %v", chunkView.FileId, err)
|
||||
return err
|
||||
}
|
||||
|
||||
n, err := util_http.RetriedFetchChunkData(ctx, buffer[idx:idx+int(chunkView.ViewSize)], urlStrings, chunkView.CipherKey, chunkView.IsGzipped, chunkView.IsFullChunk(), chunkView.OffsetInChunk, chunkView.FileId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idx += n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------- ChunkStreamReader ----------------------------------
|
||||
type ChunkStreamReader struct {
|
||||
head *Interval[*ChunkView]
|
||||
|
||||
@@ -1,281 +0,0 @@
|
||||
package filer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
|
||||
"github.com/seaweedfs/seaweedfs/weed/wdclient"
|
||||
)
|
||||
|
||||
// mockMasterClient implements HasLookupFileIdFunction and CacheInvalidator
|
||||
type mockMasterClient struct {
|
||||
lookupFunc func(ctx context.Context, fileId string) ([]string, error)
|
||||
invalidatedFileIds []string
|
||||
}
|
||||
|
||||
func (m *mockMasterClient) GetLookupFileIdFunction() wdclient.LookupFileIdFunctionType {
|
||||
return m.lookupFunc
|
||||
}
|
||||
|
||||
func (m *mockMasterClient) InvalidateCache(fileId string) {
|
||||
m.invalidatedFileIds = append(m.invalidatedFileIds, fileId)
|
||||
}
|
||||
|
||||
// Test urlSlicesEqual helper function
|
||||
func TestUrlSlicesEqual(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a []string
|
||||
b []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "identical slices",
|
||||
a: []string{"http://server1", "http://server2"},
|
||||
b: []string{"http://server1", "http://server2"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "same URLs different order",
|
||||
a: []string{"http://server1", "http://server2"},
|
||||
b: []string{"http://server2", "http://server1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different URLs",
|
||||
a: []string{"http://server1", "http://server2"},
|
||||
b: []string{"http://server1", "http://server3"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different lengths",
|
||||
a: []string{"http://server1"},
|
||||
b: []string{"http://server1", "http://server2"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty slices",
|
||||
a: []string{},
|
||||
b: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "duplicates in both",
|
||||
a: []string{"http://server1", "http://server1"},
|
||||
b: []string{"http://server1", "http://server1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different duplicate counts",
|
||||
a: []string{"http://server1", "http://server1"},
|
||||
b: []string{"http://server1", "http://server2"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := urlSlicesEqual(tt.a, tt.b)
|
||||
if result != tt.expected {
|
||||
t.Errorf("urlSlicesEqual(%v, %v) = %v; want %v", tt.a, tt.b, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test cache invalidation when read fails
|
||||
func TestStreamContentWithCacheInvalidation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
fileId := "3,01234567890"
|
||||
|
||||
callCount := 0
|
||||
oldUrls := []string{"http://failed-server:8080"}
|
||||
newUrls := []string{"http://working-server:8080"}
|
||||
|
||||
mock := &mockMasterClient{
|
||||
lookupFunc: func(ctx context.Context, fid string) ([]string, error) {
|
||||
callCount++
|
||||
if callCount == 1 {
|
||||
// First call returns failing server
|
||||
return oldUrls, nil
|
||||
}
|
||||
// After invalidation, return working server
|
||||
return newUrls, nil
|
||||
},
|
||||
}
|
||||
|
||||
// Create a simple chunk
|
||||
chunks := []*filer_pb.FileChunk{
|
||||
{
|
||||
FileId: fileId,
|
||||
Offset: 0,
|
||||
Size: 10,
|
||||
},
|
||||
}
|
||||
|
||||
streamFn, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("PrepareStreamContentWithThrottler failed: %v", err)
|
||||
}
|
||||
|
||||
// Note: This test can't fully execute streamFn because it would require actual HTTP servers
|
||||
// However, we can verify the setup was created correctly
|
||||
if streamFn == nil {
|
||||
t.Fatal("Expected non-nil stream function")
|
||||
}
|
||||
|
||||
// Verify the lookup was called
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected 1 lookup call, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that InvalidateCache is called on read failure
|
||||
func TestCacheInvalidationInterface(t *testing.T) {
|
||||
mock := &mockMasterClient{
|
||||
lookupFunc: func(ctx context.Context, fileId string) ([]string, error) {
|
||||
return []string{"http://server:8080"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
fileId := "3,test123"
|
||||
|
||||
// Simulate invalidation
|
||||
if invalidator, ok := interface{}(mock).(CacheInvalidator); ok {
|
||||
invalidator.InvalidateCache(fileId)
|
||||
} else {
|
||||
t.Fatal("mockMasterClient should implement CacheInvalidator")
|
||||
}
|
||||
|
||||
// Check that the file ID was recorded as invalidated
|
||||
if len(mock.invalidatedFileIds) != 1 {
|
||||
t.Fatalf("Expected 1 invalidated file ID, got %d", len(mock.invalidatedFileIds))
|
||||
}
|
||||
if mock.invalidatedFileIds[0] != fileId {
|
||||
t.Errorf("Expected invalidated file ID %s, got %s", fileId, mock.invalidatedFileIds[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Test retry logic doesn't retry with same URLs
|
||||
func TestRetryLogicSkipsSameUrls(t *testing.T) {
|
||||
// This test verifies that the urlSlicesEqual check prevents infinite retries
|
||||
sameUrls := []string{"http://server1:8080", "http://server2:8080"}
|
||||
differentUrls := []string{"http://server3:8080", "http://server4:8080"}
|
||||
|
||||
// Same URLs should return true (and thus skip retry)
|
||||
if !urlSlicesEqual(sameUrls, sameUrls) {
|
||||
t.Error("Expected same URLs to be equal")
|
||||
}
|
||||
|
||||
// Different URLs should return false (and thus allow retry)
|
||||
if urlSlicesEqual(sameUrls, differentUrls) {
|
||||
t.Error("Expected different URLs to not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanceledStreamSkipsCacheInvalidation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
fileId := "3,canceled"
|
||||
|
||||
mock := &mockMasterClient{
|
||||
lookupFunc: func(ctx context.Context, fid string) ([]string, error) {
|
||||
return []string{"http://server:8080"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
chunks := []*filer_pb.FileChunk{
|
||||
{
|
||||
FileId: fileId,
|
||||
Offset: 0,
|
||||
Size: 10,
|
||||
},
|
||||
}
|
||||
|
||||
streamFn, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("PrepareStreamContentWithThrottler failed: %v", err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
err = streamFn(&bytes.Buffer{})
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if len(mock.invalidatedFileIds) != 0 {
|
||||
t.Fatalf("expected no cache invalidation on cancellation, got %v", mock.invalidatedFileIds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareStreamContentSkipsLookupWhenContextAlreadyCanceled(t *testing.T) {
|
||||
oldSchedule := getLookupFileIdBackoffSchedule
|
||||
getLookupFileIdBackoffSchedule = []time.Duration{time.Millisecond}
|
||||
t.Cleanup(func() {
|
||||
getLookupFileIdBackoffSchedule = oldSchedule
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
lookupCalls := 0
|
||||
mock := &mockMasterClient{
|
||||
lookupFunc: func(ctx context.Context, fileId string) ([]string, error) {
|
||||
lookupCalls++
|
||||
return nil, errors.New("lookup should not run")
|
||||
},
|
||||
}
|
||||
|
||||
chunks := []*filer_pb.FileChunk{
|
||||
{
|
||||
FileId: "3,precanceled",
|
||||
Offset: 0,
|
||||
Size: 10,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if lookupCalls != 0 {
|
||||
t.Fatalf("expected no lookup calls after cancellation, got %d", lookupCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareStreamContentStopsLookupRetriesAfterContextCancellation(t *testing.T) {
|
||||
oldSchedule := getLookupFileIdBackoffSchedule
|
||||
getLookupFileIdBackoffSchedule = []time.Duration{time.Millisecond, time.Millisecond, time.Millisecond}
|
||||
t.Cleanup(func() {
|
||||
getLookupFileIdBackoffSchedule = oldSchedule
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
lookupCalls := 0
|
||||
mock := &mockMasterClient{
|
||||
lookupFunc: func(ctx context.Context, fileId string) ([]string, error) {
|
||||
lookupCalls++
|
||||
cancel()
|
||||
return nil, context.Canceled
|
||||
},
|
||||
}
|
||||
|
||||
chunks := []*filer_pb.FileChunk{
|
||||
{
|
||||
FileId: "3,cancel-during-lookup",
|
||||
Offset: 0,
|
||||
Size: 10,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := PrepareStreamContentWithThrottler(ctx, mock, noJwtFunc, chunks, 0, 10, 0)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if lookupCalls != 1 {
|
||||
t.Fatalf("expected lookup retries to stop after cancellation, got %d calls", lookupCalls)
|
||||
}
|
||||
}
|
||||
@@ -37,11 +37,6 @@ func GenerateRandomString(length int, charset string) (string, error) {
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GenerateAccessKeyId generates a new access key ID.
|
||||
func GenerateAccessKeyId() (string, error) {
|
||||
return GenerateRandomString(AccessKeyIdLength, CharsetUpper)
|
||||
}
|
||||
|
||||
// GenerateSecretAccessKey generates a new secret access key.
|
||||
func GenerateSecretAccessKey() (string, error) {
|
||||
return GenerateRandomString(SecretAccessKeyLength, Charset)
|
||||
@@ -179,11 +174,3 @@ func MapToIdentitiesAction(action string) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// MaskAccessKey masks an access key for logging, showing only the first 4 characters.
|
||||
func MaskAccessKey(accessKeyId string) string {
|
||||
if len(accessKeyId) > 4 {
|
||||
return accessKeyId[:4] + "***"
|
||||
}
|
||||
return accessKeyId
|
||||
}
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
package iam
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
input := "test"
|
||||
result := Hash(&input)
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Len(t, result, 40) // SHA1 hex is 40 chars
|
||||
|
||||
// Same input should produce same hash
|
||||
result2 := Hash(&input)
|
||||
assert.Equal(t, result, result2)
|
||||
|
||||
// Different input should produce different hash
|
||||
different := "different"
|
||||
result3 := Hash(&different)
|
||||
assert.NotEqual(t, result, result3)
|
||||
}
|
||||
|
||||
func TestGenerateRandomString(t *testing.T) {
|
||||
// Valid generation
|
||||
result, err := GenerateRandomString(10, CharsetUpper)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 10)
|
||||
|
||||
// Different calls should produce different results (with high probability)
|
||||
result2, err := GenerateRandomString(10, CharsetUpper)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, result, result2)
|
||||
|
||||
// Invalid length
|
||||
_, err = GenerateRandomString(0, CharsetUpper)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = GenerateRandomString(-1, CharsetUpper)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Empty charset
|
||||
_, err = GenerateRandomString(10, "")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGenerateAccessKeyId(t *testing.T) {
|
||||
keyId, err := GenerateAccessKeyId()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, keyId, AccessKeyIdLength)
|
||||
}
|
||||
|
||||
func TestGenerateSecretAccessKey(t *testing.T) {
|
||||
secretKey, err := GenerateSecretAccessKey()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, secretKey, SecretAccessKeyLength)
|
||||
}
|
||||
|
||||
func TestGenerateSecretAccessKey_URLSafe(t *testing.T) {
|
||||
// Generate multiple keys to increase probability of catching unsafe chars
|
||||
for i := 0; i < 100; i++ {
|
||||
secretKey, err := GenerateSecretAccessKey()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify no URL-unsafe characters that would cause authentication issues
|
||||
assert.NotContains(t, secretKey, "/", "Secret key should not contain /")
|
||||
assert.NotContains(t, secretKey, "+", "Secret key should not contain +")
|
||||
|
||||
// Verify only expected characters are present
|
||||
for _, char := range secretKey {
|
||||
assert.Contains(t, Charset, string(char), "Secret key contains unexpected character: %c", char)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSlicesEqual(t *testing.T) {
|
||||
tests := []struct {
|
||||
a []string
|
||||
b []string
|
||||
expected bool
|
||||
}{
|
||||
{[]string{"a", "b", "c"}, []string{"a", "b", "c"}, true},
|
||||
{[]string{"c", "b", "a"}, []string{"a", "b", "c"}, true}, // Order independent
|
||||
{[]string{"a", "b"}, []string{"a", "b", "c"}, false},
|
||||
{[]string{}, []string{}, true},
|
||||
{nil, nil, true},
|
||||
{[]string{"a"}, []string{"b"}, false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := StringSlicesEqual(test.a, test.b)
|
||||
assert.Equal(t, test.expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToStatementAction(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{StatementActionAdmin, s3_constants.ACTION_ADMIN},
|
||||
{StatementActionWrite, s3_constants.ACTION_WRITE},
|
||||
{StatementActionRead, s3_constants.ACTION_READ},
|
||||
{StatementActionList, s3_constants.ACTION_LIST},
|
||||
{StatementActionDelete, s3_constants.ACTION_DELETE_BUCKET},
|
||||
// Test fine-grained S3 action mappings (Issue #7864)
|
||||
{"DeleteObject", s3_constants.ACTION_WRITE},
|
||||
{"s3:DeleteObject", s3_constants.ACTION_WRITE},
|
||||
{"PutObject", s3_constants.ACTION_WRITE},
|
||||
{"s3:PutObject", s3_constants.ACTION_WRITE},
|
||||
{"GetObject", s3_constants.ACTION_READ},
|
||||
{"s3:GetObject", s3_constants.ACTION_READ},
|
||||
{"ListBucket", s3_constants.ACTION_LIST},
|
||||
{"s3:ListBucket", s3_constants.ACTION_LIST},
|
||||
{"PutObjectAcl", s3_constants.ACTION_WRITE_ACP},
|
||||
{"s3:PutObjectAcl", s3_constants.ACTION_WRITE_ACP},
|
||||
{"GetObjectAcl", s3_constants.ACTION_READ_ACP},
|
||||
{"s3:GetObjectAcl", s3_constants.ACTION_READ_ACP},
|
||||
{"unknown", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := MapToStatementAction(test.input)
|
||||
assert.Equal(t, test.expected, result, "Failed for input: %s", test.input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToIdentitiesAction(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{s3_constants.ACTION_ADMIN, StatementActionAdmin},
|
||||
{s3_constants.ACTION_WRITE, StatementActionWrite},
|
||||
{s3_constants.ACTION_READ, StatementActionRead},
|
||||
{s3_constants.ACTION_LIST, StatementActionList},
|
||||
{s3_constants.ACTION_DELETE_BUCKET, StatementActionDelete},
|
||||
{"unknown", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := MapToIdentitiesAction(test.input)
|
||||
assert.Equal(t, test.expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaskAccessKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"AKIAIOSFODNN7EXAMPLE", "AKIA***"},
|
||||
{"AKIA", "AKIA"},
|
||||
{"AKI", "AKI"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := MaskAccessKey(test.input)
|
||||
assert.Equal(t, test.expected, result)
|
||||
}
|
||||
}
|
||||
@@ -202,32 +202,6 @@ func (m *IAMManager) getFilerAddress() string {
|
||||
return "" // Fallback to empty string if no provider is set
|
||||
}
|
||||
|
||||
// createRoleStore creates a role store based on configuration
|
||||
func (m *IAMManager) createRoleStore(config *RoleStoreConfig) (RoleStore, error) {
|
||||
if config == nil {
|
||||
// Default to generic cached filer role store when no config provided
|
||||
return NewGenericCachedRoleStore(nil, nil)
|
||||
}
|
||||
|
||||
switch config.StoreType {
|
||||
case "", "filer":
|
||||
// Check if caching is explicitly disabled
|
||||
if config.StoreConfig != nil {
|
||||
if noCache, ok := config.StoreConfig["noCache"].(bool); ok && noCache {
|
||||
return NewFilerRoleStore(config.StoreConfig, nil)
|
||||
}
|
||||
}
|
||||
// Default to generic cached filer store for better performance
|
||||
return NewGenericCachedRoleStore(config.StoreConfig, nil)
|
||||
case "cached-filer", "generic-cached":
|
||||
return NewGenericCachedRoleStore(config.StoreConfig, nil)
|
||||
case "memory":
|
||||
return NewMemoryRoleStore(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported role store type: %s", config.StoreType)
|
||||
}
|
||||
}
|
||||
|
||||
// createRoleStoreWithProvider creates a role store with a filer address provider function
|
||||
func (m *IAMManager) createRoleStoreWithProvider(config *RoleStoreConfig, filerAddressProvider func() string) (RoleStore, error) {
|
||||
if config == nil {
|
||||
|
||||
@@ -388,157 +388,3 @@ type CachedFilerRoleStoreConfig struct {
|
||||
ListTTL string `json:"listTtl,omitempty"` // e.g., "1m", "30s"
|
||||
MaxCacheSize int `json:"maxCacheSize,omitempty"` // Maximum number of cached roles
|
||||
}
|
||||
|
||||
// NewCachedFilerRoleStore creates a new cached filer-based role store
|
||||
func NewCachedFilerRoleStore(config map[string]interface{}) (*CachedFilerRoleStore, error) {
|
||||
// Create underlying filer store
|
||||
filerStore, err := NewFilerRoleStore(config, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create filer role store: %w", err)
|
||||
}
|
||||
|
||||
// Parse cache configuration with defaults
|
||||
cacheTTL := 5 * time.Minute // Default 5 minutes for role cache
|
||||
listTTL := 1 * time.Minute // Default 1 minute for list cache
|
||||
maxCacheSize := 1000 // Default max 1000 cached roles
|
||||
|
||||
if config != nil {
|
||||
if ttlStr, ok := config["ttl"].(string); ok && ttlStr != "" {
|
||||
if parsed, err := time.ParseDuration(ttlStr); err == nil {
|
||||
cacheTTL = parsed
|
||||
}
|
||||
}
|
||||
if listTTLStr, ok := config["listTtl"].(string); ok && listTTLStr != "" {
|
||||
if parsed, err := time.ParseDuration(listTTLStr); err == nil {
|
||||
listTTL = parsed
|
||||
}
|
||||
}
|
||||
if maxSize, ok := config["maxCacheSize"].(int); ok && maxSize > 0 {
|
||||
maxCacheSize = maxSize
|
||||
}
|
||||
}
|
||||
|
||||
// Create ccache instances with appropriate configurations
|
||||
pruneCount := int64(maxCacheSize) >> 3
|
||||
if pruneCount <= 0 {
|
||||
pruneCount = 100
|
||||
}
|
||||
|
||||
store := &CachedFilerRoleStore{
|
||||
filerStore: filerStore,
|
||||
cache: ccache.New(ccache.Configure().MaxSize(int64(maxCacheSize)).ItemsToPrune(uint32(pruneCount))),
|
||||
listCache: ccache.New(ccache.Configure().MaxSize(100).ItemsToPrune(10)), // Smaller cache for lists
|
||||
ttl: cacheTTL,
|
||||
listTTL: listTTL,
|
||||
}
|
||||
|
||||
glog.V(2).Infof("Initialized CachedFilerRoleStore with TTL %v, List TTL %v, Max Cache Size %d",
|
||||
cacheTTL, listTTL, maxCacheSize)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// StoreRole stores a role definition and invalidates the cache
|
||||
func (c *CachedFilerRoleStore) StoreRole(ctx context.Context, filerAddress string, roleName string, role *RoleDefinition) error {
|
||||
// Store in filer
|
||||
err := c.filerStore.StoreRole(ctx, filerAddress, roleName, role)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate cache entries
|
||||
c.cache.Delete(roleName)
|
||||
c.listCache.Clear() // Invalidate list cache
|
||||
|
||||
glog.V(3).Infof("Stored and invalidated cache for role %s", roleName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRole retrieves a role definition with caching
|
||||
func (c *CachedFilerRoleStore) GetRole(ctx context.Context, filerAddress string, roleName string) (*RoleDefinition, error) {
|
||||
// Try to get from cache first
|
||||
item := c.cache.Get(roleName)
|
||||
if item != nil {
|
||||
// Cache hit - return cached role (DO NOT extend TTL)
|
||||
role := item.Value().(*RoleDefinition)
|
||||
glog.V(4).Infof("Cache hit for role %s", roleName)
|
||||
return copyRoleDefinition(role), nil
|
||||
}
|
||||
|
||||
// Cache miss - fetch from filer
|
||||
glog.V(4).Infof("Cache miss for role %s, fetching from filer", roleName)
|
||||
role, err := c.filerStore.GetRole(ctx, filerAddress, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache the result with TTL
|
||||
c.cache.Set(roleName, copyRoleDefinition(role), c.ttl)
|
||||
glog.V(3).Infof("Cached role %s with TTL %v", roleName, c.ttl)
|
||||
return role, nil
|
||||
}
|
||||
|
||||
// ListRoles lists all role names with caching
|
||||
func (c *CachedFilerRoleStore) ListRoles(ctx context.Context, filerAddress string) ([]string, error) {
|
||||
// Use a constant key for the role list cache
|
||||
const listCacheKey = "role_list"
|
||||
|
||||
// Try to get from list cache first
|
||||
item := c.listCache.Get(listCacheKey)
|
||||
if item != nil {
|
||||
// Cache hit - return cached list (DO NOT extend TTL)
|
||||
roles := item.Value().([]string)
|
||||
glog.V(4).Infof("List cache hit, returning %d roles", len(roles))
|
||||
return append([]string(nil), roles...), nil // Return a copy
|
||||
}
|
||||
|
||||
// Cache miss - fetch from filer
|
||||
glog.V(4).Infof("List cache miss, fetching from filer")
|
||||
roles, err := c.filerStore.ListRoles(ctx, filerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache the result with TTL (store a copy)
|
||||
rolesCopy := append([]string(nil), roles...)
|
||||
c.listCache.Set(listCacheKey, rolesCopy, c.listTTL)
|
||||
glog.V(3).Infof("Cached role list with %d entries, TTL %v", len(roles), c.listTTL)
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// DeleteRole deletes a role definition and invalidates the cache
|
||||
func (c *CachedFilerRoleStore) DeleteRole(ctx context.Context, filerAddress string, roleName string) error {
|
||||
// Delete from filer
|
||||
err := c.filerStore.DeleteRole(ctx, filerAddress, roleName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate cache entries
|
||||
c.cache.Delete(roleName)
|
||||
c.listCache.Clear() // Invalidate list cache
|
||||
|
||||
glog.V(3).Infof("Deleted and invalidated cache for role %s", roleName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearCache clears all cached entries (for testing or manual cache invalidation)
|
||||
func (c *CachedFilerRoleStore) ClearCache() {
|
||||
c.cache.Clear()
|
||||
c.listCache.Clear()
|
||||
glog.V(2).Infof("Cleared all role cache entries")
|
||||
}
|
||||
|
||||
// GetCacheStats returns cache statistics
|
||||
func (c *CachedFilerRoleStore) GetCacheStats() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"roleCache": map[string]interface{}{
|
||||
"size": c.cache.ItemCount(),
|
||||
"ttl": c.ttl.String(),
|
||||
},
|
||||
"listCache": map[string]interface{}{
|
||||
"size": c.listCache.ItemCount(),
|
||||
"ttl": c.listTTL.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,687 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConditionSetOperators(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
t.Run("ForAnyValue:StringEquals", func(t *testing.T) {
|
||||
trustPolicy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowOIDC",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAnyValue:StringEquals": {
|
||||
"oidc:roles": []string{"Dev.SeaweedFS.TestBucket.ReadWrite", "Dev.SeaweedFS.Admin"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Match: Admin is in the requested roles
|
||||
evalCtxMatch := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"Dev.SeaweedFS.Admin", "OtherRole"},
|
||||
},
|
||||
}
|
||||
resultMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxMatch)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultMatch.Effect)
|
||||
|
||||
// No Match
|
||||
evalCtxNoMatch := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"OtherRole1", "OtherRole2"},
|
||||
},
|
||||
}
|
||||
resultNoMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxNoMatch)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultNoMatch.Effect)
|
||||
|
||||
// No Match: Empty context for ForAnyValue (should deny)
|
||||
evalCtxEmpty := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{},
|
||||
},
|
||||
}
|
||||
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxEmpty)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultEmpty.Effect, "ForAnyValue should deny when context is empty")
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:StringEquals", func(t *testing.T) {
|
||||
trustPolicyAll := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowOIDCAll",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:StringEquals": {
|
||||
"oidc:roles": []string{"RoleA", "RoleB", "RoleC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Match: All requested roles ARE in the allowed set
|
||||
evalCtxAllMatch := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"RoleA", "RoleB"},
|
||||
},
|
||||
}
|
||||
resultAllMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxAllMatch)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultAllMatch.Effect)
|
||||
|
||||
// Fail: RoleD is NOT in the allowed set
|
||||
evalCtxAllFail := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"RoleA", "RoleD"},
|
||||
},
|
||||
}
|
||||
resultAllFail, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxAllFail)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultAllFail.Effect)
|
||||
|
||||
// Vacuously true: Request has NO roles
|
||||
evalCtxEmpty := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{},
|
||||
},
|
||||
}
|
||||
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxEmpty)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultEmpty.Effect)
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:NumericEqualsVacuouslyTrue", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowNumericAll",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:NumericEquals": {
|
||||
"aws:MultiFactorAuthAge": []string{"3600", "7200"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Vacuously true: Request has NO MFA age info
|
||||
evalCtxEmpty := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:MultiFactorAuthAge": []string{},
|
||||
},
|
||||
}
|
||||
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when numeric context is empty for ForAllValues")
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:BoolVacuouslyTrue", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowBoolAll",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:Bool": {
|
||||
"aws:SecureTransport": "true",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Vacuously true
|
||||
evalCtxEmpty := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SecureTransport": []interface{}{},
|
||||
},
|
||||
}
|
||||
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when bool context is empty for ForAllValues")
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:DateVacuouslyTrue", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowDateAll",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:DateGreaterThan": {
|
||||
"aws:CurrentTime": "2020-01-01T00:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Vacuously true
|
||||
evalCtxEmpty := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:CurrentTime": []interface{}{},
|
||||
},
|
||||
}
|
||||
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when date context is empty for ForAllValues")
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:DateWithLabelsAsStrings", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowDateStrings",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:DateGreaterThan": {
|
||||
"aws:CurrentTime": "2020-01-01T00:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
evalCtx := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:CurrentTime": []string{"2021-01-01T00:00:00Z", "2022-01-01T00:00:00Z"},
|
||||
},
|
||||
}
|
||||
result, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, result.Effect, "Should allow when date context is a slice of strings")
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:BoolWithLabelsAsStrings", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowBoolStrings",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:Bool": {
|
||||
"aws:SecureTransport": "true",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
evalCtx := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SecureTransport": []string{"true", "true"},
|
||||
},
|
||||
}
|
||||
result, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, result.Effect, "Should allow when bool context is a slice of strings")
|
||||
})
|
||||
|
||||
t.Run("StringEqualsIgnoreCaseWithVariable", func(t *testing.T) {
|
||||
policyDoc := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowVar",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"arn:aws:s3:::bucket/*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"StringEqualsIgnoreCase": {
|
||||
"s3:prefix": "${aws:username}/",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "var-policy", policyDoc)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtx := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/ALICE/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"s3:prefix": "ALICE/",
|
||||
"aws:username": "alice",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"var-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, result.Effect, "Should allow when variable expands and matches case-insensitively")
|
||||
})
|
||||
|
||||
t.Run("StringLike:CaseSensitivity", func(t *testing.T) {
|
||||
policyDoc := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowCaseSensitiveLike",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"arn:aws:s3:::bucket/*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"StringLike": {
|
||||
"s3:prefix": "Project/*",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "like-policy", policyDoc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Match: Case sensitive match
|
||||
evalCtxMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/Project/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"s3:prefix": "Project/data",
|
||||
},
|
||||
}
|
||||
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"like-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultMatch.Effect, "Should allow when case matches exactly")
|
||||
|
||||
// Fail: Case insensitive match (should fail for StringLike)
|
||||
evalCtxFail := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/project/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"s3:prefix": "project/data", // lowercase 'p'
|
||||
},
|
||||
}
|
||||
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"like-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when case does not match for StringLike")
|
||||
})
|
||||
|
||||
t.Run("NumericNotEquals:Logic", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "DenySpecificAges",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:NumericNotEquals": {
|
||||
"aws:MultiFactorAuthAge": []string{"3600", "7200"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "numeric-not-equals-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fail: One age matches an excluded value (3600)
|
||||
evalCtxFail := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:MultiFactorAuthAge": []string{"3600", "1800"},
|
||||
},
|
||||
}
|
||||
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"numeric-not-equals-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when one age matches an excluded value")
|
||||
|
||||
// Pass: No age matches any excluded value
|
||||
evalCtxPass := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:MultiFactorAuthAge": []string{"1800", "900"},
|
||||
},
|
||||
}
|
||||
resultPass, err := engine.Evaluate(context.Background(), "", evalCtxPass, []string{"numeric-not-equals-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultPass.Effect, "Should allow when no age matches excluded values")
|
||||
})
|
||||
|
||||
t.Run("DateNotEquals:Logic", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "DenySpecificTimes",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:DateNotEquals": {
|
||||
"aws:CurrentTime": []string{"2024-01-01T00:00:00Z", "2024-01-02T00:00:00Z"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "date-not-equals-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fail: One time matches an excluded value
|
||||
evalCtxFail := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:CurrentTime": []string{"2024-01-01T00:00:00Z", "2024-01-03T00:00:00Z"},
|
||||
},
|
||||
}
|
||||
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"date-not-equals-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when one date matches an excluded value")
|
||||
})
|
||||
|
||||
t.Run("IpAddress:SetOperators", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowSpecificIPs",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:IpAddress": {
|
||||
"aws:SourceIp": []string{"192.168.1.0/24", "10.0.0.1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "ip-set-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Match: All source IPs are in allowed ranges
|
||||
evalCtxMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": []string{"192.168.1.10", "10.0.0.1"},
|
||||
},
|
||||
}
|
||||
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-set-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultMatch.Effect)
|
||||
|
||||
// Fail: One source IP is NOT in allowed ranges
|
||||
evalCtxFail := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": []string{"192.168.1.10", "172.16.0.1"},
|
||||
},
|
||||
}
|
||||
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"ip-set-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultFail.Effect)
|
||||
|
||||
// ForAnyValue: IPAddress
|
||||
policyAny := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowAnySpecificIPs",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAnyValue:IpAddress": {
|
||||
"aws:SourceIp": []string{"192.168.1.0/24"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err = engine.AddPolicy("", "ip-any-policy", policyAny)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtxAnyMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": []string{"192.168.1.10", "172.16.0.1"},
|
||||
},
|
||||
}
|
||||
resultAnyMatch, err := engine.Evaluate(context.Background(), "", evalCtxAnyMatch, []string{"ip-any-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultAnyMatch.Effect)
|
||||
})
|
||||
|
||||
t.Run("IpAddress:SingleStringValue", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowSingleIP",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"IpAddress": {
|
||||
"aws:SourceIp": "192.168.1.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "ip-single-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtxMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "192.168.1.1",
|
||||
},
|
||||
}
|
||||
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-single-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultMatch.Effect)
|
||||
|
||||
evalCtxNoMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "10.0.0.1",
|
||||
},
|
||||
}
|
||||
resultNoMatch, err := engine.Evaluate(context.Background(), "", evalCtxNoMatch, []string{"ip-single-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultNoMatch.Effect)
|
||||
})
|
||||
|
||||
t.Run("Bool:StringSlicePolicyValues", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowWithBoolStrings",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"Bool": {
|
||||
"aws:SecureTransport": []string{"true", "false"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "bool-string-slice-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtx := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SecureTransport": "true",
|
||||
},
|
||||
}
|
||||
result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"bool-string-slice-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, result.Effect)
|
||||
})
|
||||
|
||||
t.Run("StringEqualsIgnoreCase:StringSlicePolicyValues", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowWithIgnoreCaseStrings",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"StringEqualsIgnoreCase": {
|
||||
"s3:x-amz-server-side-encryption": []string{"AES256", "aws:kms"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "string-ignorecase-slice-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtx := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"s3:x-amz-server-side-encryption": "aes256",
|
||||
},
|
||||
}
|
||||
result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"string-ignorecase-slice-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, result.Effect)
|
||||
})
|
||||
|
||||
t.Run("IpAddress:CustomContextKey", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowCustomIPKey",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"IpAddress": {
|
||||
"custom:VpcIp": "10.0.0.0/16",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "ip-custom-key-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtxMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"custom:VpcIp": "10.0.5.1",
|
||||
},
|
||||
}
|
||||
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-custom-key-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultMatch.Effect)
|
||||
|
||||
evalCtxNoMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"custom:VpcIp": "192.168.1.1",
|
||||
},
|
||||
}
|
||||
resultNoMatch, err := engine.Evaluate(context.Background(), "", evalCtxNoMatch, []string{"ip-custom-key-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultNoMatch.Effect)
|
||||
})
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNegationSetOperators(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
t.Run("ForAllValues:StringNotEquals", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "DenyAdmin",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:StringNotEquals": {
|
||||
"oidc:roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// All roles are NOT "Admin" -> Should Allow
|
||||
evalCtxAllow := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"User", "Developer"},
|
||||
},
|
||||
}
|
||||
resultAllow, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxAllow)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultAllow.Effect, "Should allow when ALL roles satisfy StringNotEquals Admin")
|
||||
|
||||
// One role is "Admin" -> Should Deny
|
||||
evalCtxDeny := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"Admin", "User"},
|
||||
},
|
||||
}
|
||||
resultDeny, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxDeny)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultDeny.Effect, "Should deny when one role is Admin and fails StringNotEquals")
|
||||
})
|
||||
|
||||
t.Run("ForAnyValue:StringNotEquals", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "Requirement",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAnyValue:StringNotEquals": {
|
||||
"oidc:roles": []string{"Prohibited"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// At least one role is NOT prohibited -> Should Allow
|
||||
evalCtxAllow := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"Prohibited", "Allowed"},
|
||||
},
|
||||
}
|
||||
resultAllow, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxAllow)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultAllow.Effect, "Should allow when at least one role is NOT Prohibited")
|
||||
|
||||
// All roles are Prohibited -> Should Deny
|
||||
evalCtxDeny := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"Prohibited", "Prohibited"},
|
||||
},
|
||||
}
|
||||
resultDeny, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxDeny)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultDeny.Effect, "Should deny when ALL roles are Prohibited")
|
||||
})
|
||||
}
|
||||
@@ -1155,11 +1155,6 @@ func ValidatePolicyDocumentWithType(policy *PolicyDocument, policyType string) e
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateStatement validates a single statement (for backward compatibility)
|
||||
func validateStatement(statement *Statement) error {
|
||||
return validateStatementWithType(statement, "resource")
|
||||
}
|
||||
|
||||
// validateStatementWithType validates a single statement based on policy type
|
||||
func validateStatementWithType(statement *Statement, policyType string) error {
|
||||
if statement.Effect != "Allow" && statement.Effect != "Deny" {
|
||||
@@ -1198,29 +1193,6 @@ func validateStatementWithType(statement *Statement, policyType string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchResource checks if a resource pattern matches a requested resource
|
||||
// Uses hybrid approach: simple suffix wildcards for compatibility, filepath.Match for complex patterns
|
||||
func matchResource(pattern, resource string) bool {
|
||||
if pattern == resource {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle simple suffix wildcard (backward compatibility)
|
||||
if strings.HasSuffix(pattern, "*") {
|
||||
prefix := pattern[:len(pattern)-1]
|
||||
return strings.HasPrefix(resource, prefix)
|
||||
}
|
||||
|
||||
// For complex patterns, use filepath.Match for advanced wildcard support (*, ?, [])
|
||||
matched, err := filepath.Match(pattern, resource)
|
||||
if err != nil {
|
||||
// Fallback to exact match if pattern is malformed
|
||||
return pattern == resource
|
||||
}
|
||||
|
||||
return matched
|
||||
}
|
||||
|
||||
// awsIAMMatch performs AWS IAM-compliant pattern matching with case-insensitivity and policy variable support
|
||||
func awsIAMMatch(pattern, value string, evalCtx *EvaluationContext) bool {
|
||||
// Step 1: Substitute policy variables (e.g., ${aws:username}, ${saml:username})
|
||||
@@ -1274,16 +1246,6 @@ func expandPolicyVariables(pattern string, evalCtx *EvaluationContext) string {
|
||||
return result
|
||||
}
|
||||
|
||||
// getContextValue safely gets a value from the evaluation context
|
||||
func getContextValue(evalCtx *EvaluationContext, key, defaultValue string) string {
|
||||
if value, exists := evalCtx.RequestContext[key]; exists {
|
||||
if str, ok := value.(string); ok {
|
||||
return str
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// AwsWildcardMatch performs case-insensitive wildcard matching like AWS IAM
|
||||
func AwsWildcardMatch(pattern, value string) bool {
|
||||
// Create regex pattern key for caching
|
||||
@@ -1322,29 +1284,6 @@ func AwsWildcardMatch(pattern, value string) bool {
|
||||
return regex.MatchString(value)
|
||||
}
|
||||
|
||||
// matchAction checks if an action pattern matches a requested action
|
||||
// Uses hybrid approach: simple suffix wildcards for compatibility, filepath.Match for complex patterns
|
||||
func matchAction(pattern, action string) bool {
|
||||
if pattern == action {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle simple suffix wildcard (backward compatibility)
|
||||
if strings.HasSuffix(pattern, "*") {
|
||||
prefix := pattern[:len(pattern)-1]
|
||||
return strings.HasPrefix(action, prefix)
|
||||
}
|
||||
|
||||
// For complex patterns, use filepath.Match for advanced wildcard support (*, ?, [])
|
||||
matched, err := filepath.Match(pattern, action)
|
||||
if err != nil {
|
||||
// Fallback to exact match if pattern is malformed
|
||||
return pattern == action
|
||||
}
|
||||
|
||||
return matched
|
||||
}
|
||||
|
||||
// evaluateStringConditionIgnoreCase evaluates string conditions with case insensitivity
|
||||
func (e *PolicyEngine) evaluateStringConditionIgnoreCase(block map[string]interface{}, evalCtx *EvaluationContext, shouldMatch bool, useWildcard bool, forAllValues bool) bool {
|
||||
for key, expectedValues := range block {
|
||||
|
||||
@@ -1,421 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestPrincipalMatching tests the matchesPrincipal method
|
||||
func TestPrincipalMatching(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
principal interface{}
|
||||
evalCtx *EvaluationContext
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "plain wildcard principal",
|
||||
principal: "*",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "structured wildcard federated principal",
|
||||
principal: map[string]interface{}{
|
||||
"Federated": "*",
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard in array",
|
||||
principal: map[string]interface{}{
|
||||
"Federated": []interface{}{"specific-provider", "*"},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "specific federated provider match",
|
||||
principal: map[string]interface{}{
|
||||
"Federated": "https://example.com/oidc",
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://example.com/oidc",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "specific federated provider no match",
|
||||
principal: map[string]interface{}{
|
||||
"Federated": "https://example.com/oidc",
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://other.com/oidc",
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "array with specific provider match",
|
||||
principal: map[string]interface{}{
|
||||
"Federated": []string{"https://provider1.com", "https://provider2.com"},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://provider2.com",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "AWS principal match",
|
||||
principal: map[string]interface{}{
|
||||
"AWS": "arn:aws:iam::123456789012:user/alice",
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:PrincipalArn": "arn:aws:iam::123456789012:user/alice",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Service principal match",
|
||||
principal: map[string]interface{}{
|
||||
"Service": "s3.amazonaws.com",
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:PrincipalServiceName": "s3.amazonaws.com",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := engine.matchesPrincipal(tt.principal, tt.evalCtx)
|
||||
assert.Equal(t, tt.want, result, "Principal matching failed for: %s", tt.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEvaluatePrincipalValue tests the evaluatePrincipalValue method
|
||||
func TestEvaluatePrincipalValue(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
principalValue interface{}
|
||||
contextKey string
|
||||
evalCtx *EvaluationContext
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "wildcard string",
|
||||
principalValue: "*",
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "specific string match",
|
||||
principalValue: "https://example.com",
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://example.com",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "specific string no match",
|
||||
principalValue: "https://example.com",
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://other.com",
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard in array",
|
||||
principalValue: []interface{}{"provider1", "*"},
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "array match",
|
||||
principalValue: []string{"provider1", "provider2", "provider3"},
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "provider2",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "array no match",
|
||||
principalValue: []string{"provider1", "provider2"},
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "provider3",
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "missing context key",
|
||||
principalValue: "specific-value",
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := engine.evaluatePrincipalValue(tt.principalValue, tt.evalCtx, tt.contextKey)
|
||||
assert.Equal(t, tt.want, result, "Principal value evaluation failed for: %s", tt.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTrustPolicyEvaluation tests the EvaluateTrustPolicy method
|
||||
func TestTrustPolicyEvaluation(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
trustPolicy *PolicyDocument
|
||||
evalCtx *EvaluationContext
|
||||
wantEffect Effect
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "wildcard federated principal allows any provider",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: map[string]interface{}{
|
||||
"Federated": "*",
|
||||
},
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://any-provider.com",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectAllow,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "specific federated principal matches",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: map[string]interface{}{
|
||||
"Federated": "https://example.com/oidc",
|
||||
},
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://example.com/oidc",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectAllow,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "specific federated principal does not match",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: map[string]interface{}{
|
||||
"Federated": "https://example.com/oidc",
|
||||
},
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://other.com/oidc",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectDeny,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "plain wildcard principal",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: "*",
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://any-provider.com",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectAllow,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "trust policy with conditions",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: map[string]interface{}{
|
||||
"Federated": "*",
|
||||
},
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"StringEquals": {
|
||||
"oidc:aud": "my-app-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://provider.com",
|
||||
"oidc:aud": "my-app-id",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectAllow,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "trust policy condition not met",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: map[string]interface{}{
|
||||
"Federated": "*",
|
||||
},
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"StringEquals": {
|
||||
"oidc:aud": "my-app-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://provider.com",
|
||||
"oidc:aud": "wrong-app-id",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectDeny,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := engine.EvaluateTrustPolicy(context.Background(), tt.trustPolicy, tt.evalCtx)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantEffect, result.Effect, "Trust policy evaluation failed for: %s", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetPrincipalContextKey tests the context key mapping
|
||||
func TestGetPrincipalContextKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
principalType string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "Federated principal",
|
||||
principalType: "Federated",
|
||||
want: "aws:FederatedProvider",
|
||||
},
|
||||
{
|
||||
name: "AWS principal",
|
||||
principalType: "AWS",
|
||||
want: "aws:PrincipalArn",
|
||||
},
|
||||
{
|
||||
name: "Service principal",
|
||||
principalType: "Service",
|
||||
want: "aws:PrincipalServiceName",
|
||||
},
|
||||
{
|
||||
name: "Custom principal type",
|
||||
principalType: "CustomType",
|
||||
want: "aws:PrincipalCustomType",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getPrincipalContextKey(tt.principalType)
|
||||
assert.Equal(t, tt.want, result, "Context key mapping failed for: %s", tt.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,426 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestPolicyEngineInitialization tests policy engine initialization
|
||||
func TestPolicyEngineInitialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *PolicyEngineConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: &PolicyEngineConfig{
|
||||
DefaultEffect: "Deny",
|
||||
StoreType: "memory",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid default effect",
|
||||
config: &PolicyEngineConfig{
|
||||
DefaultEffect: "Invalid",
|
||||
StoreType: "memory",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine := NewPolicyEngine()
|
||||
|
||||
err := engine.Initialize(tt.config)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, engine.IsInitialized())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPolicyDocumentValidation tests policy document structure validation
|
||||
func TestPolicyDocumentValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
policy *PolicyDocument
|
||||
wantErr bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid policy document",
|
||||
policy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowS3Read",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject", "s3:ListBucket"},
|
||||
Resource: []string{"arn:aws:s3:::mybucket/*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing version",
|
||||
policy: &PolicyDocument{
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"arn:aws:s3:::mybucket/*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errorMsg: "version is required",
|
||||
},
|
||||
{
|
||||
name: "empty statements",
|
||||
policy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{},
|
||||
},
|
||||
wantErr: true,
|
||||
errorMsg: "at least one statement is required",
|
||||
},
|
||||
{
|
||||
name: "invalid effect",
|
||||
policy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Maybe",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"arn:aws:s3:::mybucket/*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errorMsg: "invalid effect",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidatePolicyDocument(tt.policy)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPolicyEvaluation tests policy evaluation logic
|
||||
func TestPolicyEvaluation(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
// Add test policies
|
||||
readPolicy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowS3Read",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject", "s3:ListBucket"},
|
||||
Resource: []string{
|
||||
"arn:aws:s3:::public-bucket/*", // For object operations
|
||||
"arn:aws:s3:::public-bucket", // For bucket operations
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "read-policy", readPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
denyPolicy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "DenyS3Delete",
|
||||
Effect: "Deny",
|
||||
Action: []string{"s3:DeleteObject"},
|
||||
Resource: []string{"arn:aws:s3:::*"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = engine.AddPolicy("", "deny-policy", denyPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
context *EvaluationContext
|
||||
policies []string
|
||||
want Effect
|
||||
}{
|
||||
{
|
||||
name: "allow read access",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::public-bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "192.168.1.100",
|
||||
},
|
||||
},
|
||||
policies: []string{"read-policy"},
|
||||
want: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "deny delete access (explicit deny)",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:DeleteObject",
|
||||
Resource: "arn:aws:s3:::public-bucket/file.txt",
|
||||
},
|
||||
policies: []string{"read-policy", "deny-policy"},
|
||||
want: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "deny by default (no matching policy)",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:PutObject",
|
||||
Resource: "arn:aws:s3:::public-bucket/file.txt",
|
||||
},
|
||||
policies: []string{"read-policy"},
|
||||
want: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "allow with wildcard action",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:admin",
|
||||
Action: "s3:ListBucket",
|
||||
Resource: "arn:aws:s3:::public-bucket",
|
||||
},
|
||||
policies: []string{"read-policy"},
|
||||
want: EffectAllow,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := engine.Evaluate(context.Background(), "", tt.context, tt.policies)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, result.Effect)
|
||||
|
||||
// Verify evaluation details
|
||||
assert.NotNil(t, result.EvaluationDetails)
|
||||
assert.Equal(t, tt.context.Action, result.EvaluationDetails.Action)
|
||||
assert.Equal(t, tt.context.Resource, result.EvaluationDetails.Resource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConditionEvaluation tests policy conditions
|
||||
func TestConditionEvaluation(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
// Policy with IP address condition
|
||||
conditionalPolicy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowFromOfficeIP",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:*"},
|
||||
Resource: []string{"arn:aws:s3:::*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"IpAddress": {
|
||||
"aws:SourceIp": []string{"192.168.1.0/24", "10.0.0.0/8"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "ip-conditional", conditionalPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
context *EvaluationContext
|
||||
want Effect
|
||||
}{
|
||||
{
|
||||
name: "allow from office IP",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::mybucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "192.168.1.100",
|
||||
},
|
||||
},
|
||||
want: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "deny from external IP",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::mybucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "8.8.8.8",
|
||||
},
|
||||
},
|
||||
want: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "allow from internal IP",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:PutObject",
|
||||
Resource: "arn:aws:s3:::mybucket/newfile.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "10.1.2.3",
|
||||
},
|
||||
},
|
||||
want: EffectAllow,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := engine.Evaluate(context.Background(), "", tt.context, []string{"ip-conditional"})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, result.Effect)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResourceMatching tests resource ARN matching
|
||||
func TestResourceMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
policyResource string
|
||||
requestResource string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
policyResource: "arn:aws:s3:::mybucket/file.txt",
|
||||
requestResource: "arn:aws:s3:::mybucket/file.txt",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard match",
|
||||
policyResource: "arn:aws:s3:::mybucket/*",
|
||||
requestResource: "arn:aws:s3:::mybucket/folder/file.txt",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "bucket wildcard",
|
||||
policyResource: "arn:aws:s3:::*",
|
||||
requestResource: "arn:aws:s3:::anybucket/file.txt",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match different bucket",
|
||||
policyResource: "arn:aws:s3:::mybucket/*",
|
||||
requestResource: "arn:aws:s3:::otherbucket/file.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "prefix match",
|
||||
policyResource: "arn:aws:s3:::mybucket/documents/*",
|
||||
requestResource: "arn:aws:s3:::mybucket/documents/secret.txt",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchResource(tt.policyResource, tt.requestResource)
|
||||
assert.Equal(t, tt.want, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestActionMatching tests action pattern matching
|
||||
func TestActionMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
policyAction string
|
||||
requestAction string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
policyAction: "s3:GetObject",
|
||||
requestAction: "s3:GetObject",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard service",
|
||||
policyAction: "s3:*",
|
||||
requestAction: "s3:PutObject",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard all",
|
||||
policyAction: "*",
|
||||
requestAction: "filer:CreateEntry",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "prefix match",
|
||||
policyAction: "s3:Get*",
|
||||
requestAction: "s3:GetObject",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match different service",
|
||||
policyAction: "s3:GetObject",
|
||||
requestAction: "filer:GetEntry",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchAction(tt.policyAction, tt.requestAction)
|
||||
assert.Equal(t, tt.want, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to set up test policy engine
|
||||
func setupTestPolicyEngine(t *testing.T) *PolicyEngine {
|
||||
engine := NewPolicyEngine()
|
||||
config := &PolicyEngineConfig{
|
||||
DefaultEffect: "Deny",
|
||||
StoreType: "memory",
|
||||
}
|
||||
|
||||
err := engine.Initialize(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
return engine
|
||||
}
|
||||
@@ -1,246 +0,0 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIdentityProviderInterface tests the core identity provider interface
|
||||
func TestIdentityProviderInterface(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
provider IdentityProvider
|
||||
wantErr bool
|
||||
}{
|
||||
// We'll add test cases as we implement providers
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test provider name
|
||||
name := tt.provider.Name()
|
||||
assert.NotEmpty(t, name, "Provider name should not be empty")
|
||||
|
||||
// Test initialization
|
||||
err := tt.provider.Initialize(nil)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test authentication with invalid token
|
||||
ctx := context.Background()
|
||||
_, err = tt.provider.Authenticate(ctx, "invalid-token")
|
||||
assert.Error(t, err, "Should fail with invalid token")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExternalIdentityValidation tests external identity structure validation
|
||||
func TestExternalIdentityValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
identity *ExternalIdentity
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid identity",
|
||||
identity: &ExternalIdentity{
|
||||
UserID: "user123",
|
||||
Email: "user@example.com",
|
||||
DisplayName: "Test User",
|
||||
Groups: []string{"group1", "group2"},
|
||||
Attributes: map[string]string{"dept": "engineering"},
|
||||
Provider: "test-provider",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing user id",
|
||||
identity: &ExternalIdentity{
|
||||
Email: "user@example.com",
|
||||
Provider: "test-provider",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing provider",
|
||||
identity: &ExternalIdentity{
|
||||
UserID: "user123",
|
||||
Email: "user@example.com",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
identity: &ExternalIdentity{
|
||||
UserID: "user123",
|
||||
Email: "invalid-email",
|
||||
Provider: "test-provider",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.identity.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenClaimsValidation tests token claims structure
|
||||
func TestTokenClaimsValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
claims *TokenClaims
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "valid claims",
|
||||
claims: &TokenClaims{
|
||||
Subject: "user123",
|
||||
Issuer: "https://provider.example.com",
|
||||
Audience: "seaweedfs",
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
IssuedAt: time.Now().Add(-time.Minute),
|
||||
Claims: map[string]interface{}{"email": "user@example.com"},
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
claims: &TokenClaims{
|
||||
Subject: "user123",
|
||||
Issuer: "https://provider.example.com",
|
||||
Audience: "seaweedfs",
|
||||
ExpiresAt: time.Now().Add(-time.Hour), // Expired
|
||||
IssuedAt: time.Now().Add(-time.Hour * 2),
|
||||
Claims: map[string]interface{}{"email": "user@example.com"},
|
||||
},
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "future issued token",
|
||||
claims: &TokenClaims{
|
||||
Subject: "user123",
|
||||
Issuer: "https://provider.example.com",
|
||||
Audience: "seaweedfs",
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
IssuedAt: time.Now().Add(time.Hour), // Future
|
||||
Claims: map[string]interface{}{"email": "user@example.com"},
|
||||
},
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
valid := tt.claims.IsValid()
|
||||
assert.Equal(t, tt.valid, valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderRegistry tests provider registration and discovery
|
||||
func TestProviderRegistry(t *testing.T) {
|
||||
// Clear registry for test
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
t.Run("register provider", func(t *testing.T) {
|
||||
mockProvider := &MockProvider{name: "test-provider"}
|
||||
|
||||
err := registry.RegisterProvider(mockProvider)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test duplicate registration
|
||||
err = registry.RegisterProvider(mockProvider)
|
||||
assert.Error(t, err, "Should not allow duplicate registration")
|
||||
})
|
||||
|
||||
t.Run("get provider", func(t *testing.T) {
|
||||
provider, exists := registry.GetProvider("test-provider")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "test-provider", provider.Name())
|
||||
|
||||
// Test non-existent provider
|
||||
_, exists = registry.GetProvider("non-existent")
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("list providers", func(t *testing.T) {
|
||||
providers := registry.ListProviders()
|
||||
assert.Len(t, providers, 1)
|
||||
assert.Equal(t, "test-provider", providers[0])
|
||||
})
|
||||
}
|
||||
|
||||
// MockProvider for testing
|
||||
type MockProvider struct {
|
||||
name string
|
||||
initialized bool
|
||||
shouldError bool
|
||||
}
|
||||
|
||||
func (m *MockProvider) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockProvider) Initialize(config interface{}) error {
|
||||
if m.shouldError {
|
||||
return assert.AnError
|
||||
}
|
||||
m.initialized = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) {
|
||||
if !m.initialized {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
if token == "invalid-token" {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
return &ExternalIdentity{
|
||||
UserID: "test-user",
|
||||
Email: "test@example.com",
|
||||
DisplayName: "Test User",
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) {
|
||||
if !m.initialized || userID == "" {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
return &ExternalIdentity{
|
||||
UserID: userID,
|
||||
Email: userID + "@example.com",
|
||||
DisplayName: "User " + userID,
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) ValidateToken(ctx context.Context, token string) (*TokenClaims, error) {
|
||||
if !m.initialized || token == "invalid-token" {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
return &TokenClaims{
|
||||
Subject: "test-user",
|
||||
Issuer: "test-issuer",
|
||||
Audience: "seaweedfs",
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
IssuedAt: time.Now(),
|
||||
Claims: map[string]interface{}{"email": "test@example.com"},
|
||||
}, nil
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ProviderRegistry manages registered identity providers
|
||||
type ProviderRegistry struct {
|
||||
mu sync.RWMutex
|
||||
providers map[string]IdentityProvider
|
||||
}
|
||||
|
||||
// NewProviderRegistry creates a new provider registry
|
||||
func NewProviderRegistry() *ProviderRegistry {
|
||||
return &ProviderRegistry{
|
||||
providers: make(map[string]IdentityProvider),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterProvider registers a new identity provider
|
||||
func (r *ProviderRegistry) RegisterProvider(provider IdentityProvider) error {
|
||||
if provider == nil {
|
||||
return fmt.Errorf("provider cannot be nil")
|
||||
}
|
||||
|
||||
name := provider.Name()
|
||||
if name == "" {
|
||||
return fmt.Errorf("provider name cannot be empty")
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.providers[name]; exists {
|
||||
return fmt.Errorf("provider %s is already registered", name)
|
||||
}
|
||||
|
||||
r.providers[name] = provider
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProvider retrieves a provider by name
|
||||
func (r *ProviderRegistry) GetProvider(name string) (IdentityProvider, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
provider, exists := r.providers[name]
|
||||
return provider, exists
|
||||
}
|
||||
|
||||
// ListProviders returns all registered provider names
|
||||
func (r *ProviderRegistry) ListProviders() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var names []string
|
||||
for name := range r.providers {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// UnregisterProvider removes a provider from the registry
|
||||
func (r *ProviderRegistry) UnregisterProvider(name string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.providers[name]; !exists {
|
||||
return fmt.Errorf("provider %s is not registered", name)
|
||||
}
|
||||
|
||||
delete(r.providers, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all providers from the registry
|
||||
func (r *ProviderRegistry) Clear() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.providers = make(map[string]IdentityProvider)
|
||||
}
|
||||
|
||||
// GetProviderCount returns the number of registered providers
|
||||
func (r *ProviderRegistry) GetProviderCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
return len(r.providers)
|
||||
}
|
||||
|
||||
// Default global registry
|
||||
var defaultRegistry = NewProviderRegistry()
|
||||
|
||||
// RegisterProvider registers a provider in the default registry
|
||||
func RegisterProvider(provider IdentityProvider) error {
|
||||
return defaultRegistry.RegisterProvider(provider)
|
||||
}
|
||||
|
||||
// GetProvider retrieves a provider from the default registry
|
||||
func GetProvider(name string) (IdentityProvider, bool) {
|
||||
return defaultRegistry.GetProvider(name)
|
||||
}
|
||||
|
||||
// ListProviders returns all provider names from the default registry
|
||||
func ListProviders() []string {
|
||||
return defaultRegistry.ListProviders()
|
||||
}
|
||||
@@ -1,503 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/oidc"
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test-only constants for mock providers
|
||||
const (
|
||||
ProviderTypeMock = "mock"
|
||||
)
|
||||
|
||||
// createMockOIDCProvider creates a mock OIDC provider for testing
|
||||
// This is only available in test builds
|
||||
func createMockOIDCProvider(name string, config map[string]interface{}) (providers.IdentityProvider, error) {
|
||||
// Convert config to OIDC format
|
||||
factory := NewProviderFactory()
|
||||
oidcConfig, err := factory.convertToOIDCConfig(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set default values for mock provider if not provided
|
||||
if oidcConfig.Issuer == "" {
|
||||
oidcConfig.Issuer = "http://localhost:9999"
|
||||
}
|
||||
|
||||
provider := oidc.NewMockOIDCProvider(name)
|
||||
if err := provider.Initialize(oidcConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set up default test data for the mock provider
|
||||
provider.SetupDefaultTestData()
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// createMockJWT creates a test JWT token with the specified issuer for mock provider testing
|
||||
func createMockJWT(t *testing.T, issuer, subject string) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": subject,
|
||||
"aud": "test-client",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// TestCrossInstanceTokenUsage verifies that tokens generated by one STS instance
|
||||
// can be used and validated by other STS instances in a distributed environment
|
||||
func TestCrossInstanceTokenUsage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// Dummy filer address for testing
|
||||
|
||||
// Common configuration that would be shared across all instances in production
|
||||
sharedConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "distributed-sts-cluster", // SAME across all instances
|
||||
SigningKey: []byte(TestSigningKey32Chars), // SAME across all instances
|
||||
Providers: []*ProviderConfig{
|
||||
{
|
||||
Name: "company-oidc",
|
||||
Type: ProviderTypeOIDC,
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
ConfigFieldIssuer: "https://sso.company.com/realms/production",
|
||||
ConfigFieldClientID: "seaweedfs-cluster",
|
||||
ConfigFieldJWKSUri: "https://sso.company.com/realms/production/protocol/openid-connect/certs",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create multiple STS instances simulating different S3 gateway instances
|
||||
instanceA := NewSTSService() // e.g., s3-gateway-1
|
||||
instanceB := NewSTSService() // e.g., s3-gateway-2
|
||||
instanceC := NewSTSService() // e.g., s3-gateway-3
|
||||
|
||||
// Initialize all instances with IDENTICAL configuration
|
||||
err := instanceA.Initialize(sharedConfig)
|
||||
require.NoError(t, err, "Instance A should initialize")
|
||||
|
||||
err = instanceB.Initialize(sharedConfig)
|
||||
require.NoError(t, err, "Instance B should initialize")
|
||||
|
||||
err = instanceC.Initialize(sharedConfig)
|
||||
require.NoError(t, err, "Instance C should initialize")
|
||||
|
||||
// Set up mock trust policy validator for all instances (required for STS testing)
|
||||
mockValidator := &MockTrustPolicyValidator{}
|
||||
instanceA.SetTrustPolicyValidator(mockValidator)
|
||||
instanceB.SetTrustPolicyValidator(mockValidator)
|
||||
instanceC.SetTrustPolicyValidator(mockValidator)
|
||||
|
||||
// Manually register mock provider for testing (not available in production)
|
||||
mockProviderConfig := map[string]interface{}{
|
||||
ConfigFieldIssuer: "http://test-mock:9999",
|
||||
ConfigFieldClientID: TestClientID,
|
||||
}
|
||||
mockProviderA, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProviderB, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProviderC, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
instanceA.RegisterProvider(mockProviderA)
|
||||
instanceB.RegisterProvider(mockProviderB)
|
||||
instanceC.RegisterProvider(mockProviderC)
|
||||
|
||||
// Test 1: Token generated on Instance A can be validated on Instance B & C
|
||||
t.Run("cross_instance_token_validation", func(t *testing.T) {
|
||||
// Generate session token on Instance A
|
||||
sessionId := TestSessionID
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
|
||||
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err, "Instance A should generate token")
|
||||
|
||||
// Validate token on Instance B
|
||||
claimsFromB, err := instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
|
||||
require.NoError(t, err, "Instance B should validate token from Instance A")
|
||||
assert.Equal(t, sessionId, claimsFromB.SessionId, "Session ID should match")
|
||||
|
||||
// Validate same token on Instance C
|
||||
claimsFromC, err := instanceC.GetTokenGenerator().ValidateSessionToken(tokenFromA)
|
||||
require.NoError(t, err, "Instance C should validate token from Instance A")
|
||||
assert.Equal(t, sessionId, claimsFromC.SessionId, "Session ID should match")
|
||||
|
||||
// All instances should extract identical claims
|
||||
assert.Equal(t, claimsFromB.SessionId, claimsFromC.SessionId)
|
||||
assert.Equal(t, claimsFromB.ExpiresAt.Unix(), claimsFromC.ExpiresAt.Unix())
|
||||
assert.Equal(t, claimsFromB.IssuedAt.Unix(), claimsFromC.IssuedAt.Unix())
|
||||
})
|
||||
|
||||
// Test 2: Complete assume role flow across instances
|
||||
t.Run("cross_instance_assume_role_flow", func(t *testing.T) {
|
||||
// Step 1: User authenticates and assumes role on Instance A
|
||||
// Create a valid JWT token for the mock provider
|
||||
mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
|
||||
|
||||
assumeRequest := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/CrossInstanceTestRole",
|
||||
WebIdentityToken: mockToken, // JWT token for mock provider
|
||||
RoleSessionName: "cross-instance-test-session",
|
||||
DurationSeconds: int64ToPtr(3600),
|
||||
}
|
||||
|
||||
// Instance A processes assume role request
|
||||
responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
require.NoError(t, err, "Instance A should process assume role")
|
||||
|
||||
sessionToken := responseFromA.Credentials.SessionToken
|
||||
accessKeyId := responseFromA.Credentials.AccessKeyId
|
||||
secretAccessKey := responseFromA.Credentials.SecretAccessKey
|
||||
|
||||
// Verify response structure
|
||||
assert.NotEmpty(t, sessionToken, "Should have session token")
|
||||
assert.NotEmpty(t, accessKeyId, "Should have access key ID")
|
||||
assert.NotEmpty(t, secretAccessKey, "Should have secret access key")
|
||||
assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user")
|
||||
|
||||
// Step 2: Use session token on Instance B (different instance)
|
||||
sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Instance B should validate session token from Instance A")
|
||||
|
||||
assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName)
|
||||
assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn)
|
||||
|
||||
// Step 3: Use same session token on Instance C (yet another instance)
|
||||
sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Instance C should validate session token from Instance A")
|
||||
|
||||
// All instances should return identical session information
|
||||
assert.Equal(t, sessionInfoFromB.SessionId, sessionInfoFromC.SessionId)
|
||||
assert.Equal(t, sessionInfoFromB.SessionName, sessionInfoFromC.SessionName)
|
||||
assert.Equal(t, sessionInfoFromB.RoleArn, sessionInfoFromC.RoleArn)
|
||||
assert.Equal(t, sessionInfoFromB.Subject, sessionInfoFromC.Subject)
|
||||
assert.Equal(t, sessionInfoFromB.Provider, sessionInfoFromC.Provider)
|
||||
})
|
||||
|
||||
// Test 3: Session revocation across instances
|
||||
t.Run("cross_instance_session_revocation", func(t *testing.T) {
|
||||
// Create session on Instance A
|
||||
mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
|
||||
|
||||
assumeRequest := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/RevocationTestRole",
|
||||
WebIdentityToken: mockToken,
|
||||
RoleSessionName: "revocation-test-session",
|
||||
}
|
||||
|
||||
response, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
require.NoError(t, err)
|
||||
sessionToken := response.Credentials.SessionToken
|
||||
|
||||
// Verify token works on Instance B
|
||||
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Token should be valid on Instance B initially")
|
||||
|
||||
// Validate session on Instance C to verify cross-instance token compatibility
|
||||
_, err = instanceC.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Instance C should be able to validate session token")
|
||||
|
||||
// In a stateless JWT system, tokens remain valid on all instances since they're self-contained
|
||||
// No revocation is possible without breaking the stateless architecture
|
||||
_, err = instanceA.ValidateSessionToken(ctx, sessionToken)
|
||||
assert.NoError(t, err, "Token should still be valid on Instance A (stateless system)")
|
||||
|
||||
// Verify token is still valid on Instance B
|
||||
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
|
||||
assert.NoError(t, err, "Token should still be valid on Instance B (stateless system)")
|
||||
})
|
||||
|
||||
// Test 4: Provider consistency across instances
|
||||
t.Run("provider_consistency_affects_token_generation", func(t *testing.T) {
|
||||
// All instances should have same providers and be able to process same OIDC tokens
|
||||
providerNamesA := instanceA.getProviderNames()
|
||||
providerNamesB := instanceB.getProviderNames()
|
||||
providerNamesC := instanceC.getProviderNames()
|
||||
|
||||
assert.ElementsMatch(t, providerNamesA, providerNamesB, "Instance A and B should have same providers")
|
||||
assert.ElementsMatch(t, providerNamesB, providerNamesC, "Instance B and C should have same providers")
|
||||
|
||||
// All instances should be able to process same web identity token
|
||||
testToken := createMockJWT(t, "http://test-mock:9999", "test-user")
|
||||
|
||||
// Try to assume role with same token on different instances
|
||||
assumeRequest := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/ProviderTestRole",
|
||||
WebIdentityToken: testToken,
|
||||
RoleSessionName: "provider-consistency-test",
|
||||
}
|
||||
|
||||
// Should work on any instance
|
||||
responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
|
||||
require.NoError(t, errA, "Instance A should process OIDC token")
|
||||
require.NoError(t, errB, "Instance B should process OIDC token")
|
||||
require.NoError(t, errC, "Instance C should process OIDC token")
|
||||
|
||||
// All should return valid responses (sessions will have different IDs but same structure)
|
||||
assert.NotEmpty(t, responseA.Credentials.SessionToken)
|
||||
assert.NotEmpty(t, responseB.Credentials.SessionToken)
|
||||
assert.NotEmpty(t, responseC.Credentials.SessionToken)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSTSDistributedConfigurationRequirements tests the configuration requirements
|
||||
// for cross-instance token compatibility
|
||||
func TestSTSDistributedConfigurationRequirements(t *testing.T) {
|
||||
_ = "localhost:8888" // Dummy filer address for testing (not used in these tests)
|
||||
|
||||
t.Run("same_signing_key_required", func(t *testing.T) {
|
||||
// Instance A with signing key 1
|
||||
configA := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("signing-key-1-32-characters-long"),
|
||||
}
|
||||
|
||||
// Instance B with different signing key
|
||||
configB := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("signing-key-2-32-characters-long"), // DIFFERENT!
|
||||
}
|
||||
|
||||
instanceA := NewSTSService()
|
||||
instanceB := NewSTSService()
|
||||
|
||||
err := instanceA.Initialize(configA)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = instanceB.Initialize(configB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate token on Instance A
|
||||
sessionId := "test-session"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instance A should validate its own token
|
||||
_, err = instanceA.GetTokenGenerator().ValidateSessionToken(tokenFromA)
|
||||
assert.NoError(t, err, "Instance A should validate own token")
|
||||
|
||||
// Instance B should REJECT token due to different signing key
|
||||
_, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
|
||||
assert.Error(t, err, "Instance B should reject token with different signing key")
|
||||
assert.Contains(t, err.Error(), "invalid token", "Should be signature validation error")
|
||||
})
|
||||
|
||||
t.Run("same_issuer_required", func(t *testing.T) {
|
||||
sharedSigningKey := []byte("shared-signing-key-32-characters-lo")
|
||||
|
||||
// Instance A with issuer 1
|
||||
configA := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "sts-cluster-1",
|
||||
SigningKey: sharedSigningKey,
|
||||
}
|
||||
|
||||
// Instance B with different issuer
|
||||
configB := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "sts-cluster-2", // DIFFERENT!
|
||||
SigningKey: sharedSigningKey,
|
||||
}
|
||||
|
||||
instanceA := NewSTSService()
|
||||
instanceB := NewSTSService()
|
||||
|
||||
err := instanceA.Initialize(configA)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = instanceB.Initialize(configB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate token on Instance A
|
||||
sessionId := "test-session"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instance B should REJECT token due to different issuer
|
||||
_, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
|
||||
assert.Error(t, err, "Instance B should reject token with different issuer")
|
||||
assert.Contains(t, err.Error(), "invalid issuer", "Should be issuer validation error")
|
||||
})
|
||||
|
||||
t.Run("identical_configuration_required", func(t *testing.T) {
|
||||
// Identical configuration
|
||||
identicalConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "production-sts-cluster",
|
||||
SigningKey: []byte("production-signing-key-32-chars-l"),
|
||||
}
|
||||
|
||||
// Create multiple instances with identical config
|
||||
instances := make([]*STSService, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
instances[i] = NewSTSService()
|
||||
err := instances[i].Initialize(identicalConfig)
|
||||
require.NoError(t, err, "Instance %d should initialize", i)
|
||||
}
|
||||
|
||||
// Generate token on Instance 0
|
||||
sessionId := "multi-instance-test"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
token, err := instances[0].GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// All other instances should validate the token
|
||||
for i := 1; i < 5; i++ {
|
||||
claims, err := instances[i].GetTokenGenerator().ValidateSessionToken(token)
|
||||
require.NoError(t, err, "Instance %d should validate token", i)
|
||||
assert.Equal(t, sessionId, claims.SessionId, "Instance %d should extract correct session ID", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios
|
||||
func TestSTSRealWorldDistributedScenarios(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) {
|
||||
// Simulate real production scenario:
|
||||
// 1. User authenticates with OIDC provider
|
||||
// 2. User calls AssumeRoleWithWebIdentity on S3 Gateway 1
|
||||
// 3. User makes S3 requests that hit S3 Gateway 2 & 3 via load balancer
|
||||
// 4. All instances should handle the session token correctly
|
||||
|
||||
productionConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{2 * time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{24 * time.Hour},
|
||||
Issuer: "seaweedfs-production-sts",
|
||||
SigningKey: []byte("prod-signing-key-32-characters-lon"),
|
||||
|
||||
Providers: []*ProviderConfig{
|
||||
{
|
||||
Name: "corporate-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://sso.company.com/realms/production",
|
||||
"clientId": "seaweedfs-prod-cluster",
|
||||
"clientSecret": "supersecret-prod-key",
|
||||
"scopes": []string{"openid", "profile", "email", "groups"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create 3 S3 Gateway instances behind load balancer
|
||||
gateway1 := NewSTSService()
|
||||
gateway2 := NewSTSService()
|
||||
gateway3 := NewSTSService()
|
||||
|
||||
err := gateway1.Initialize(productionConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = gateway2.Initialize(productionConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = gateway3.Initialize(productionConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up mock trust policy validator for all gateway instances
|
||||
mockValidator := &MockTrustPolicyValidator{}
|
||||
gateway1.SetTrustPolicyValidator(mockValidator)
|
||||
gateway2.SetTrustPolicyValidator(mockValidator)
|
||||
gateway3.SetTrustPolicyValidator(mockValidator)
|
||||
|
||||
// Manually register mock provider for testing (not available in production)
|
||||
mockProviderConfig := map[string]interface{}{
|
||||
ConfigFieldIssuer: "http://test-mock:9999",
|
||||
ConfigFieldClientID: "test-client-id",
|
||||
}
|
||||
mockProvider1, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProvider2, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProvider3, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
gateway1.RegisterProvider(mockProvider1)
|
||||
gateway2.RegisterProvider(mockProvider2)
|
||||
gateway3.RegisterProvider(mockProvider3)
|
||||
|
||||
// Step 1: User authenticates and hits Gateway 1 for AssumeRole
|
||||
mockToken := createMockJWT(t, "http://test-mock:9999", "production-user")
|
||||
|
||||
assumeRequest := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/ProductionS3User",
|
||||
WebIdentityToken: mockToken, // JWT token from mock provider
|
||||
RoleSessionName: "user-production-session",
|
||||
DurationSeconds: int64ToPtr(7200), // 2 hours
|
||||
}
|
||||
|
||||
stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
require.NoError(t, err, "Gateway 1 should handle AssumeRole")
|
||||
|
||||
sessionToken := stsResponse.Credentials.SessionToken
|
||||
accessKey := stsResponse.Credentials.AccessKeyId
|
||||
secretKey := stsResponse.Credentials.SecretAccessKey
|
||||
|
||||
// Step 2: User makes S3 requests that hit different gateways via load balancer
|
||||
// Simulate S3 request validation on Gateway 2
|
||||
sessionInfo2, err := gateway2.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Gateway 2 should validate session from Gateway 1")
|
||||
assert.Equal(t, "user-production-session", sessionInfo2.SessionName)
|
||||
assert.Equal(t, "arn:aws:iam::role/ProductionS3User", sessionInfo2.RoleArn)
|
||||
|
||||
// Simulate S3 request validation on Gateway 3
|
||||
sessionInfo3, err := gateway3.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Gateway 3 should validate session from Gateway 1")
|
||||
assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session")
|
||||
|
||||
// Step 3: Verify credentials are consistent
|
||||
assert.Equal(t, accessKey, stsResponse.Credentials.AccessKeyId, "Access key should be consistent")
|
||||
assert.Equal(t, secretKey, stsResponse.Credentials.SecretAccessKey, "Secret key should be consistent")
|
||||
|
||||
// Step 4: Session expiration should be honored across all instances
|
||||
assert.True(t, sessionInfo2.ExpiresAt.After(time.Now()), "Session should not be expired")
|
||||
assert.True(t, sessionInfo3.ExpiresAt.After(time.Now()), "Session should not be expired")
|
||||
|
||||
// Step 5: Token should be identical when parsed
|
||||
claims2, err := gateway2.GetTokenGenerator().ValidateSessionToken(sessionToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims3, err := gateway3.GetTokenGenerator().ValidateSessionToken(sessionToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, claims2.SessionId, claims3.SessionId, "Session IDs should match")
|
||||
assert.Equal(t, claims2.ExpiresAt.Unix(), claims3.ExpiresAt.Unix(), "Expiration should match")
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to convert int64 to pointer
|
||||
func int64ToPtr(i int64) *int64 {
|
||||
return &i
|
||||
}
|
||||
@@ -1,340 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDistributedSTSService verifies that multiple STS instances with identical configurations
|
||||
// behave consistently across distributed environments
|
||||
func TestDistributedSTSService(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Common configuration for all instances
|
||||
commonConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "distributed-sts-test",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
|
||||
Providers: []*ProviderConfig{
|
||||
{
|
||||
Name: "keycloak-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "http://keycloak:8080/realms/seaweedfs-test",
|
||||
"clientId": "seaweedfs-s3",
|
||||
"jwksUri": "http://keycloak:8080/realms/seaweedfs-test/protocol/openid-connect/certs",
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Name: "disabled-ldap",
|
||||
Type: "oidc", // Use OIDC as placeholder since LDAP isn't implemented
|
||||
Enabled: false,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "ldap://company.com",
|
||||
"clientId": "ldap-client",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create multiple STS instances simulating distributed deployment
|
||||
instance1 := NewSTSService()
|
||||
instance2 := NewSTSService()
|
||||
instance3 := NewSTSService()
|
||||
|
||||
// Initialize all instances with identical configuration
|
||||
err := instance1.Initialize(commonConfig)
|
||||
require.NoError(t, err, "Instance 1 should initialize successfully")
|
||||
|
||||
err = instance2.Initialize(commonConfig)
|
||||
require.NoError(t, err, "Instance 2 should initialize successfully")
|
||||
|
||||
err = instance3.Initialize(commonConfig)
|
||||
require.NoError(t, err, "Instance 3 should initialize successfully")
|
||||
|
||||
// Manually register mock providers for testing (not available in production)
|
||||
mockProviderConfig := map[string]interface{}{
|
||||
"issuer": "http://localhost:9999",
|
||||
"clientId": "test-client",
|
||||
}
|
||||
mockProvider1, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProvider2, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProvider3, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
instance1.RegisterProvider(mockProvider1)
|
||||
instance2.RegisterProvider(mockProvider2)
|
||||
instance3.RegisterProvider(mockProvider3)
|
||||
|
||||
// Verify all instances have identical provider configurations
|
||||
t.Run("provider_consistency", func(t *testing.T) {
|
||||
// All instances should have same number of providers
|
||||
assert.Len(t, instance1.providers, 2, "Instance 1 should have 2 enabled providers")
|
||||
assert.Len(t, instance2.providers, 2, "Instance 2 should have 2 enabled providers")
|
||||
assert.Len(t, instance3.providers, 2, "Instance 3 should have 2 enabled providers")
|
||||
|
||||
// All instances should have same provider names
|
||||
instance1Names := instance1.getProviderNames()
|
||||
instance2Names := instance2.getProviderNames()
|
||||
instance3Names := instance3.getProviderNames()
|
||||
|
||||
assert.ElementsMatch(t, instance1Names, instance2Names, "Instance 1 and 2 should have same providers")
|
||||
assert.ElementsMatch(t, instance2Names, instance3Names, "Instance 2 and 3 should have same providers")
|
||||
|
||||
// Verify specific providers exist on all instances
|
||||
expectedProviders := []string{"keycloak-oidc", "test-mock-provider"}
|
||||
assert.ElementsMatch(t, instance1Names, expectedProviders, "Instance 1 should have expected providers")
|
||||
assert.ElementsMatch(t, instance2Names, expectedProviders, "Instance 2 should have expected providers")
|
||||
assert.ElementsMatch(t, instance3Names, expectedProviders, "Instance 3 should have expected providers")
|
||||
|
||||
// Verify disabled providers are not loaded
|
||||
assert.NotContains(t, instance1Names, "disabled-ldap", "Disabled providers should not be loaded")
|
||||
assert.NotContains(t, instance2Names, "disabled-ldap", "Disabled providers should not be loaded")
|
||||
assert.NotContains(t, instance3Names, "disabled-ldap", "Disabled providers should not be loaded")
|
||||
})
|
||||
|
||||
// Test token generation consistency across instances
|
||||
t.Run("token_generation_consistency", func(t *testing.T) {
|
||||
sessionId := "test-session-123"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
|
||||
// Generate tokens from different instances
|
||||
token1, err1 := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
token2, err2 := instance2.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
token3, err3 := instance3.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
|
||||
require.NoError(t, err1, "Instance 1 token generation should succeed")
|
||||
require.NoError(t, err2, "Instance 2 token generation should succeed")
|
||||
require.NoError(t, err3, "Instance 3 token generation should succeed")
|
||||
|
||||
// All tokens should be different (due to timestamp variations)
|
||||
// But they should all be valid JWTs with same signing key
|
||||
assert.NotEmpty(t, token1)
|
||||
assert.NotEmpty(t, token2)
|
||||
assert.NotEmpty(t, token3)
|
||||
})
|
||||
|
||||
// Test token validation consistency - any instance should validate tokens from any other instance
|
||||
t.Run("cross_instance_token_validation", func(t *testing.T) {
|
||||
sessionId := "cross-validation-session"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
|
||||
// Generate token on instance 1
|
||||
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate on all instances
|
||||
claims1, err1 := instance1.GetTokenGenerator().ValidateSessionToken(token)
|
||||
claims2, err2 := instance2.GetTokenGenerator().ValidateSessionToken(token)
|
||||
claims3, err3 := instance3.GetTokenGenerator().ValidateSessionToken(token)
|
||||
|
||||
require.NoError(t, err1, "Instance 1 should validate token from instance 1")
|
||||
require.NoError(t, err2, "Instance 2 should validate token from instance 1")
|
||||
require.NoError(t, err3, "Instance 3 should validate token from instance 1")
|
||||
|
||||
// All instances should extract same session ID
|
||||
assert.Equal(t, sessionId, claims1.SessionId)
|
||||
assert.Equal(t, sessionId, claims2.SessionId)
|
||||
assert.Equal(t, sessionId, claims3.SessionId)
|
||||
|
||||
assert.Equal(t, claims1.SessionId, claims2.SessionId)
|
||||
assert.Equal(t, claims2.SessionId, claims3.SessionId)
|
||||
})
|
||||
|
||||
// Test provider access consistency
|
||||
t.Run("provider_access_consistency", func(t *testing.T) {
|
||||
// All instances should be able to access the same providers
|
||||
provider1, exists1 := instance1.providers["test-mock-provider"]
|
||||
provider2, exists2 := instance2.providers["test-mock-provider"]
|
||||
provider3, exists3 := instance3.providers["test-mock-provider"]
|
||||
|
||||
assert.True(t, exists1, "Instance 1 should have test-mock-provider")
|
||||
assert.True(t, exists2, "Instance 2 should have test-mock-provider")
|
||||
assert.True(t, exists3, "Instance 3 should have test-mock-provider")
|
||||
|
||||
assert.Equal(t, provider1.Name(), provider2.Name())
|
||||
assert.Equal(t, provider2.Name(), provider3.Name())
|
||||
|
||||
// Test authentication with the mock provider on all instances
|
||||
testToken := "valid_test_token"
|
||||
|
||||
identity1, err1 := provider1.Authenticate(ctx, testToken)
|
||||
identity2, err2 := provider2.Authenticate(ctx, testToken)
|
||||
identity3, err3 := provider3.Authenticate(ctx, testToken)
|
||||
|
||||
require.NoError(t, err1, "Instance 1 provider should authenticate successfully")
|
||||
require.NoError(t, err2, "Instance 2 provider should authenticate successfully")
|
||||
require.NoError(t, err3, "Instance 3 provider should authenticate successfully")
|
||||
|
||||
// All instances should return identical identity information
|
||||
assert.Equal(t, identity1.UserID, identity2.UserID)
|
||||
assert.Equal(t, identity2.UserID, identity3.UserID)
|
||||
assert.Equal(t, identity1.Email, identity2.Email)
|
||||
assert.Equal(t, identity2.Email, identity3.Email)
|
||||
assert.Equal(t, identity1.Provider, identity2.Provider)
|
||||
assert.Equal(t, identity2.Provider, identity3.Provider)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSTSConfigurationValidation tests configuration validation for distributed deployments
|
||||
func TestSTSConfigurationValidation(t *testing.T) {
|
||||
t.Run("consistent_signing_keys_required", func(t *testing.T) {
|
||||
// Different signing keys should result in incompatible token validation
|
||||
config1 := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("signing-key-1-32-characters-long"),
|
||||
}
|
||||
|
||||
config2 := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("signing-key-2-32-characters-long"), // Different key!
|
||||
}
|
||||
|
||||
instance1 := NewSTSService()
|
||||
instance2 := NewSTSService()
|
||||
|
||||
err1 := instance1.Initialize(config1)
|
||||
err2 := instance2.Initialize(config2)
|
||||
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
|
||||
// Generate token on instance 1
|
||||
sessionId := "test-session"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instance 1 should validate its own token
|
||||
_, err = instance1.GetTokenGenerator().ValidateSessionToken(token)
|
||||
assert.NoError(t, err, "Instance 1 should validate its own token")
|
||||
|
||||
// Instance 2 should reject token from instance 1 (different signing key)
|
||||
_, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
|
||||
assert.Error(t, err, "Instance 2 should reject token with different signing key")
|
||||
})
|
||||
|
||||
t.Run("consistent_issuer_required", func(t *testing.T) {
|
||||
// Different issuers should result in incompatible tokens
|
||||
commonSigningKey := []byte("shared-signing-key-32-characters-lo")
|
||||
|
||||
config1 := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "sts-instance-1",
|
||||
SigningKey: commonSigningKey,
|
||||
}
|
||||
|
||||
config2 := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "sts-instance-2", // Different issuer!
|
||||
SigningKey: commonSigningKey,
|
||||
}
|
||||
|
||||
instance1 := NewSTSService()
|
||||
instance2 := NewSTSService()
|
||||
|
||||
err1 := instance1.Initialize(config1)
|
||||
err2 := instance2.Initialize(config2)
|
||||
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
|
||||
// Generate token on instance 1
|
||||
sessionId := "test-session"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instance 2 should reject token due to issuer mismatch
|
||||
// (Even though signing key is the same, issuer validation will fail)
|
||||
_, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
|
||||
assert.Error(t, err, "Instance 2 should reject token with different issuer")
|
||||
})
|
||||
}
|
||||
|
||||
// TestProviderFactoryDistributed tests the provider factory in distributed scenarios
|
||||
func TestProviderFactoryDistributed(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Simulate configuration that would be identical across all instances
|
||||
configs := []*ProviderConfig{
|
||||
{
|
||||
Name: "production-keycloak",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://keycloak.company.com/realms/seaweedfs",
|
||||
"clientId": "seaweedfs-prod",
|
||||
"clientSecret": "super-secret-key",
|
||||
"jwksUri": "https://keycloak.company.com/realms/seaweedfs/protocol/openid-connect/certs",
|
||||
"scopes": []string{"openid", "profile", "email", "roles"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "backup-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: false, // Disabled by default
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://backup-oidc.company.com",
|
||||
"clientId": "seaweedfs-backup",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create providers multiple times (simulating multiple instances)
|
||||
providers1, err1 := factory.LoadProvidersFromConfig(configs)
|
||||
providers2, err2 := factory.LoadProvidersFromConfig(configs)
|
||||
providers3, err3 := factory.LoadProvidersFromConfig(configs)
|
||||
|
||||
require.NoError(t, err1, "First load should succeed")
|
||||
require.NoError(t, err2, "Second load should succeed")
|
||||
require.NoError(t, err3, "Third load should succeed")
|
||||
|
||||
// All instances should have same provider counts
|
||||
assert.Len(t, providers1, 1, "First instance should have 1 enabled provider")
|
||||
assert.Len(t, providers2, 1, "Second instance should have 1 enabled provider")
|
||||
assert.Len(t, providers3, 1, "Third instance should have 1 enabled provider")
|
||||
|
||||
// All instances should have same provider names
|
||||
names1 := make([]string, 0, len(providers1))
|
||||
names2 := make([]string, 0, len(providers2))
|
||||
names3 := make([]string, 0, len(providers3))
|
||||
|
||||
for name := range providers1 {
|
||||
names1 = append(names1, name)
|
||||
}
|
||||
for name := range providers2 {
|
||||
names2 = append(names2, name)
|
||||
}
|
||||
for name := range providers3 {
|
||||
names3 = append(names3, name)
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, names1, names2, "Instance 1 and 2 should have same provider names")
|
||||
assert.ElementsMatch(t, names2, names3, "Instance 2 and 3 should have same provider names")
|
||||
|
||||
// Verify specific providers
|
||||
expectedProviders := []string{"production-keycloak"}
|
||||
assert.ElementsMatch(t, names1, expectedProviders, "Should have expected enabled providers")
|
||||
|
||||
// Verify disabled providers are not included
|
||||
assert.NotContains(t, names1, "backup-oidc", "Disabled providers should not be loaded")
|
||||
assert.NotContains(t, names2, "backup-oidc", "Disabled providers should not be loaded")
|
||||
assert.NotContains(t, names3, "backup-oidc", "Disabled providers should not be loaded")
|
||||
}
|
||||
@@ -274,69 +274,3 @@ func (f *ProviderFactory) convertToRoleMapping(value interface{}) (*providers.Ro
|
||||
|
||||
return roleMapping, nil
|
||||
}
|
||||
|
||||
// ValidateProviderConfig validates a provider configuration
|
||||
func (f *ProviderFactory) ValidateProviderConfig(config *ProviderConfig) error {
|
||||
if config == nil {
|
||||
return fmt.Errorf("provider config cannot be nil")
|
||||
}
|
||||
|
||||
if config.Name == "" {
|
||||
return fmt.Errorf("provider name cannot be empty")
|
||||
}
|
||||
|
||||
if config.Type == "" {
|
||||
return fmt.Errorf("provider type cannot be empty")
|
||||
}
|
||||
|
||||
if config.Config == nil {
|
||||
return fmt.Errorf("provider config cannot be nil")
|
||||
}
|
||||
|
||||
// Type-specific validation
|
||||
switch config.Type {
|
||||
case "oidc":
|
||||
return f.validateOIDCConfig(config.Config)
|
||||
case "ldap":
|
||||
return f.validateLDAPConfig(config.Config)
|
||||
case "saml":
|
||||
return f.validateSAMLConfig(config.Config)
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider type: %s", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// validateOIDCConfig validates OIDC provider configuration
|
||||
func (f *ProviderFactory) validateOIDCConfig(config map[string]interface{}) error {
|
||||
if _, ok := config[ConfigFieldIssuer]; !ok {
|
||||
return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldIssuer)
|
||||
}
|
||||
|
||||
if _, ok := config[ConfigFieldClientID]; !ok {
|
||||
return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldClientID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateLDAPConfig validates LDAP provider configuration
|
||||
func (f *ProviderFactory) validateLDAPConfig(config map[string]interface{}) error {
|
||||
if _, ok := config["server"]; !ok {
|
||||
return fmt.Errorf("LDAP provider requires 'server' field")
|
||||
}
|
||||
if _, ok := config["baseDN"]; !ok {
|
||||
return fmt.Errorf("LDAP provider requires 'baseDN' field")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateSAMLConfig validates SAML provider configuration
|
||||
func (f *ProviderFactory) validateSAMLConfig(config map[string]interface{}) error {
|
||||
// TODO: Implement when SAML provider is available
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSupportedProviderTypes returns list of supported provider types
|
||||
func (f *ProviderFactory) GetSupportedProviderTypes() []string {
|
||||
return []string{ProviderTypeOIDC}
|
||||
}
|
||||
|
||||
@@ -1,312 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProviderFactory_CreateOIDCProvider(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
config := &ProviderConfig{
|
||||
Name: "test-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
"clientSecret": "test-secret",
|
||||
"jwksUri": "https://test-issuer.com/.well-known/jwks.json",
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
assert.Equal(t, "test-oidc", provider.Name())
|
||||
}
|
||||
|
||||
// Note: Mock provider tests removed - mock providers are now test-only
|
||||
// and not available through the production ProviderFactory
|
||||
|
||||
func TestProviderFactory_DisabledProvider(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
config := &ProviderConfig{
|
||||
Name: "disabled-provider",
|
||||
Type: "oidc",
|
||||
Enabled: false,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, provider) // Should return nil for disabled providers
|
||||
}
|
||||
|
||||
func TestProviderFactory_InvalidProviderType(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-provider",
|
||||
Type: "unsupported-type",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "unsupported provider type")
|
||||
}
|
||||
|
||||
func TestProviderFactory_LoadMultipleProviders(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
configs := []*ProviderConfig{
|
||||
{
|
||||
Name: "oidc-provider",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://oidc-issuer.com",
|
||||
"clientId": "oidc-client",
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Name: "disabled-provider",
|
||||
Type: "oidc",
|
||||
Enabled: false,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://disabled-issuer.com",
|
||||
"clientId": "disabled-client",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
providers, err := factory.LoadProvidersFromConfig(configs)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, providers, 1) // Only enabled providers should be loaded
|
||||
|
||||
assert.Contains(t, providers, "oidc-provider")
|
||||
assert.NotContains(t, providers, "disabled-provider")
|
||||
}
|
||||
|
||||
func TestProviderFactory_ValidateOIDCConfig(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "valid-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://valid-issuer.com",
|
||||
"clientId": "valid-client",
|
||||
},
|
||||
}
|
||||
|
||||
err := factory.ValidateProviderConfig(config)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing issuer", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"clientId": "valid-client",
|
||||
},
|
||||
}
|
||||
|
||||
err := factory.ValidateProviderConfig(config)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "issuer")
|
||||
})
|
||||
|
||||
t.Run("missing clientId", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://valid-issuer.com",
|
||||
},
|
||||
}
|
||||
|
||||
err := factory.ValidateProviderConfig(config)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "clientId")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderFactory_ConvertToStringSlice(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
t.Run("string slice", func(t *testing.T) {
|
||||
input := []string{"a", "b", "c"}
|
||||
result, err := factory.convertToStringSlice(input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"a", "b", "c"}, result)
|
||||
})
|
||||
|
||||
t.Run("interface slice", func(t *testing.T) {
|
||||
input := []interface{}{"a", "b", "c"}
|
||||
result, err := factory.convertToStringSlice(input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"a", "b", "c"}, result)
|
||||
})
|
||||
|
||||
t.Run("invalid type", func(t *testing.T) {
|
||||
input := "not-a-slice"
|
||||
result, err := factory.convertToStringSlice(input)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderFactory_ConfigConversionErrors(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
t.Run("invalid scopes type", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-scopes",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
"scopes": "invalid-not-array", // Should be array
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "failed to convert scopes")
|
||||
})
|
||||
|
||||
t.Run("invalid claimsMapping type", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-claims",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
"claimsMapping": "invalid-not-map", // Should be map
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "failed to convert claimsMapping")
|
||||
})
|
||||
|
||||
t.Run("invalid roleMapping type", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-roles",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
"roleMapping": "invalid-not-map", // Should be map
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "failed to convert roleMapping")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderFactory_ConvertToStringMap(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
t.Run("string map", func(t *testing.T) {
|
||||
input := map[string]string{"key1": "value1", "key2": "value2"}
|
||||
result, err := factory.convertToStringMap(input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
|
||||
})
|
||||
|
||||
t.Run("interface map", func(t *testing.T) {
|
||||
input := map[string]interface{}{"key1": "value1", "key2": "value2"}
|
||||
result, err := factory.convertToStringMap(input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
|
||||
})
|
||||
|
||||
t.Run("invalid type", func(t *testing.T) {
|
||||
input := "not-a-map"
|
||||
result, err := factory.convertToStringMap(input)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderFactory_GetSupportedProviderTypes(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
supportedTypes := factory.GetSupportedProviderTypes()
|
||||
assert.Contains(t, supportedTypes, "oidc")
|
||||
assert.Len(t, supportedTypes, 1) // Currently only OIDC is supported in production
|
||||
}
|
||||
|
||||
func TestSTSService_LoadProvidersFromConfig(t *testing.T) {
|
||||
stsConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{3600 * time.Second},
|
||||
MaxSessionLength: FlexibleDuration{43200 * time.Second},
|
||||
Issuer: "test-issuer",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
Providers: []*ProviderConfig{
|
||||
{
|
||||
Name: "test-provider",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stsService := NewSTSService()
|
||||
err := stsService.Initialize(stsConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that provider was loaded
|
||||
assert.Len(t, stsService.providers, 1)
|
||||
assert.Contains(t, stsService.providers, "test-provider")
|
||||
assert.Equal(t, "test-provider", stsService.providers["test-provider"].Name())
|
||||
}
|
||||
|
||||
func TestSTSService_NoProvidersConfig(t *testing.T) {
|
||||
stsConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{3600 * time.Second},
|
||||
MaxSessionLength: FlexibleDuration{43200 * time.Second},
|
||||
Issuer: "test-issuer",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
// No providers configured
|
||||
}
|
||||
|
||||
stsService := NewSTSService()
|
||||
err := stsService.Initialize(stsConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should initialize successfully with no providers
|
||||
assert.Len(t, stsService.providers, 0)
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSecurityIssuerToProviderMapping tests the security fix that ensures JWT tokens
|
||||
// with specific issuer claims can only be validated by the provider registered for that issuer
|
||||
func TestSecurityIssuerToProviderMapping(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create STS service with two mock providers
|
||||
service := NewSTSService()
|
||||
config := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{time.Hour * 12},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
}
|
||||
|
||||
err := service.Initialize(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up mock trust policy validator
|
||||
mockValidator := &MockTrustPolicyValidator{}
|
||||
service.SetTrustPolicyValidator(mockValidator)
|
||||
|
||||
// Create two mock providers with different issuers
|
||||
providerA := &MockIdentityProviderWithIssuer{
|
||||
name: "provider-a",
|
||||
issuer: "https://provider-a.com",
|
||||
validTokens: map[string]bool{
|
||||
"token-for-provider-a": true,
|
||||
},
|
||||
}
|
||||
|
||||
providerB := &MockIdentityProviderWithIssuer{
|
||||
name: "provider-b",
|
||||
issuer: "https://provider-b.com",
|
||||
validTokens: map[string]bool{
|
||||
"token-for-provider-b": true,
|
||||
},
|
||||
}
|
||||
|
||||
// Register both providers
|
||||
err = service.RegisterProvider(providerA)
|
||||
require.NoError(t, err)
|
||||
err = service.RegisterProvider(providerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create JWT tokens with specific issuer claims
|
||||
tokenForProviderA := createTestJWT(t, "https://provider-a.com", "user-a")
|
||||
tokenForProviderB := createTestJWT(t, "https://provider-b.com", "user-b")
|
||||
|
||||
t.Run("jwt_token_with_issuer_a_only_validated_by_provider_a", func(t *testing.T) {
|
||||
// This should succeed - token has issuer A and provider A is registered
|
||||
identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderA)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, identity)
|
||||
assert.Equal(t, "provider-a", provider.Name())
|
||||
})
|
||||
|
||||
t.Run("jwt_token_with_issuer_b_only_validated_by_provider_b", func(t *testing.T) {
|
||||
// This should succeed - token has issuer B and provider B is registered
|
||||
identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderB)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, identity)
|
||||
assert.Equal(t, "provider-b", provider.Name())
|
||||
})
|
||||
|
||||
t.Run("jwt_token_with_unregistered_issuer_fails", func(t *testing.T) {
|
||||
// Create token with unregistered issuer
|
||||
tokenWithUnknownIssuer := createTestJWT(t, "https://unknown-issuer.com", "user-x")
|
||||
|
||||
// This should fail - no provider registered for this issuer
|
||||
identity, provider, err := service.validateWebIdentityToken(ctx, tokenWithUnknownIssuer)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, identity)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "no identity provider registered for issuer: https://unknown-issuer.com")
|
||||
})
|
||||
|
||||
t.Run("non_jwt_tokens_are_rejected", func(t *testing.T) {
|
||||
// Non-JWT tokens should be rejected - no fallback mechanism exists for security
|
||||
identity, provider, err := service.validateWebIdentityToken(ctx, "token-for-provider-a")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, identity)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "web identity token must be a valid JWT token")
|
||||
})
|
||||
}
|
||||
|
||||
// createTestJWT creates a test JWT token with the specified issuer and subject
|
||||
func createTestJWT(t *testing.T, issuer, subject string) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": subject,
|
||||
"aud": "test-client",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// MockIdentityProviderWithIssuer is a mock provider that supports issuer mapping
|
||||
type MockIdentityProviderWithIssuer struct {
|
||||
name string
|
||||
issuer string
|
||||
validTokens map[string]bool
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) GetIssuer() string {
|
||||
return m.issuer
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) Initialize(config interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
|
||||
// For JWT tokens, parse and validate the token format
|
||||
if len(token) > 50 && strings.Contains(token, ".") {
|
||||
// This looks like a JWT - parse it to get the subject
|
||||
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT token")
|
||||
}
|
||||
|
||||
claims, ok := parsedToken.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid claims")
|
||||
}
|
||||
|
||||
issuer, _ := claims["iss"].(string)
|
||||
subject, _ := claims["sub"].(string)
|
||||
|
||||
// Verify the issuer matches what we expect
|
||||
if issuer != m.issuer {
|
||||
return nil, fmt.Errorf("token issuer %s does not match provider issuer %s", issuer, m.issuer)
|
||||
}
|
||||
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: subject,
|
||||
Email: subject + "@" + m.name + ".com",
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// For non-JWT tokens, check our simple token list
|
||||
if m.validTokens[token] {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: "test-user",
|
||||
Email: "test@" + m.name + ".com",
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: userID,
|
||||
Email: userID + "@" + m.name + ".com",
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
|
||||
if m.validTokens[token] {
|
||||
return &providers.TokenClaims{
|
||||
Subject: "test-user",
|
||||
Issuer: m.issuer,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
@@ -1,168 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// createSessionPolicyTestJWT creates a test JWT token for session policy tests
|
||||
func createSessionPolicyTestJWT(t *testing.T, issuer, subject string) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": subject,
|
||||
"aud": "test-client",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithWebIdentity_SessionPolicy verifies inline session policies are preserved in tokens.
|
||||
func TestAssumeRoleWithWebIdentity_SessionPolicy(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::example-bucket/*"}]}`
|
||||
testToken := createSessionPolicyTestJWT(t, "test-issuer", "test-user")
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: testToken,
|
||||
RoleSessionName: "test-session",
|
||||
Policy: &sessionPolicy,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
normalized, err := NormalizeSessionPolicy(sessionPolicy)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, normalized, sessionInfo.SessionPolicy)
|
||||
|
||||
t.Run("should_succeed_without_session_policy", func(t *testing.T) {
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sessionInfo.SessionPolicy)
|
||||
})
|
||||
}
|
||||
|
||||
// Test edge case scenarios for the Policy field handling
|
||||
func TestAssumeRoleWithWebIdentity_SessionPolicy_EdgeCases(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("malformed_json_policy_rejected", func(t *testing.T) {
|
||||
malformedPolicy := `{"Version": "2012-10-17", "Statement": [` // Incomplete JSON
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
Policy: &malformedPolicy,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, response)
|
||||
assert.Contains(t, err.Error(), "invalid session policy JSON")
|
||||
})
|
||||
|
||||
t.Run("invalid_policy_document_rejected", func(t *testing.T) {
|
||||
invalidPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow"}]}`
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
Policy: &invalidPolicy,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, response)
|
||||
assert.Contains(t, err.Error(), "invalid session policy document")
|
||||
})
|
||||
|
||||
t.Run("whitespace_policy_ignored", func(t *testing.T) {
|
||||
whitespacePolicy := " \t\n "
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
Policy: &whitespacePolicy,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sessionInfo.SessionPolicy)
|
||||
})
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation verifies that the struct field exists and is optional.
|
||||
func TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation(t *testing.T) {
|
||||
request := &AssumeRoleWithWebIdentityRequest{}
|
||||
|
||||
assert.IsType(t, (*string)(nil), request.Policy,
|
||||
"Policy field should be *string type for optional JSON policy")
|
||||
assert.Nil(t, request.Policy,
|
||||
"Policy field should default to nil (no session policy)")
|
||||
|
||||
policyValue := `{"Version": "2012-10-17"}`
|
||||
request.Policy = &policyValue
|
||||
assert.NotNil(t, request.Policy, "Should be able to assign policy value")
|
||||
assert.Equal(t, policyValue, *request.Policy, "Policy value should be preserved")
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithCredentials_SessionPolicy verifies session policy support for credentials-based flow.
|
||||
func TestAssumeRoleWithCredentials_SessionPolicy(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"filer:CreateEntry","Resource":"arn:aws:filer::path/user-docs/*"}]}`
|
||||
request := &AssumeRoleWithCredentialsRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
RoleSessionName: "test-session",
|
||||
ProviderName: "test-ldap",
|
||||
Policy: &sessionPolicy,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithCredentials(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
normalized, err := NormalizeSessionPolicy(sessionPolicy)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, normalized, sessionInfo.SessionPolicy)
|
||||
}
|
||||
@@ -879,21 +879,6 @@ func (s *STSService) calculateSessionDuration(durationSeconds *int64, tokenExpir
|
||||
return duration
|
||||
}
|
||||
|
||||
// extractSessionIdFromToken extracts session ID from JWT session token
|
||||
func (s *STSService) extractSessionIdFromToken(sessionToken string) string {
|
||||
// Validate JWT and extract session claims
|
||||
claims, err := s.tokenGenerator.ValidateJWTWithClaims(sessionToken)
|
||||
if err != nil {
|
||||
// For test compatibility, also handle direct session IDs
|
||||
if len(sessionToken) == 32 { // Typical session ID length
|
||||
return sessionToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
return claims.SessionId
|
||||
}
|
||||
|
||||
// validateAssumeRoleWithCredentialsRequest validates the credentials request parameters
|
||||
func (s *STSService) validateAssumeRoleWithCredentialsRequest(request *AssumeRoleWithCredentialsRequest) error {
|
||||
if request.RoleArn == "" {
|
||||
|
||||
@@ -1,778 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// createSTSTestJWT creates a test JWT token for STS service tests
|
||||
func createSTSTestJWT(t *testing.T, issuer, subject string) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": subject,
|
||||
"aud": "test-client",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// TestSTSServiceInitialization tests STS service initialization
|
||||
func TestSTSServiceInitialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *STSConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{time.Hour * 12},
|
||||
Issuer: "seaweedfs-sts",
|
||||
SigningKey: []byte("test-signing-key"),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing signing key",
|
||||
config: &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
Issuer: "seaweedfs-sts",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid token duration",
|
||||
config: &STSConfig{
|
||||
TokenDuration: FlexibleDuration{-time.Hour},
|
||||
Issuer: "seaweedfs-sts",
|
||||
SigningKey: []byte("test-key"),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service := NewSTSService()
|
||||
|
||||
err := service.Initialize(tt.config)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, service.IsInitialized())
|
||||
|
||||
// Verify defaults if applicable
|
||||
if tt.config.Issuer == "" {
|
||||
assert.Equal(t, DefaultIssuer, service.Config.Issuer)
|
||||
}
|
||||
if tt.config.TokenDuration.Duration == 0 {
|
||||
assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, service.Config.TokenDuration.Duration)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSTSServiceDefaults(t *testing.T) {
|
||||
service := NewSTSService()
|
||||
config := &STSConfig{
|
||||
SigningKey: []byte("test-signing-key"),
|
||||
// Missing duration and issuer
|
||||
}
|
||||
|
||||
err := service.Initialize(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, DefaultIssuer, config.Issuer)
|
||||
assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, config.TokenDuration.Duration)
|
||||
assert.Equal(t, time.Duration(DefaultMaxSessionLength)*time.Second, config.MaxSessionLength.Duration)
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
|
||||
func TestAssumeRoleWithWebIdentity(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roleArn string
|
||||
webIdentityToken string
|
||||
sessionName string
|
||||
durationSeconds *int64
|
||||
wantErr bool
|
||||
expectedSubject string
|
||||
}{
|
||||
{
|
||||
name: "successful role assumption",
|
||||
roleArn: "arn:aws:iam::role/TestRole",
|
||||
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user-id"),
|
||||
sessionName: "test-session",
|
||||
durationSeconds: nil, // Use default
|
||||
wantErr: false,
|
||||
expectedSubject: "test-user-id",
|
||||
},
|
||||
{
|
||||
name: "invalid web identity token",
|
||||
roleArn: "arn:aws:iam::role/TestRole",
|
||||
webIdentityToken: "invalid-token",
|
||||
sessionName: "test-session",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent role",
|
||||
roleArn: "arn:aws:iam::role/NonExistentRole",
|
||||
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
|
||||
sessionName: "test-session",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "custom session duration",
|
||||
roleArn: "arn:aws:iam::role/TestRole",
|
||||
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
|
||||
sessionName: "test-session",
|
||||
durationSeconds: int64Ptr(7200), // 2 hours
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: tt.roleArn,
|
||||
WebIdentityToken: tt.webIdentityToken,
|
||||
RoleSessionName: tt.sessionName,
|
||||
DurationSeconds: tt.durationSeconds,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, response)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, response)
|
||||
assert.NotNil(t, response.Credentials)
|
||||
assert.NotNil(t, response.AssumedRoleUser)
|
||||
|
||||
// Verify credentials
|
||||
creds := response.Credentials
|
||||
assert.NotEmpty(t, creds.AccessKeyId)
|
||||
assert.NotEmpty(t, creds.SecretAccessKey)
|
||||
assert.NotEmpty(t, creds.SessionToken)
|
||||
assert.True(t, creds.Expiration.After(time.Now()))
|
||||
|
||||
// Verify assumed role user
|
||||
user := response.AssumedRoleUser
|
||||
assert.Equal(t, tt.roleArn, user.AssumedRoleId)
|
||||
assert.Contains(t, user.Arn, tt.sessionName)
|
||||
|
||||
if tt.expectedSubject != "" {
|
||||
assert.Equal(t, tt.expectedSubject, user.Subject)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
|
||||
func TestAssumeRoleWithLDAP(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roleArn string
|
||||
username string
|
||||
password string
|
||||
sessionName string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful LDAP role assumption",
|
||||
roleArn: "arn:aws:iam::role/LDAPRole",
|
||||
username: "testuser",
|
||||
password: "testpass",
|
||||
sessionName: "ldap-session",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid LDAP credentials",
|
||||
roleArn: "arn:aws:iam::role/LDAPRole",
|
||||
username: "testuser",
|
||||
password: "wrongpass",
|
||||
sessionName: "ldap-session",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
request := &AssumeRoleWithCredentialsRequest{
|
||||
RoleArn: tt.roleArn,
|
||||
Username: tt.username,
|
||||
Password: tt.password,
|
||||
RoleSessionName: tt.sessionName,
|
||||
ProviderName: "test-ldap",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithCredentials(ctx, request)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, response)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, response)
|
||||
assert.NotNil(t, response.Credentials)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTokenValidation tests session token validation
|
||||
func TestSessionTokenValidation(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// First, create a session
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
sessionToken := response.Credentials.SessionToken
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid session token",
|
||||
token: sessionToken,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid session token",
|
||||
token: "invalid-session-token",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty session token",
|
||||
token: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
session, err := service.ValidateSessionToken(ctx, tt.token)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, session)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, session)
|
||||
assert.Equal(t, "test-session", session.SessionName)
|
||||
assert.Equal(t, "arn:aws:iam::role/TestRole", session.RoleArn)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTokenPersistence tests that JWT tokens remain valid throughout their lifetime
|
||||
// Note: In the stateless JWT design, tokens cannot be revoked and remain valid until expiration
|
||||
func TestSessionTokenPersistence(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session first
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionToken := response.Credentials.SessionToken
|
||||
|
||||
// Verify token is valid initially
|
||||
session, err := service.ValidateSessionToken(ctx, sessionToken)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, session)
|
||||
assert.Equal(t, "test-session", session.SessionName)
|
||||
|
||||
// In a stateless JWT system, tokens remain valid throughout their lifetime
|
||||
// Multiple validations should all succeed as long as the token hasn't expired
|
||||
session2, err := service.ValidateSessionToken(ctx, sessionToken)
|
||||
assert.NoError(t, err, "Token should remain valid in stateless system")
|
||||
assert.NotNil(t, session2, "Session should be returned from JWT token")
|
||||
assert.Equal(t, session.SessionId, session2.SessionId, "Session ID should be consistent")
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func setupTestSTSService(t *testing.T) *STSService {
|
||||
service := NewSTSService()
|
||||
|
||||
config := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{time.Hour * 12},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
}
|
||||
|
||||
err := service.Initialize(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up mock trust policy validator (required for STS testing)
|
||||
mockValidator := &MockTrustPolicyValidator{}
|
||||
service.SetTrustPolicyValidator(mockValidator)
|
||||
|
||||
// Register test providers
|
||||
mockOIDCProvider := &MockIdentityProvider{
|
||||
name: "test-oidc",
|
||||
validTokens: map[string]*providers.TokenClaims{
|
||||
createSTSTestJWT(t, "test-issuer", "test-user"): {
|
||||
Subject: "test-user-id",
|
||||
Issuer: "test-issuer",
|
||||
Claims: map[string]interface{}{
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockLDAPProvider := &MockIdentityProvider{
|
||||
name: "test-ldap",
|
||||
validCredentials: map[string]string{
|
||||
"testuser": "testpass",
|
||||
},
|
||||
}
|
||||
|
||||
service.RegisterProvider(mockOIDCProvider)
|
||||
service.RegisterProvider(mockLDAPProvider)
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func int64Ptr(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Mock identity provider for testing
|
||||
type MockIdentityProvider struct {
|
||||
name string
|
||||
validTokens map[string]*providers.TokenClaims
|
||||
validCredentials map[string]string
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) GetIssuer() string {
|
||||
return "test-issuer" // This matches the issuer in the token claims
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) Initialize(config interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
|
||||
// First try to parse as JWT token
|
||||
if len(token) > 20 && strings.Count(token, ".") >= 2 {
|
||||
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
|
||||
if err == nil {
|
||||
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
|
||||
issuer, _ := claims["iss"].(string)
|
||||
subject, _ := claims["sub"].(string)
|
||||
|
||||
// Verify the issuer matches what we expect
|
||||
if issuer == "test-issuer" && subject != "" {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: subject,
|
||||
Email: subject + "@test-domain.com",
|
||||
DisplayName: "Test User " + subject,
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle legacy OIDC tokens (for backwards compatibility)
|
||||
if claims, exists := m.validTokens[token]; exists {
|
||||
email, _ := claims.GetClaimString("email")
|
||||
name, _ := claims.GetClaimString("name")
|
||||
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: claims.Subject,
|
||||
Email: email,
|
||||
DisplayName: name,
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Handle LDAP credentials (username:password format)
|
||||
if m.validCredentials != nil {
|
||||
parts := strings.Split(token, ":")
|
||||
if len(parts) == 2 {
|
||||
username, password := parts[0], parts[1]
|
||||
if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: username,
|
||||
Email: username + "@" + m.name + ".com",
|
||||
DisplayName: "Test User " + username,
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown test token: %s", token)
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: userID,
|
||||
Email: userID + "@" + m.name + ".com",
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
|
||||
if claims, exists := m.validTokens[token]; exists {
|
||||
return claims, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
// TestSessionDurationCappedByTokenExpiration tests that session duration is capped by the source token's exp claim
|
||||
func TestSessionDurationCappedByTokenExpiration(t *testing.T) {
|
||||
service := NewSTSService()
|
||||
|
||||
config := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour}, // Default: 1 hour
|
||||
MaxSessionLength: FlexibleDuration{time.Hour * 12},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
}
|
||||
|
||||
err := service.Initialize(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
durationSeconds *int64
|
||||
tokenExpiration *time.Time
|
||||
expectedMaxSeconds int64
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no token expiration - use default duration",
|
||||
durationSeconds: nil,
|
||||
tokenExpiration: nil,
|
||||
expectedMaxSeconds: 3600, // 1 hour default
|
||||
description: "When no token expiration is set, use the configured default duration",
|
||||
},
|
||||
{
|
||||
name: "token expires before default duration",
|
||||
durationSeconds: nil,
|
||||
tokenExpiration: timePtr(time.Now().Add(30 * time.Minute)),
|
||||
expectedMaxSeconds: 30 * 60, // 30 minutes
|
||||
description: "When token expires in 30 min, session should be capped at 30 min",
|
||||
},
|
||||
{
|
||||
name: "token expires after default duration - use default",
|
||||
durationSeconds: nil,
|
||||
tokenExpiration: timePtr(time.Now().Add(2 * time.Hour)),
|
||||
expectedMaxSeconds: 3600, // 1 hour default, since it's less than 2 hour token expiry
|
||||
description: "When token expires after default duration, use the default duration",
|
||||
},
|
||||
{
|
||||
name: "requested duration shorter than token expiry",
|
||||
durationSeconds: int64Ptr(1800), // 30 min requested
|
||||
tokenExpiration: timePtr(time.Now().Add(time.Hour)),
|
||||
expectedMaxSeconds: 1800, // 30 minutes as requested
|
||||
description: "When requested duration is shorter than token expiry, use requested duration",
|
||||
},
|
||||
{
|
||||
name: "requested duration longer than token expiry - cap at token expiry",
|
||||
durationSeconds: int64Ptr(3600), // 1 hour requested
|
||||
tokenExpiration: timePtr(time.Now().Add(15 * time.Minute)),
|
||||
expectedMaxSeconds: 15 * 60, // Capped at 15 minutes
|
||||
description: "When requested duration exceeds token expiry, cap at token expiry",
|
||||
},
|
||||
{
|
||||
name: "GitLab CI short-lived token scenario",
|
||||
durationSeconds: nil,
|
||||
tokenExpiration: timePtr(time.Now().Add(5 * time.Minute)),
|
||||
expectedMaxSeconds: 5 * 60, // 5 minutes
|
||||
description: "GitLab CI job with 5 minute timeout should result in 5 minute session",
|
||||
},
|
||||
{
|
||||
name: "already expired token - defense in depth",
|
||||
durationSeconds: nil,
|
||||
tokenExpiration: timePtr(time.Now().Add(-5 * time.Minute)), // Expired 5 minutes ago
|
||||
expectedMaxSeconds: 60, // 1 minute minimum
|
||||
description: "Already expired token should result in minimal 1 minute session",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
duration := service.calculateSessionDuration(tt.durationSeconds, tt.tokenExpiration)
|
||||
|
||||
// Allow 5 second tolerance for time calculations
|
||||
maxExpected := time.Duration(tt.expectedMaxSeconds+5) * time.Second
|
||||
minExpected := time.Duration(tt.expectedMaxSeconds-5) * time.Second
|
||||
|
||||
assert.GreaterOrEqual(t, duration, minExpected,
|
||||
"%s: duration %v should be >= %v", tt.description, duration, minExpected)
|
||||
assert.LessOrEqual(t, duration, maxExpected,
|
||||
"%s: duration %v should be <= %v", tt.description, duration, maxExpected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithWebIdentityRespectsTokenExpiration tests end-to-end that session duration is capped
|
||||
func TestAssumeRoleWithWebIdentityRespectsTokenExpiration(t *testing.T) {
|
||||
service := NewSTSService()
|
||||
|
||||
config := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{time.Hour * 12},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
}
|
||||
|
||||
err := service.Initialize(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up mock trust policy validator
|
||||
mockValidator := &MockTrustPolicyValidator{}
|
||||
service.SetTrustPolicyValidator(mockValidator)
|
||||
|
||||
// Create a mock provider that returns tokens with short expiration
|
||||
shortLivedTokenExpiration := time.Now().Add(10 * time.Minute)
|
||||
mockProvider := &MockIdentityProviderWithExpiration{
|
||||
name: "short-lived-issuer",
|
||||
tokenExpiration: &shortLivedTokenExpiration,
|
||||
}
|
||||
service.RegisterProvider(mockProvider)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a JWT token with short expiration
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": "short-lived-issuer",
|
||||
"sub": "test-user",
|
||||
"aud": "test-client",
|
||||
"exp": shortLivedTokenExpiration.Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: tokenString,
|
||||
RoleSessionName: "test-session",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
// Verify the session expires at or before the token expiration
|
||||
// Allow 5 second tolerance
|
||||
assert.True(t, response.Credentials.Expiration.Before(shortLivedTokenExpiration.Add(5*time.Second)),
|
||||
"Session expiration (%v) should not exceed token expiration (%v)",
|
||||
response.Credentials.Expiration, shortLivedTokenExpiration)
|
||||
}
|
||||
|
||||
// MockIdentityProviderWithExpiration is a mock provider that returns tokens with configurable expiration
|
||||
type MockIdentityProviderWithExpiration struct {
|
||||
name string
|
||||
tokenExpiration *time.Time
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) GetIssuer() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) Initialize(config interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
|
||||
// Parse the token to get subject
|
||||
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := parsedToken.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid claims")
|
||||
}
|
||||
|
||||
subject, _ := claims["sub"].(string)
|
||||
|
||||
identity := &providers.ExternalIdentity{
|
||||
UserID: subject,
|
||||
Email: subject + "@example.com",
|
||||
DisplayName: "Test User",
|
||||
Provider: m.name,
|
||||
TokenExpiration: m.tokenExpiration,
|
||||
}
|
||||
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: userID,
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
|
||||
claims := &providers.TokenClaims{
|
||||
Subject: "test-user",
|
||||
Issuer: m.name,
|
||||
}
|
||||
if m.tokenExpiration != nil {
|
||||
claims.ExpiresAt = *m.tokenExpiration
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func timePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithWebIdentity_PreservesAttributes tests that attributes from the identity provider
|
||||
// are correctly propagated to the session token's request context
|
||||
func TestAssumeRoleWithWebIdentity_PreservesAttributes(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
|
||||
// Create a mock provider that returns a user with attributes
|
||||
mockProvider := &MockIdentityProviderWithAttributes{
|
||||
name: "attr-provider",
|
||||
attributes: map[string]string{
|
||||
"preferred_username": "my-user",
|
||||
"department": "engineering",
|
||||
"project": "seaweedfs",
|
||||
},
|
||||
}
|
||||
service.RegisterProvider(mockProvider)
|
||||
|
||||
// Create a valid JWT token for the provider
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": "attr-provider",
|
||||
"sub": "test-user-id",
|
||||
"aud": "test-client",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: tokenString,
|
||||
RoleSessionName: "test-session",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
// Validate the session token to check claims
|
||||
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that attributes are present in RequestContext
|
||||
require.NotNil(t, sessionInfo.RequestContext, "RequestContext should not be nil")
|
||||
assert.Equal(t, "my-user", sessionInfo.RequestContext["preferred_username"])
|
||||
assert.Equal(t, "engineering", sessionInfo.RequestContext["department"])
|
||||
assert.Equal(t, "seaweedfs", sessionInfo.RequestContext["project"])
|
||||
|
||||
// Check standard claims are also present
|
||||
assert.Equal(t, "test-user-id", sessionInfo.RequestContext["sub"])
|
||||
assert.Equal(t, "test@example.com", sessionInfo.RequestContext["email"])
|
||||
assert.Equal(t, "Test User", sessionInfo.RequestContext["name"])
|
||||
}
|
||||
|
||||
// MockIdentityProviderWithAttributes is a mock provider that returns configured attributes
|
||||
type MockIdentityProviderWithAttributes struct {
|
||||
name string
|
||||
attributes map[string]string
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) GetIssuer() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) Initialize(config interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: "test-user-id",
|
||||
Email: "test@example.com",
|
||||
DisplayName: "Test User",
|
||||
Provider: m.name,
|
||||
Attributes: m.attributes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
|
||||
return &providers.TokenClaims{
|
||||
Subject: "test-user-id",
|
||||
Issuer: m.name,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,53 +1,4 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
|
||||
)
|
||||
|
||||
// MockTrustPolicyValidator is a simple mock for testing STS functionality
|
||||
type MockTrustPolicyValidator struct{}
|
||||
|
||||
// ValidateTrustPolicyForWebIdentity allows valid JWT test tokens for STS testing
|
||||
func (m *MockTrustPolicyValidator) ValidateTrustPolicyForWebIdentity(ctx context.Context, roleArn string, webIdentityToken string, durationSeconds *int64) error {
|
||||
// Reject non-existent roles for testing
|
||||
if strings.Contains(roleArn, "NonExistentRole") {
|
||||
return fmt.Errorf("trust policy validation failed: role does not exist")
|
||||
}
|
||||
|
||||
// For STS unit tests, allow JWT tokens that look valid (contain dots for JWT structure)
|
||||
// In real implementation, this would validate against actual trust policies
|
||||
if len(webIdentityToken) > 20 && strings.Count(webIdentityToken, ".") >= 2 {
|
||||
// This appears to be a JWT token - allow it for testing
|
||||
return nil
|
||||
}
|
||||
|
||||
// Legacy support for specific test tokens during migration
|
||||
if webIdentityToken == "valid_test_token" || webIdentityToken == "valid-oidc-token" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reject invalid tokens
|
||||
if webIdentityToken == "invalid_token" || webIdentityToken == "expired_token" || webIdentityToken == "invalid-token" {
|
||||
return fmt.Errorf("trust policy denies token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateTrustPolicyForCredentials allows valid test identities for STS testing
|
||||
func (m *MockTrustPolicyValidator) ValidateTrustPolicyForCredentials(ctx context.Context, roleArn string, identity *providers.ExternalIdentity) error {
|
||||
// Reject non-existent roles for testing
|
||||
if strings.Contains(roleArn, "NonExistentRole") {
|
||||
return fmt.Errorf("trust policy validation failed: role does not exist")
|
||||
}
|
||||
|
||||
// For STS unit tests, allow test identities
|
||||
if identity != nil && identity.UserID != "" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("invalid identity for role assumption")
|
||||
}
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
package images
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
/*
|
||||
* Preprocess image files on client side.
|
||||
* 1. possibly adjust the orientation
|
||||
* 2. resize the image to a width or height limit
|
||||
* 3. remove the exif data
|
||||
* Call this function on any file uploaded to SeaweedFS
|
||||
*
|
||||
*/
|
||||
func MaybePreprocessImage(filename string, data []byte, width, height int) (resized io.ReadSeeker, w int, h int) {
|
||||
ext := filepath.Ext(filename)
|
||||
ext = strings.ToLower(ext)
|
||||
switch ext {
|
||||
case ".png", ".gif":
|
||||
return Resized(ext, bytes.NewReader(data), width, height, "")
|
||||
case ".jpg", ".jpeg":
|
||||
data = FixJpgOrientation(data)
|
||||
return Resized(ext, bytes.NewReader(data), width, height, "")
|
||||
}
|
||||
return bytes.NewReader(data), 0, 0
|
||||
}
|
||||
@@ -290,15 +290,6 @@ func (loader *ConfigLoader) ValidateConfiguration() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadKMSFromFilerToml is a convenience function to load KMS configuration from filer.toml
|
||||
func LoadKMSFromFilerToml(v ViperConfig) error {
|
||||
loader := NewConfigLoader(v)
|
||||
if err := loader.LoadConfigurations(); err != nil {
|
||||
return err
|
||||
}
|
||||
return loader.ValidateConfiguration()
|
||||
}
|
||||
|
||||
// LoadKMSFromConfig loads KMS configuration directly from parsed JSON data
|
||||
func LoadKMSFromConfig(kmsConfig interface{}) error {
|
||||
kmsMap, ok := kmsConfig.(map[string]interface{})
|
||||
@@ -415,12 +406,3 @@ func getIntFromConfig(config map[string]interface{}, key string, defaultValue in
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getStringFromConfig(config map[string]interface{}, key string, defaultValue string) string {
|
||||
if value, exists := config[key]; exists {
|
||||
if stringValue, ok := value.(string); ok {
|
||||
return stringValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
@@ -147,13 +147,6 @@ func (fh *FileHandle) ReleaseHandle() {
|
||||
}
|
||||
}
|
||||
|
||||
func lessThan(a, b *filer_pb.FileChunk) bool {
|
||||
if a.ModifiedTsNs == b.ModifiedTsNs {
|
||||
return a.Fid.FileKey < b.Fid.FileKey
|
||||
}
|
||||
return a.ModifiedTsNs < b.ModifiedTsNs
|
||||
}
|
||||
|
||||
// getCumulativeOffsets returns cached cumulative offsets for chunks, computing them if necessary
|
||||
func (fh *FileHandle) getCumulativeOffsets(chunks []*filer_pb.FileChunk) []int64 {
|
||||
fh.chunkCacheLock.RLock()
|
||||
|
||||
@@ -21,9 +21,3 @@ func min(x, y int64) int64 {
|
||||
}
|
||||
return y
|
||||
}
|
||||
func minInt(x, y int) int {
|
||||
if x < y {
|
||||
return x
|
||||
}
|
||||
return y
|
||||
}
|
||||
|
||||
@@ -119,13 +119,6 @@ func (c *RDMAMountClient) lookupVolumeLocationByFileID(ctx context.Context, file
|
||||
return bestAddress, nil
|
||||
}
|
||||
|
||||
// lookupVolumeLocation finds the best volume server for a given volume ID (legacy method)
|
||||
func (c *RDMAMountClient) lookupVolumeLocation(ctx context.Context, volumeID uint32, needleID uint64, cookie uint32) (string, error) {
|
||||
// Create a file ID for lookup (format: volumeId,needleId,cookie)
|
||||
fileID := fmt.Sprintf("%d,%x,%d", volumeID, needleID, cookie)
|
||||
return c.lookupVolumeLocationByFileID(ctx, fileID)
|
||||
}
|
||||
|
||||
// healthCheck verifies that the RDMA sidecar is available and functioning
|
||||
func (c *RDMAMountClient) healthCheck() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
|
||||
@@ -117,11 +117,6 @@ func GetBrokerErrorInfo(code int32) BrokerErrorInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// GetKafkaErrorCode returns the corresponding Kafka protocol error code for a broker error
|
||||
func GetKafkaErrorCode(brokerErrorCode int32) int16 {
|
||||
return GetBrokerErrorInfo(brokerErrorCode).KafkaCode
|
||||
}
|
||||
|
||||
// CreateBrokerError creates a structured broker error with both error code and message
|
||||
func CreateBrokerError(code int32, message string) (int32, string) {
|
||||
info := GetBrokerErrorInfo(code)
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
package broker
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/mq/topic"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
||||
)
|
||||
|
||||
func createTestTopic() topic.Topic {
|
||||
return topic.Topic{
|
||||
Namespace: "test",
|
||||
Name: "offset-test",
|
||||
}
|
||||
}
|
||||
|
||||
func createTestPartition() topic.Partition {
|
||||
return topic.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerOffsetManager_AssignOffset(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorageForTesting()
|
||||
manager := NewBrokerOffsetManagerWithStorage(storage)
|
||||
testTopic := createTestTopic()
|
||||
testPartition := createTestPartition()
|
||||
|
||||
// Test sequential offset assignment
|
||||
for i := int64(0); i < 10; i++ {
|
||||
assignedOffset, err := manager.AssignOffset(testTopic, testPartition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign offset %d: %v", i, err)
|
||||
}
|
||||
|
||||
if assignedOffset != i {
|
||||
t.Errorf("Expected offset %d, got %d", i, assignedOffset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerOffsetManager_AssignBatchOffsets(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorageForTesting()
|
||||
manager := NewBrokerOffsetManagerWithStorage(storage)
|
||||
testTopic := createTestTopic()
|
||||
testPartition := createTestPartition()
|
||||
|
||||
// Assign batch of offsets
|
||||
baseOffset, lastOffset, err := manager.AssignBatchOffsets(testTopic, testPartition, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign batch offsets: %v", err)
|
||||
}
|
||||
|
||||
if baseOffset != 0 {
|
||||
t.Errorf("Expected base offset 0, got %d", baseOffset)
|
||||
}
|
||||
|
||||
if lastOffset != 4 {
|
||||
t.Errorf("Expected last offset 4, got %d", lastOffset)
|
||||
}
|
||||
|
||||
// Assign another batch
|
||||
baseOffset2, lastOffset2, err := manager.AssignBatchOffsets(testTopic, testPartition, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign second batch offsets: %v", err)
|
||||
}
|
||||
|
||||
if baseOffset2 != 5 {
|
||||
t.Errorf("Expected base offset 5, got %d", baseOffset2)
|
||||
}
|
||||
|
||||
if lastOffset2 != 7 {
|
||||
t.Errorf("Expected last offset 7, got %d", lastOffset2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerOffsetManager_GetHighWaterMark(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorageForTesting()
|
||||
manager := NewBrokerOffsetManagerWithStorage(storage)
|
||||
testTopic := createTestTopic()
|
||||
testPartition := createTestPartition()
|
||||
|
||||
// Initially should be 0
|
||||
hwm, err := manager.GetHighWaterMark(testTopic, testPartition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get initial high water mark: %v", err)
|
||||
}
|
||||
|
||||
if hwm != 0 {
|
||||
t.Errorf("Expected initial high water mark 0, got %d", hwm)
|
||||
}
|
||||
|
||||
// Assign some offsets
|
||||
manager.AssignBatchOffsets(testTopic, testPartition, 10)
|
||||
|
||||
// High water mark should be updated
|
||||
hwm, err = manager.GetHighWaterMark(testTopic, testPartition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get high water mark after assignment: %v", err)
|
||||
}
|
||||
|
||||
if hwm != 10 {
|
||||
t.Errorf("Expected high water mark 10, got %d", hwm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerOffsetManager_CreateSubscription(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorageForTesting()
|
||||
manager := NewBrokerOffsetManagerWithStorage(storage)
|
||||
testTopic := createTestTopic()
|
||||
testPartition := createTestPartition()
|
||||
|
||||
// Assign some offsets first
|
||||
manager.AssignBatchOffsets(testTopic, testPartition, 5)
|
||||
|
||||
// Create subscription
|
||||
sub, err := manager.CreateSubscription(
|
||||
"test-sub",
|
||||
testTopic,
|
||||
testPartition,
|
||||
schema_pb.OffsetType_RESET_TO_EARLIEST,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
if sub.ID != "test-sub" {
|
||||
t.Errorf("Expected subscription ID 'test-sub', got %s", sub.ID)
|
||||
}
|
||||
|
||||
if sub.StartOffset != 0 {
|
||||
t.Errorf("Expected start offset 0, got %d", sub.StartOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerOffsetManager_GetPartitionOffsetInfo(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorageForTesting()
|
||||
manager := NewBrokerOffsetManagerWithStorage(storage)
|
||||
testTopic := createTestTopic()
|
||||
testPartition := createTestPartition()
|
||||
|
||||
// Test empty partition
|
||||
info, err := manager.GetPartitionOffsetInfo(testTopic, testPartition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get partition offset info: %v", err)
|
||||
}
|
||||
|
||||
if info.EarliestOffset != 0 {
|
||||
t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
|
||||
}
|
||||
|
||||
if info.LatestOffset != -1 {
|
||||
t.Errorf("Expected latest offset -1 for empty partition, got %d", info.LatestOffset)
|
||||
}
|
||||
|
||||
// Assign offsets and test again
|
||||
manager.AssignBatchOffsets(testTopic, testPartition, 5)
|
||||
|
||||
info, err = manager.GetPartitionOffsetInfo(testTopic, testPartition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get partition offset info after assignment: %v", err)
|
||||
}
|
||||
|
||||
if info.LatestOffset != 4 {
|
||||
t.Errorf("Expected latest offset 4, got %d", info.LatestOffset)
|
||||
}
|
||||
|
||||
if info.HighWaterMark != 5 {
|
||||
t.Errorf("Expected high water mark 5, got %d", info.HighWaterMark)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerOffsetManager_MultiplePartitions(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorageForTesting()
|
||||
manager := NewBrokerOffsetManagerWithStorage(storage)
|
||||
testTopic := createTestTopic()
|
||||
|
||||
// Create different partitions
|
||||
partition1 := topic.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
partition2 := topic.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 32,
|
||||
RangeStop: 63,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
// Assign offsets to different partitions
|
||||
assignedOffset1, err := manager.AssignOffset(testTopic, partition1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign offset to partition1: %v", err)
|
||||
}
|
||||
|
||||
assignedOffset2, err := manager.AssignOffset(testTopic, partition2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign offset to partition2: %v", err)
|
||||
}
|
||||
|
||||
// Both should start at 0
|
||||
if assignedOffset1 != 0 {
|
||||
t.Errorf("Expected offset 0 for partition1, got %d", assignedOffset1)
|
||||
}
|
||||
|
||||
if assignedOffset2 != 0 {
|
||||
t.Errorf("Expected offset 0 for partition2, got %d", assignedOffset2)
|
||||
}
|
||||
|
||||
// Assign more offsets to partition1
|
||||
assignedOffset1_2, err := manager.AssignOffset(testTopic, partition1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign second offset to partition1: %v", err)
|
||||
}
|
||||
|
||||
if assignedOffset1_2 != 1 {
|
||||
t.Errorf("Expected offset 1 for partition1, got %d", assignedOffset1_2)
|
||||
}
|
||||
|
||||
// Partition2 should still be at 0 for next assignment
|
||||
assignedOffset2_2, err := manager.AssignOffset(testTopic, partition2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign second offset to partition2: %v", err)
|
||||
}
|
||||
|
||||
if assignedOffset2_2 != 1 {
|
||||
t.Errorf("Expected offset 1 for partition2, got %d", assignedOffset2_2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetAwarePublisher(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorageForTesting()
|
||||
manager := NewBrokerOffsetManagerWithStorage(storage)
|
||||
testTopic := createTestTopic()
|
||||
testPartition := createTestPartition()
|
||||
|
||||
// Create a mock local partition (simplified for testing)
|
||||
localPartition := &topic.LocalPartition{}
|
||||
|
||||
// Create offset assignment function
|
||||
assignOffsetFn := func() (int64, error) {
|
||||
return manager.AssignOffset(testTopic, testPartition)
|
||||
}
|
||||
|
||||
// Create offset-aware publisher
|
||||
publisher := topic.NewOffsetAwarePublisher(localPartition, assignOffsetFn)
|
||||
|
||||
if publisher.GetPartition() != localPartition {
|
||||
t.Error("Publisher should return the correct partition")
|
||||
}
|
||||
|
||||
// Test would require more setup to actually publish messages
|
||||
// This tests the basic structure
|
||||
}
|
||||
|
||||
func TestBrokerOffsetManager_GetOffsetMetrics(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorageForTesting()
|
||||
manager := NewBrokerOffsetManagerWithStorage(storage)
|
||||
testTopic := createTestTopic()
|
||||
testPartition := createTestPartition()
|
||||
|
||||
// Initial metrics
|
||||
metrics := manager.GetOffsetMetrics()
|
||||
if metrics.TotalOffsets != 0 {
|
||||
t.Errorf("Expected 0 total offsets initially, got %d", metrics.TotalOffsets)
|
||||
}
|
||||
|
||||
// Assign some offsets
|
||||
manager.AssignBatchOffsets(testTopic, testPartition, 5)
|
||||
|
||||
// Create subscription
|
||||
manager.CreateSubscription("test-sub", testTopic, testPartition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
|
||||
|
||||
// Check updated metrics
|
||||
metrics = manager.GetOffsetMetrics()
|
||||
if metrics.PartitionCount != 1 {
|
||||
t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerOffsetManager_AssignOffsetsWithResult(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorageForTesting()
|
||||
manager := NewBrokerOffsetManagerWithStorage(storage)
|
||||
testTopic := createTestTopic()
|
||||
testPartition := createTestPartition()
|
||||
|
||||
// Assign offsets with result
|
||||
result := manager.AssignOffsetsWithResult(testTopic, testPartition, 3)
|
||||
|
||||
if result.Error != nil {
|
||||
t.Fatalf("Expected no error, got: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.BaseOffset != 0 {
|
||||
t.Errorf("Expected base offset 0, got %d", result.BaseOffset)
|
||||
}
|
||||
|
||||
if result.LastOffset != 2 {
|
||||
t.Errorf("Expected last offset 2, got %d", result.LastOffset)
|
||||
}
|
||||
|
||||
if result.Count != 3 {
|
||||
t.Errorf("Expected count 3, got %d", result.Count)
|
||||
}
|
||||
|
||||
if result.Topic != testTopic {
|
||||
t.Error("Topic mismatch in result")
|
||||
}
|
||||
|
||||
if result.Partition != testPartition {
|
||||
t.Error("Partition mismatch in result")
|
||||
}
|
||||
|
||||
if result.Timestamp <= 0 {
|
||||
t.Error("Timestamp should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrokerOffsetManager_Shutdown(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorageForTesting()
|
||||
manager := NewBrokerOffsetManagerWithStorage(storage)
|
||||
testTopic := createTestTopic()
|
||||
testPartition := createTestPartition()
|
||||
|
||||
// Assign some offsets and create subscriptions
|
||||
manager.AssignBatchOffsets(testTopic, testPartition, 5)
|
||||
manager.CreateSubscription("test-sub", testTopic, testPartition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
|
||||
|
||||
// Shutdown should not panic
|
||||
manager.Shutdown()
|
||||
|
||||
// After shutdown, operations should still work (using new managers)
|
||||
offset, err := manager.AssignOffset(testTopic, testPartition)
|
||||
if err != nil {
|
||||
t.Fatalf("Operations should still work after shutdown: %v", err)
|
||||
}
|
||||
|
||||
// Should start from 0 again (new manager)
|
||||
if offset != 0 {
|
||||
t.Errorf("Expected offset 0 after shutdown, got %d", offset)
|
||||
}
|
||||
}
|
||||
@@ -203,14 +203,6 @@ func (b *MessageQueueBroker) GetDataCenter() string {
|
||||
|
||||
}
|
||||
|
||||
func (b *MessageQueueBroker) withMasterClient(streamingMode bool, master pb.ServerAddress, fn func(client master_pb.SeaweedClient) error) error {
|
||||
|
||||
return pb.WithMasterClient(streamingMode, master, b.grpcDialOption, false, func(client master_pb.SeaweedClient) error {
|
||||
return fn(client)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (b *MessageQueueBroker) withBrokerClient(streamingMode bool, server pb.ServerAddress, fn func(client mq_pb.SeaweedMessagingClient) error) error {
|
||||
|
||||
return pb.WithBrokerGrpcClient(streamingMode, server.String(), b.grpcDialOption, func(client mq_pb.SeaweedMessagingClient) error {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/filer_client"
|
||||
@@ -192,10 +191,6 @@ func (f *FilerStorage) getOffsetPath(group, topic string, partition int32) strin
|
||||
return fmt.Sprintf("%s/offset", f.getPartitionPath(group, topic, partition))
|
||||
}
|
||||
|
||||
func (f *FilerStorage) getMetadataPath(group, topic string, partition int32) string {
|
||||
return fmt.Sprintf("%s/metadata", f.getPartitionPath(group, topic, partition))
|
||||
}
|
||||
|
||||
func (f *FilerStorage) writeFile(path string, data []byte) error {
|
||||
fullPath := util.FullPath(path)
|
||||
dir, name := fullPath.DirAndName()
|
||||
@@ -311,16 +306,3 @@ func (f *FilerStorage) deleteDirectory(path string) error {
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// normalizePath removes leading/trailing slashes and collapses multiple slashes
|
||||
func normalizePath(path string) string {
|
||||
path = strings.Trim(path, "/")
|
||||
parts := strings.Split(path, "/")
|
||||
normalized := []string{}
|
||||
for _, part := range parts {
|
||||
if part != "" {
|
||||
normalized = append(normalized, part)
|
||||
}
|
||||
}
|
||||
return "/" + strings.Join(normalized, "/")
|
||||
}
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
package consumer_offset
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Note: These tests require a running filer instance
|
||||
// They are marked as integration tests and should be run with:
|
||||
// go test -tags=integration
|
||||
|
||||
func TestFilerStorageCommitAndFetch(t *testing.T) {
|
||||
t.Skip("Requires running filer - integration test")
|
||||
|
||||
// This will be implemented once we have test infrastructure
|
||||
// Test will:
|
||||
// 1. Create filer storage
|
||||
// 2. Commit offset
|
||||
// 3. Fetch offset
|
||||
// 4. Verify values match
|
||||
}
|
||||
|
||||
func TestFilerStoragePersistence(t *testing.T) {
|
||||
t.Skip("Requires running filer - integration test")
|
||||
|
||||
// Test will:
|
||||
// 1. Commit offset with first storage instance
|
||||
// 2. Close first instance
|
||||
// 3. Create new storage instance
|
||||
// 4. Fetch offset and verify it persisted
|
||||
}
|
||||
|
||||
func TestFilerStorageMultipleGroups(t *testing.T) {
|
||||
t.Skip("Requires running filer - integration test")
|
||||
|
||||
// Test will:
|
||||
// 1. Commit offsets for multiple groups
|
||||
// 2. Fetch all offsets per group
|
||||
// 3. Verify isolation between groups
|
||||
}
|
||||
|
||||
func TestFilerStoragePath(t *testing.T) {
|
||||
// Test path generation (doesn't require filer)
|
||||
storage := &FilerStorage{}
|
||||
|
||||
group := "test-group"
|
||||
topic := "test-topic"
|
||||
partition := int32(5)
|
||||
|
||||
groupPath := storage.getGroupPath(group)
|
||||
assert.Equal(t, ConsumerOffsetsBasePath+"/test-group", groupPath)
|
||||
|
||||
topicPath := storage.getTopicPath(group, topic)
|
||||
assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic", topicPath)
|
||||
|
||||
partitionPath := storage.getPartitionPath(group, topic, partition)
|
||||
assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5", partitionPath)
|
||||
|
||||
offsetPath := storage.getOffsetPath(group, topic, partition)
|
||||
assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5/offset", offsetPath)
|
||||
|
||||
metadataPath := storage.getMetadataPath(group, topic, partition)
|
||||
assert.Equal(t, ConsumerOffsetsBasePath+"/test-group/test-topic/5/metadata", metadataPath)
|
||||
}
|
||||
@@ -278,38 +278,3 @@ func (h *SeaweedMQHandler) checkTopicInFiler(topicName string) bool {
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
// listTopicsFromFiler lists all topics from the filer
|
||||
func (h *SeaweedMQHandler) listTopicsFromFiler() []string {
|
||||
if h.filerClientAccessor == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var topics []string
|
||||
|
||||
h.filerClientAccessor.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
|
||||
request := &filer_pb.ListEntriesRequest{
|
||||
Directory: "/topics/kafka",
|
||||
}
|
||||
|
||||
stream, err := client.ListEntries(context.Background(), request)
|
||||
if err != nil {
|
||||
return nil // Don't propagate error, just return empty list
|
||||
}
|
||||
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if err != nil {
|
||||
break // End of stream or error
|
||||
}
|
||||
|
||||
if resp.Entry != nil && resp.Entry.IsDirectory {
|
||||
topics = append(topics, resp.Entry.Name)
|
||||
} else if resp.Entry != nil {
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return topics
|
||||
}
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
package kafka
|
||||
|
||||
import (
|
||||
"github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
||||
)
|
||||
|
||||
// Convenience functions for partition mapping used by production code
|
||||
// The full PartitionMapper implementation is in partition_mapping_test.go for testing
|
||||
|
||||
// MapKafkaPartitionToSMQRange maps a Kafka partition to SeaweedMQ ring range
|
||||
func MapKafkaPartitionToSMQRange(kafkaPartition int32) (rangeStart, rangeStop int32) {
|
||||
// Use a range size that divides evenly into MaxPartitionCount (2520)
|
||||
// Range size 35 gives us exactly 72 Kafka partitions: 2520 / 35 = 72
|
||||
rangeSize := int32(35)
|
||||
rangeStart = kafkaPartition * rangeSize
|
||||
rangeStop = rangeStart + rangeSize - 1
|
||||
return rangeStart, rangeStop
|
||||
}
|
||||
|
||||
// CreateSMQPartition creates a SeaweedMQ partition from a Kafka partition
|
||||
func CreateSMQPartition(kafkaPartition int32, unixTimeNs int64) *schema_pb.Partition {
|
||||
rangeStart, rangeStop := MapKafkaPartitionToSMQRange(kafkaPartition)
|
||||
|
||||
return &schema_pb.Partition{
|
||||
RingSize: pub_balancer.MaxPartitionCount,
|
||||
RangeStart: rangeStart,
|
||||
RangeStop: rangeStop,
|
||||
UnixTimeNs: unixTimeNs,
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractKafkaPartitionFromSMQRange extracts the Kafka partition from SeaweedMQ range
|
||||
func ExtractKafkaPartitionFromSMQRange(rangeStart int32) int32 {
|
||||
rangeSize := int32(35)
|
||||
return rangeStart / rangeSize
|
||||
}
|
||||
|
||||
// ValidateKafkaPartition validates that a Kafka partition is within supported range
|
||||
func ValidateKafkaPartition(kafkaPartition int32) bool {
|
||||
maxPartitions := int32(pub_balancer.MaxPartitionCount) / 35 // 72 partitions
|
||||
return kafkaPartition >= 0 && kafkaPartition < maxPartitions
|
||||
}
|
||||
|
||||
// GetRangeSize returns the range size used for partition mapping
|
||||
func GetRangeSize() int32 {
|
||||
return 35
|
||||
}
|
||||
|
||||
// GetMaxKafkaPartitions returns the maximum number of Kafka partitions supported
|
||||
func GetMaxKafkaPartitions() int32 {
|
||||
return int32(pub_balancer.MaxPartitionCount) / 35 // 72 partitions
|
||||
}
|
||||
@@ -1,294 +0,0 @@
|
||||
package kafka
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
||||
)
|
||||
|
||||
// PartitionMapper provides consistent Kafka partition to SeaweedMQ ring mapping
|
||||
// NOTE: This is test-only code and not used in the actual Kafka Gateway implementation
|
||||
type PartitionMapper struct{}
|
||||
|
||||
// NewPartitionMapper creates a new partition mapper
|
||||
func NewPartitionMapper() *PartitionMapper {
|
||||
return &PartitionMapper{}
|
||||
}
|
||||
|
||||
// GetRangeSize returns the consistent range size for Kafka partition mapping
|
||||
// This ensures all components use the same calculation
|
||||
func (pm *PartitionMapper) GetRangeSize() int32 {
|
||||
// Use a range size that divides evenly into MaxPartitionCount (2520)
|
||||
// Range size 35 gives us exactly 72 Kafka partitions: 2520 / 35 = 72
|
||||
// This provides a good balance between partition granularity and ring utilization
|
||||
return 35
|
||||
}
|
||||
|
||||
// GetMaxKafkaPartitions returns the maximum number of Kafka partitions supported
|
||||
func (pm *PartitionMapper) GetMaxKafkaPartitions() int32 {
|
||||
// With range size 35, we can support: 2520 / 35 = 72 Kafka partitions
|
||||
return int32(pub_balancer.MaxPartitionCount) / pm.GetRangeSize()
|
||||
}
|
||||
|
||||
// MapKafkaPartitionToSMQRange maps a Kafka partition to SeaweedMQ ring range
|
||||
func (pm *PartitionMapper) MapKafkaPartitionToSMQRange(kafkaPartition int32) (rangeStart, rangeStop int32) {
|
||||
rangeSize := pm.GetRangeSize()
|
||||
rangeStart = kafkaPartition * rangeSize
|
||||
rangeStop = rangeStart + rangeSize - 1
|
||||
return rangeStart, rangeStop
|
||||
}
|
||||
|
||||
// CreateSMQPartition creates a SeaweedMQ partition from a Kafka partition
|
||||
func (pm *PartitionMapper) CreateSMQPartition(kafkaPartition int32, unixTimeNs int64) *schema_pb.Partition {
|
||||
rangeStart, rangeStop := pm.MapKafkaPartitionToSMQRange(kafkaPartition)
|
||||
|
||||
return &schema_pb.Partition{
|
||||
RingSize: pub_balancer.MaxPartitionCount,
|
||||
RangeStart: rangeStart,
|
||||
RangeStop: rangeStop,
|
||||
UnixTimeNs: unixTimeNs,
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractKafkaPartitionFromSMQRange extracts the Kafka partition from SeaweedMQ range
|
||||
func (pm *PartitionMapper) ExtractKafkaPartitionFromSMQRange(rangeStart int32) int32 {
|
||||
rangeSize := pm.GetRangeSize()
|
||||
return rangeStart / rangeSize
|
||||
}
|
||||
|
||||
// ValidateKafkaPartition validates that a Kafka partition is within supported range
|
||||
func (pm *PartitionMapper) ValidateKafkaPartition(kafkaPartition int32) bool {
|
||||
return kafkaPartition >= 0 && kafkaPartition < pm.GetMaxKafkaPartitions()
|
||||
}
|
||||
|
||||
// GetPartitionMappingInfo returns debug information about the partition mapping
|
||||
func (pm *PartitionMapper) GetPartitionMappingInfo() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"ring_size": pub_balancer.MaxPartitionCount,
|
||||
"range_size": pm.GetRangeSize(),
|
||||
"max_kafka_partitions": pm.GetMaxKafkaPartitions(),
|
||||
"ring_utilization": float64(pm.GetMaxKafkaPartitions()*pm.GetRangeSize()) / float64(pub_balancer.MaxPartitionCount),
|
||||
}
|
||||
}
|
||||
|
||||
// Global instance for consistent usage across the test codebase
|
||||
var DefaultPartitionMapper = NewPartitionMapper()
|
||||
|
||||
func TestPartitionMapper_GetRangeSize(t *testing.T) {
|
||||
mapper := NewPartitionMapper()
|
||||
rangeSize := mapper.GetRangeSize()
|
||||
|
||||
if rangeSize != 35 {
|
||||
t.Errorf("Expected range size 35, got %d", rangeSize)
|
||||
}
|
||||
|
||||
// Verify that the range size divides evenly into available partitions
|
||||
maxPartitions := mapper.GetMaxKafkaPartitions()
|
||||
totalUsed := maxPartitions * rangeSize
|
||||
|
||||
if totalUsed > int32(pub_balancer.MaxPartitionCount) {
|
||||
t.Errorf("Total used slots (%d) exceeds MaxPartitionCount (%d)", totalUsed, pub_balancer.MaxPartitionCount)
|
||||
}
|
||||
|
||||
t.Logf("Range size: %d, Max Kafka partitions: %d, Ring utilization: %.2f%%",
|
||||
rangeSize, maxPartitions, float64(totalUsed)/float64(pub_balancer.MaxPartitionCount)*100)
|
||||
}
|
||||
|
||||
func TestPartitionMapper_MapKafkaPartitionToSMQRange(t *testing.T) {
|
||||
mapper := NewPartitionMapper()
|
||||
|
||||
tests := []struct {
|
||||
kafkaPartition int32
|
||||
expectedStart int32
|
||||
expectedStop int32
|
||||
}{
|
||||
{0, 0, 34},
|
||||
{1, 35, 69},
|
||||
{2, 70, 104},
|
||||
{10, 350, 384},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
start, stop := mapper.MapKafkaPartitionToSMQRange(tt.kafkaPartition)
|
||||
|
||||
if start != tt.expectedStart {
|
||||
t.Errorf("Kafka partition %d: expected start %d, got %d", tt.kafkaPartition, tt.expectedStart, start)
|
||||
}
|
||||
|
||||
if stop != tt.expectedStop {
|
||||
t.Errorf("Kafka partition %d: expected stop %d, got %d", tt.kafkaPartition, tt.expectedStop, stop)
|
||||
}
|
||||
|
||||
// Verify range size is consistent
|
||||
rangeSize := stop - start + 1
|
||||
if rangeSize != mapper.GetRangeSize() {
|
||||
t.Errorf("Inconsistent range size: expected %d, got %d", mapper.GetRangeSize(), rangeSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionMapper_ExtractKafkaPartitionFromSMQRange(t *testing.T) {
|
||||
mapper := NewPartitionMapper()
|
||||
|
||||
tests := []struct {
|
||||
rangeStart int32
|
||||
expectedKafka int32
|
||||
}{
|
||||
{0, 0},
|
||||
{35, 1},
|
||||
{70, 2},
|
||||
{350, 10},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
kafkaPartition := mapper.ExtractKafkaPartitionFromSMQRange(tt.rangeStart)
|
||||
|
||||
if kafkaPartition != tt.expectedKafka {
|
||||
t.Errorf("Range start %d: expected Kafka partition %d, got %d",
|
||||
tt.rangeStart, tt.expectedKafka, kafkaPartition)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionMapper_RoundTrip(t *testing.T) {
|
||||
mapper := NewPartitionMapper()
|
||||
|
||||
// Test round-trip conversion for all valid Kafka partitions
|
||||
maxPartitions := mapper.GetMaxKafkaPartitions()
|
||||
|
||||
for kafkaPartition := int32(0); kafkaPartition < maxPartitions; kafkaPartition++ {
|
||||
// Kafka -> SMQ -> Kafka
|
||||
rangeStart, rangeStop := mapper.MapKafkaPartitionToSMQRange(kafkaPartition)
|
||||
extractedKafka := mapper.ExtractKafkaPartitionFromSMQRange(rangeStart)
|
||||
|
||||
if extractedKafka != kafkaPartition {
|
||||
t.Errorf("Round-trip failed for partition %d: got %d", kafkaPartition, extractedKafka)
|
||||
}
|
||||
|
||||
// Verify no overlap with next partition
|
||||
if kafkaPartition < maxPartitions-1 {
|
||||
nextStart, _ := mapper.MapKafkaPartitionToSMQRange(kafkaPartition + 1)
|
||||
if rangeStop >= nextStart {
|
||||
t.Errorf("Partition %d range [%d,%d] overlaps with partition %d start %d",
|
||||
kafkaPartition, rangeStart, rangeStop, kafkaPartition+1, nextStart)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionMapper_CreateSMQPartition(t *testing.T) {
|
||||
mapper := NewPartitionMapper()
|
||||
|
||||
kafkaPartition := int32(5)
|
||||
unixTimeNs := time.Now().UnixNano()
|
||||
|
||||
partition := mapper.CreateSMQPartition(kafkaPartition, unixTimeNs)
|
||||
|
||||
if partition.RingSize != pub_balancer.MaxPartitionCount {
|
||||
t.Errorf("Expected ring size %d, got %d", pub_balancer.MaxPartitionCount, partition.RingSize)
|
||||
}
|
||||
|
||||
expectedStart, expectedStop := mapper.MapKafkaPartitionToSMQRange(kafkaPartition)
|
||||
if partition.RangeStart != expectedStart {
|
||||
t.Errorf("Expected range start %d, got %d", expectedStart, partition.RangeStart)
|
||||
}
|
||||
|
||||
if partition.RangeStop != expectedStop {
|
||||
t.Errorf("Expected range stop %d, got %d", expectedStop, partition.RangeStop)
|
||||
}
|
||||
|
||||
if partition.UnixTimeNs != unixTimeNs {
|
||||
t.Errorf("Expected timestamp %d, got %d", unixTimeNs, partition.UnixTimeNs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionMapper_ValidateKafkaPartition(t *testing.T) {
|
||||
mapper := NewPartitionMapper()
|
||||
|
||||
tests := []struct {
|
||||
partition int32
|
||||
valid bool
|
||||
}{
|
||||
{-1, false},
|
||||
{0, true},
|
||||
{1, true},
|
||||
{mapper.GetMaxKafkaPartitions() - 1, true},
|
||||
{mapper.GetMaxKafkaPartitions(), false},
|
||||
{1000, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("", func(t *testing.T) {
|
||||
valid := mapper.ValidateKafkaPartition(tt.partition)
|
||||
if valid != tt.valid {
|
||||
t.Errorf("Partition %d: expected valid=%v, got %v", tt.partition, tt.valid, valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionMapper_ConsistencyWithGlobalFunctions(t *testing.T) {
|
||||
mapper := NewPartitionMapper()
|
||||
|
||||
kafkaPartition := int32(7)
|
||||
unixTimeNs := time.Now().UnixNano()
|
||||
|
||||
// Test that global functions produce same results as mapper methods
|
||||
start1, stop1 := mapper.MapKafkaPartitionToSMQRange(kafkaPartition)
|
||||
start2, stop2 := MapKafkaPartitionToSMQRange(kafkaPartition)
|
||||
|
||||
if start1 != start2 || stop1 != stop2 {
|
||||
t.Errorf("Global function inconsistent: mapper=(%d,%d), global=(%d,%d)",
|
||||
start1, stop1, start2, stop2)
|
||||
}
|
||||
|
||||
partition1 := mapper.CreateSMQPartition(kafkaPartition, unixTimeNs)
|
||||
partition2 := CreateSMQPartition(kafkaPartition, unixTimeNs)
|
||||
|
||||
if partition1.RangeStart != partition2.RangeStart || partition1.RangeStop != partition2.RangeStop {
|
||||
t.Errorf("Global CreateSMQPartition inconsistent")
|
||||
}
|
||||
|
||||
extracted1 := mapper.ExtractKafkaPartitionFromSMQRange(start1)
|
||||
extracted2 := ExtractKafkaPartitionFromSMQRange(start1)
|
||||
|
||||
if extracted1 != extracted2 {
|
||||
t.Errorf("Global ExtractKafkaPartitionFromSMQRange inconsistent: %d vs %d", extracted1, extracted2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionMapper_GetPartitionMappingInfo(t *testing.T) {
|
||||
mapper := NewPartitionMapper()
|
||||
|
||||
info := mapper.GetPartitionMappingInfo()
|
||||
|
||||
// Verify all expected keys are present
|
||||
expectedKeys := []string{"ring_size", "range_size", "max_kafka_partitions", "ring_utilization"}
|
||||
for _, key := range expectedKeys {
|
||||
if _, exists := info[key]; !exists {
|
||||
t.Errorf("Missing key in mapping info: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify values are reasonable
|
||||
if info["ring_size"].(int) != pub_balancer.MaxPartitionCount {
|
||||
t.Errorf("Incorrect ring_size in info")
|
||||
}
|
||||
|
||||
if info["range_size"].(int32) != mapper.GetRangeSize() {
|
||||
t.Errorf("Incorrect range_size in info")
|
||||
}
|
||||
|
||||
utilization := info["ring_utilization"].(float64)
|
||||
if utilization <= 0 || utilization > 1 {
|
||||
t.Errorf("Invalid ring utilization: %f", utilization)
|
||||
}
|
||||
|
||||
t.Logf("Partition mapping info: %+v", info)
|
||||
}
|
||||
@@ -2,11 +2,9 @@ package offset
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
||||
)
|
||||
|
||||
@@ -62,151 +60,6 @@ func BenchmarkBatchOffsetAssignment(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSQLOffsetStorage benchmarks SQL storage operations
|
||||
func BenchmarkSQLOffsetStorage(b *testing.B) {
|
||||
// Create temporary database
|
||||
tmpFile, err := os.CreateTemp("", "benchmark_*.db")
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create temp database: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
db, err := CreateDatabase(tmpFile.Name())
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
partition := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
partitionKey := partitionKey(partition)
|
||||
|
||||
b.Run("SaveCheckpoint", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
storage.SaveCheckpoint("test-namespace", "test-topic", partition, int64(i))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("LoadCheckpoint", func(b *testing.B) {
|
||||
storage.SaveCheckpoint("test-namespace", "test-topic", partition, 1000)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
storage.LoadCheckpoint("test-namespace", "test-topic", partition)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("SaveOffsetMapping", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), 100)
|
||||
}
|
||||
})
|
||||
|
||||
// Pre-populate for read benchmarks
|
||||
for i := 0; i < 1000; i++ {
|
||||
storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), 100)
|
||||
}
|
||||
|
||||
b.Run("GetHighestOffset", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
storage.GetHighestOffset("test-namespace", "test-topic", partition)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("LoadOffsetMappings", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
storage.LoadOffsetMappings(partitionKey)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GetOffsetMappingsByRange", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
start := int64(i % 900)
|
||||
end := start + 100
|
||||
storage.GetOffsetMappingsByRange(partitionKey, start, end)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("GetPartitionStats", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
storage.GetPartitionStats(partitionKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkInMemoryVsSQL compares in-memory and SQL storage performance
|
||||
func BenchmarkInMemoryVsSQL(b *testing.B) {
|
||||
partition := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
// In-memory storage benchmark
|
||||
b.Run("InMemory", func(b *testing.B) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create partition manager: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.AssignOffset()
|
||||
}
|
||||
})
|
||||
|
||||
// SQL storage benchmark
|
||||
b.Run("SQL", func(b *testing.B) {
|
||||
tmpFile, err := os.CreateTemp("", "benchmark_sql_*.db")
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create temp database: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
db, err := CreateDatabase(tmpFile.Name())
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create partition manager: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.AssignOffset()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkOffsetSubscription benchmarks subscription operations
|
||||
func BenchmarkOffsetSubscription(b *testing.B) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
|
||||
@@ -1,473 +0,0 @@
|
||||
package offset
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
||||
)
|
||||
|
||||
// TestEndToEndOffsetFlow tests the complete offset management flow
|
||||
func TestEndToEndOffsetFlow(t *testing.T) {
|
||||
// Create temporary database
|
||||
tmpFile, err := os.CreateTemp("", "e2e_offset_test_*.db")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp database: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
// Create database with migrations
|
||||
db, err := CreateDatabase(tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create SQL storage
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
// Create SMQ offset integration
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
|
||||
// Test partition
|
||||
partition := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
t.Run("PublishAndAssignOffsets", func(t *testing.T) {
|
||||
// Simulate publishing messages with offset assignment
|
||||
records := []PublishRecordRequest{
|
||||
{Key: []byte("user1"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("user2"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("user3"), Value: &schema_pb.RecordValue{}},
|
||||
}
|
||||
|
||||
response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to publish record batch: %v", err)
|
||||
}
|
||||
|
||||
if response.BaseOffset != 0 {
|
||||
t.Errorf("Expected base offset 0, got %d", response.BaseOffset)
|
||||
}
|
||||
|
||||
if response.LastOffset != 2 {
|
||||
t.Errorf("Expected last offset 2, got %d", response.LastOffset)
|
||||
}
|
||||
|
||||
// Verify high water mark
|
||||
hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get high water mark: %v", err)
|
||||
}
|
||||
|
||||
if hwm != 3 {
|
||||
t.Errorf("Expected high water mark 3, got %d", hwm)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CreateAndUseSubscription", func(t *testing.T) {
|
||||
// Create subscription from earliest
|
||||
sub, err := integration.CreateSubscription(
|
||||
"e2e-test-sub",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_RESET_TO_EARLIEST,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Subscribe to records
|
||||
responses, err := integration.SubscribeRecords(sub, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to subscribe to records: %v", err)
|
||||
}
|
||||
|
||||
if len(responses) != 2 {
|
||||
t.Errorf("Expected 2 responses, got %d", len(responses))
|
||||
}
|
||||
|
||||
// Check subscription advancement
|
||||
if sub.CurrentOffset != 2 {
|
||||
t.Errorf("Expected current offset 2, got %d", sub.CurrentOffset)
|
||||
}
|
||||
|
||||
// Get subscription lag
|
||||
lag, err := sub.GetLag()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get lag: %v", err)
|
||||
}
|
||||
|
||||
if lag != 1 { // 3 (hwm) - 2 (current) = 1
|
||||
t.Errorf("Expected lag 1, got %d", lag)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OffsetSeekingAndRanges", func(t *testing.T) {
|
||||
// Create subscription at specific offset
|
||||
sub, err := integration.CreateSubscription(
|
||||
"seek-test-sub",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_EXACT_OFFSET,
|
||||
1,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription at offset 1: %v", err)
|
||||
}
|
||||
|
||||
// Verify starting position
|
||||
if sub.CurrentOffset != 1 {
|
||||
t.Errorf("Expected current offset 1, got %d", sub.CurrentOffset)
|
||||
}
|
||||
|
||||
// Get offset range
|
||||
offsetRange, err := sub.GetOffsetRange(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get offset range: %v", err)
|
||||
}
|
||||
|
||||
if offsetRange.StartOffset != 1 {
|
||||
t.Errorf("Expected start offset 1, got %d", offsetRange.StartOffset)
|
||||
}
|
||||
|
||||
if offsetRange.Count != 2 {
|
||||
t.Errorf("Expected count 2, got %d", offsetRange.Count)
|
||||
}
|
||||
|
||||
// Seek to different offset
|
||||
err = sub.SeekToOffset(0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to seek to offset 0: %v", err)
|
||||
}
|
||||
|
||||
if sub.CurrentOffset != 0 {
|
||||
t.Errorf("Expected current offset 0 after seek, got %d", sub.CurrentOffset)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PartitionInformationAndMetrics", func(t *testing.T) {
|
||||
// Get partition offset info
|
||||
info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get partition offset info: %v", err)
|
||||
}
|
||||
|
||||
if info.EarliestOffset != 0 {
|
||||
t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
|
||||
}
|
||||
|
||||
if info.LatestOffset != 2 {
|
||||
t.Errorf("Expected latest offset 2, got %d", info.LatestOffset)
|
||||
}
|
||||
|
||||
if info.HighWaterMark != 3 {
|
||||
t.Errorf("Expected high water mark 3, got %d", info.HighWaterMark)
|
||||
}
|
||||
|
||||
if info.ActiveSubscriptions != 2 { // Two subscriptions created above
|
||||
t.Errorf("Expected 2 active subscriptions, got %d", info.ActiveSubscriptions)
|
||||
}
|
||||
|
||||
// Get offset metrics
|
||||
metrics := integration.GetOffsetMetrics()
|
||||
if metrics.PartitionCount != 1 {
|
||||
t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount)
|
||||
}
|
||||
|
||||
if metrics.ActiveSubscriptions != 2 {
|
||||
t.Errorf("Expected 2 active subscriptions in metrics, got %d", metrics.ActiveSubscriptions)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestOffsetPersistenceAcrossRestarts tests that offsets persist across system restarts
|
||||
func TestOffsetPersistenceAcrossRestarts(t *testing.T) {
|
||||
// Create temporary database
|
||||
tmpFile, err := os.CreateTemp("", "persistence_test_*.db")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp database: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
partition := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
var lastOffset int64
|
||||
|
||||
// First session: Create database and assign offsets
|
||||
{
|
||||
db, err := CreateDatabase(tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
|
||||
// Publish some records
|
||||
records := []PublishRecordRequest{
|
||||
{Key: []byte("msg1"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("msg2"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("msg3"), Value: &schema_pb.RecordValue{}},
|
||||
}
|
||||
|
||||
response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to publish records: %v", err)
|
||||
}
|
||||
|
||||
lastOffset = response.LastOffset
|
||||
|
||||
// Close connections - Close integration first to trigger final checkpoint
|
||||
integration.Close()
|
||||
storage.Close()
|
||||
db.Close()
|
||||
}
|
||||
|
||||
// Second session: Reopen database and verify persistence
|
||||
{
|
||||
db, err := CreateDatabase(tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to reopen database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
|
||||
// Verify high water mark persisted
|
||||
hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get high water mark after restart: %v", err)
|
||||
}
|
||||
|
||||
if hwm != lastOffset+1 {
|
||||
t.Errorf("Expected high water mark %d after restart, got %d", lastOffset+1, hwm)
|
||||
}
|
||||
|
||||
// Assign new offsets and verify continuity
|
||||
newResponse, err := integration.PublishRecord("test-namespace", "test-topic", partition, []byte("msg4"), &schema_pb.RecordValue{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to publish new record after restart: %v", err)
|
||||
}
|
||||
|
||||
expectedNextOffset := lastOffset + 1
|
||||
if newResponse.BaseOffset != expectedNextOffset {
|
||||
t.Errorf("Expected next offset %d after restart, got %d", expectedNextOffset, newResponse.BaseOffset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentOffsetOperations tests concurrent offset operations
|
||||
func TestConcurrentOffsetOperations(t *testing.T) {
|
||||
// Create temporary database
|
||||
tmpFile, err := os.CreateTemp("", "concurrent_test_*.db")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp database: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
db, err := CreateDatabase(tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
|
||||
partition := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
// Concurrent publishers
|
||||
const numPublishers = 5
|
||||
const recordsPerPublisher = 10
|
||||
|
||||
done := make(chan bool, numPublishers)
|
||||
|
||||
for i := 0; i < numPublishers; i++ {
|
||||
go func(publisherID int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := 0; j < recordsPerPublisher; j++ {
|
||||
key := fmt.Sprintf("publisher-%d-msg-%d", publisherID, j)
|
||||
_, err := integration.PublishRecord("test-namespace", "test-topic", partition, []byte(key), &schema_pb.RecordValue{})
|
||||
if err != nil {
|
||||
t.Errorf("Publisher %d failed to publish message %d: %v", publisherID, j, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all publishers to complete
|
||||
for i := 0; i < numPublishers; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify total records
|
||||
hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get high water mark: %v", err)
|
||||
}
|
||||
|
||||
expectedTotal := int64(numPublishers * recordsPerPublisher)
|
||||
if hwm != expectedTotal {
|
||||
t.Errorf("Expected high water mark %d, got %d", expectedTotal, hwm)
|
||||
}
|
||||
|
||||
// Verify no duplicate offsets
|
||||
info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get partition info: %v", err)
|
||||
}
|
||||
|
||||
if info.RecordCount != expectedTotal {
|
||||
t.Errorf("Expected record count %d, got %d", expectedTotal, info.RecordCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOffsetValidationAndErrorHandling tests error conditions and validation
|
||||
func TestOffsetValidationAndErrorHandling(t *testing.T) {
|
||||
// Create temporary database
|
||||
tmpFile, err := os.CreateTemp("", "validation_test_*.db")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp database: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
db, err := CreateDatabase(tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
|
||||
partition := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
t.Run("InvalidOffsetSubscription", func(t *testing.T) {
|
||||
// Try to create subscription with invalid offset
|
||||
_, err := integration.CreateSubscription(
|
||||
"invalid-sub",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_EXACT_OFFSET,
|
||||
100, // Beyond any existing data
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("Expected error for subscription beyond high water mark")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NegativeOffsetValidation", func(t *testing.T) {
|
||||
// Try to create subscription with negative offset
|
||||
_, err := integration.CreateSubscription(
|
||||
"negative-sub",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_EXACT_OFFSET,
|
||||
-1,
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("Expected error for negative offset")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DuplicateSubscriptionID", func(t *testing.T) {
|
||||
// Create first subscription
|
||||
_, err := integration.CreateSubscription(
|
||||
"duplicate-id",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_RESET_TO_EARLIEST,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create first subscription: %v", err)
|
||||
}
|
||||
|
||||
// Try to create duplicate
|
||||
_, err = integration.CreateSubscription(
|
||||
"duplicate-id",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_RESET_TO_EARLIEST,
|
||||
0,
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("Expected error for duplicate subscription ID")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OffsetRangeValidation", func(t *testing.T) {
|
||||
// Add some data first
|
||||
integration.PublishRecord("test-namespace", "test-topic", partition, []byte("test"), &schema_pb.RecordValue{})
|
||||
|
||||
// Test invalid range validation
|
||||
err := integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 5, 10) // Beyond high water mark
|
||||
if err == nil {
|
||||
t.Error("Expected error for range beyond high water mark")
|
||||
}
|
||||
|
||||
err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 10, 5) // End before start
|
||||
if err == nil {
|
||||
t.Error("Expected error for end offset before start offset")
|
||||
}
|
||||
|
||||
err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, -1, 5) // Negative start
|
||||
if err == nil {
|
||||
t.Error("Expected error for negative start offset")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -93,9 +93,3 @@ func (f *FilerOffsetStorage) getPartitionDir(namespace, topicName string, partit
|
||||
|
||||
return fmt.Sprintf("%s/%s/%s/%s/%s", filer.TopicsDir, namespace, topicName, version, partitionRange)
|
||||
}
|
||||
|
||||
// getPartitionKey generates a unique key for a partition
|
||||
func (f *FilerOffsetStorage) getPartitionKey(partition *schema_pb.Partition) string {
|
||||
return fmt.Sprintf("ring:%d:range:%d-%d:time:%d",
|
||||
partition.RingSize, partition.RangeStart, partition.RangeStop, partition.UnixTimeNs)
|
||||
}
|
||||
|
||||
@@ -1,544 +0,0 @@
|
||||
package offset
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
||||
)
|
||||
|
||||
func TestSMQOffsetIntegration_PublishRecord(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Publish a single record
|
||||
response, err := integration.PublishRecord(
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
[]byte("test-key"),
|
||||
&schema_pb.RecordValue{},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to publish record: %v", err)
|
||||
}
|
||||
|
||||
if response.Error != "" {
|
||||
t.Errorf("Expected no error, got: %s", response.Error)
|
||||
}
|
||||
|
||||
if response.BaseOffset != 0 {
|
||||
t.Errorf("Expected base offset 0, got %d", response.BaseOffset)
|
||||
}
|
||||
|
||||
if response.LastOffset != 0 {
|
||||
t.Errorf("Expected last offset 0, got %d", response.LastOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_PublishRecordBatch(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Create batch of records
|
||||
records := []PublishRecordRequest{
|
||||
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
|
||||
}
|
||||
|
||||
// Publish batch
|
||||
response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to publish record batch: %v", err)
|
||||
}
|
||||
|
||||
if response.Error != "" {
|
||||
t.Errorf("Expected no error, got: %s", response.Error)
|
||||
}
|
||||
|
||||
if response.BaseOffset != 0 {
|
||||
t.Errorf("Expected base offset 0, got %d", response.BaseOffset)
|
||||
}
|
||||
|
||||
if response.LastOffset != 2 {
|
||||
t.Errorf("Expected last offset 2, got %d", response.LastOffset)
|
||||
}
|
||||
|
||||
// Verify high water mark
|
||||
hwm, err := integration.GetHighWaterMark("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get high water mark: %v", err)
|
||||
}
|
||||
|
||||
if hwm != 3 {
|
||||
t.Errorf("Expected high water mark 3, got %d", hwm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_EmptyBatch(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Publish empty batch
|
||||
response, err := integration.PublishRecordBatch("test-namespace", "test-topic", partition, []PublishRecordRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to publish empty batch: %v", err)
|
||||
}
|
||||
|
||||
if response.Error == "" {
|
||||
t.Error("Expected error for empty batch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_CreateSubscription(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Publish some records first
|
||||
records := []PublishRecordRequest{
|
||||
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
|
||||
}
|
||||
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
|
||||
|
||||
// Create subscription
|
||||
sub, err := integration.CreateSubscription(
|
||||
"test-sub",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_RESET_TO_EARLIEST,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
if sub.ID != "test-sub" {
|
||||
t.Errorf("Expected subscription ID 'test-sub', got %s", sub.ID)
|
||||
}
|
||||
|
||||
if sub.StartOffset != 0 {
|
||||
t.Errorf("Expected start offset 0, got %d", sub.StartOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_SubscribeRecords(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Publish some records
|
||||
records := []PublishRecordRequest{
|
||||
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
|
||||
}
|
||||
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
|
||||
|
||||
// Create subscription
|
||||
sub, err := integration.CreateSubscription(
|
||||
"test-sub",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_RESET_TO_EARLIEST,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Subscribe to records
|
||||
responses, err := integration.SubscribeRecords(sub, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to subscribe to records: %v", err)
|
||||
}
|
||||
|
||||
if len(responses) != 2 {
|
||||
t.Errorf("Expected 2 responses, got %d", len(responses))
|
||||
}
|
||||
|
||||
// Check offset progression
|
||||
if responses[0].Offset != 0 {
|
||||
t.Errorf("Expected first record offset 0, got %d", responses[0].Offset)
|
||||
}
|
||||
|
||||
if responses[1].Offset != 1 {
|
||||
t.Errorf("Expected second record offset 1, got %d", responses[1].Offset)
|
||||
}
|
||||
|
||||
// Check subscription advancement
|
||||
if sub.CurrentOffset != 2 {
|
||||
t.Errorf("Expected subscription current offset 2, got %d", sub.CurrentOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_SubscribeEmptyPartition(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Create subscription on empty partition
|
||||
sub, err := integration.CreateSubscription(
|
||||
"empty-sub",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_RESET_TO_EARLIEST,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Subscribe to records (should return empty)
|
||||
responses, err := integration.SubscribeRecords(sub, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to subscribe to empty partition: %v", err)
|
||||
}
|
||||
|
||||
if len(responses) != 0 {
|
||||
t.Errorf("Expected 0 responses from empty partition, got %d", len(responses))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_SeekSubscription(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Publish records
|
||||
records := []PublishRecordRequest{
|
||||
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key4"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key5"), Value: &schema_pb.RecordValue{}},
|
||||
}
|
||||
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
|
||||
|
||||
// Create subscription
|
||||
sub, err := integration.CreateSubscription(
|
||||
"seek-sub",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_RESET_TO_EARLIEST,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Seek to offset 3
|
||||
err = integration.SeekSubscription("seek-sub", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to seek subscription: %v", err)
|
||||
}
|
||||
|
||||
if sub.CurrentOffset != 3 {
|
||||
t.Errorf("Expected current offset 3 after seek, got %d", sub.CurrentOffset)
|
||||
}
|
||||
|
||||
// Subscribe from new position
|
||||
responses, err := integration.SubscribeRecords(sub, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to subscribe after seek: %v", err)
|
||||
}
|
||||
|
||||
if len(responses) != 2 {
|
||||
t.Errorf("Expected 2 responses after seek, got %d", len(responses))
|
||||
}
|
||||
|
||||
if responses[0].Offset != 3 {
|
||||
t.Errorf("Expected first record offset 3 after seek, got %d", responses[0].Offset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_GetSubscriptionLag(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Publish records
|
||||
records := []PublishRecordRequest{
|
||||
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
|
||||
}
|
||||
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
|
||||
|
||||
// Create subscription at offset 1
|
||||
sub, err := integration.CreateSubscription(
|
||||
"lag-sub",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_EXACT_OFFSET,
|
||||
1,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Get lag
|
||||
lag, err := integration.GetSubscriptionLag("lag-sub")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get subscription lag: %v", err)
|
||||
}
|
||||
|
||||
expectedLag := int64(3 - 1) // hwm - current
|
||||
if lag != expectedLag {
|
||||
t.Errorf("Expected lag %d, got %d", expectedLag, lag)
|
||||
}
|
||||
|
||||
// Advance subscription and check lag again
|
||||
integration.SubscribeRecords(sub, 1)
|
||||
|
||||
lag, err = integration.GetSubscriptionLag("lag-sub")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get lag after advance: %v", err)
|
||||
}
|
||||
|
||||
expectedLag = int64(3 - 2) // hwm - current
|
||||
if lag != expectedLag {
|
||||
t.Errorf("Expected lag %d after advance, got %d", expectedLag, lag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_CloseSubscription(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Create subscription
|
||||
_, err := integration.CreateSubscription(
|
||||
"close-sub",
|
||||
"test-namespace", "test-topic",
|
||||
partition,
|
||||
schema_pb.OffsetType_RESET_TO_EARLIEST,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Close subscription
|
||||
err = integration.CloseSubscription("close-sub")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to close subscription: %v", err)
|
||||
}
|
||||
|
||||
// Try to get lag (should fail)
|
||||
_, err = integration.GetSubscriptionLag("close-sub")
|
||||
if err == nil {
|
||||
t.Error("Expected error when getting lag for closed subscription")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_ValidateOffsetRange(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Publish some records
|
||||
records := []PublishRecordRequest{
|
||||
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
|
||||
}
|
||||
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
|
||||
|
||||
// Test valid range
|
||||
err := integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 0, 2)
|
||||
if err != nil {
|
||||
t.Errorf("Valid range should not return error: %v", err)
|
||||
}
|
||||
|
||||
// Test invalid range (beyond hwm)
|
||||
err = integration.ValidateOffsetRange("test-namespace", "test-topic", partition, 0, 5)
|
||||
if err == nil {
|
||||
t.Error("Expected error for range beyond high water mark")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_GetAvailableOffsetRange(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Test empty partition
|
||||
offsetRange, err := integration.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get available range for empty partition: %v", err)
|
||||
}
|
||||
|
||||
if offsetRange.Count != 0 {
|
||||
t.Errorf("Expected empty range for empty partition, got count %d", offsetRange.Count)
|
||||
}
|
||||
|
||||
// Publish records
|
||||
records := []PublishRecordRequest{
|
||||
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
|
||||
}
|
||||
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
|
||||
|
||||
// Test with data
|
||||
offsetRange, err = integration.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get available range: %v", err)
|
||||
}
|
||||
|
||||
if offsetRange.StartOffset != 0 {
|
||||
t.Errorf("Expected start offset 0, got %d", offsetRange.StartOffset)
|
||||
}
|
||||
|
||||
if offsetRange.EndOffset != 1 {
|
||||
t.Errorf("Expected end offset 1, got %d", offsetRange.EndOffset)
|
||||
}
|
||||
|
||||
if offsetRange.Count != 2 {
|
||||
t.Errorf("Expected count 2, got %d", offsetRange.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_GetOffsetMetrics(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Initial metrics
|
||||
metrics := integration.GetOffsetMetrics()
|
||||
if metrics.TotalOffsets != 0 {
|
||||
t.Errorf("Expected 0 total offsets initially, got %d", metrics.TotalOffsets)
|
||||
}
|
||||
|
||||
if metrics.ActiveSubscriptions != 0 {
|
||||
t.Errorf("Expected 0 active subscriptions initially, got %d", metrics.ActiveSubscriptions)
|
||||
}
|
||||
|
||||
// Publish records
|
||||
records := []PublishRecordRequest{
|
||||
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
|
||||
}
|
||||
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
|
||||
|
||||
// Create subscriptions
|
||||
integration.CreateSubscription("sub1", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
|
||||
integration.CreateSubscription("sub2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
|
||||
|
||||
// Check updated metrics
|
||||
metrics = integration.GetOffsetMetrics()
|
||||
if metrics.TotalOffsets != 2 {
|
||||
t.Errorf("Expected 2 total offsets, got %d", metrics.TotalOffsets)
|
||||
}
|
||||
|
||||
if metrics.ActiveSubscriptions != 2 {
|
||||
t.Errorf("Expected 2 active subscriptions, got %d", metrics.ActiveSubscriptions)
|
||||
}
|
||||
|
||||
if metrics.PartitionCount != 1 {
|
||||
t.Errorf("Expected 1 partition, got %d", metrics.PartitionCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_GetOffsetInfo(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Test non-existent offset
|
||||
info, err := integration.GetOffsetInfo("test-namespace", "test-topic", partition, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get offset info: %v", err)
|
||||
}
|
||||
|
||||
if info.Exists {
|
||||
t.Error("Offset should not exist in empty partition")
|
||||
}
|
||||
|
||||
// Publish record
|
||||
integration.PublishRecord("test-namespace", "test-topic", partition, []byte("key1"), &schema_pb.RecordValue{})
|
||||
|
||||
// Test existing offset
|
||||
info, err = integration.GetOffsetInfo("test-namespace", "test-topic", partition, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get offset info for existing offset: %v", err)
|
||||
}
|
||||
|
||||
if !info.Exists {
|
||||
t.Error("Offset should exist after publishing")
|
||||
}
|
||||
|
||||
if info.Offset != 0 {
|
||||
t.Errorf("Expected offset 0, got %d", info.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSMQOffsetIntegration_GetPartitionOffsetInfo(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
integration := NewSMQOffsetIntegration(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Test empty partition
|
||||
info, err := integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get partition offset info: %v", err)
|
||||
}
|
||||
|
||||
if info.EarliestOffset != 0 {
|
||||
t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
|
||||
}
|
||||
|
||||
if info.LatestOffset != -1 {
|
||||
t.Errorf("Expected latest offset -1 for empty partition, got %d", info.LatestOffset)
|
||||
}
|
||||
|
||||
if info.HighWaterMark != 0 {
|
||||
t.Errorf("Expected high water mark 0, got %d", info.HighWaterMark)
|
||||
}
|
||||
|
||||
if info.RecordCount != 0 {
|
||||
t.Errorf("Expected record count 0, got %d", info.RecordCount)
|
||||
}
|
||||
|
||||
// Publish records
|
||||
records := []PublishRecordRequest{
|
||||
{Key: []byte("key1"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key2"), Value: &schema_pb.RecordValue{}},
|
||||
{Key: []byte("key3"), Value: &schema_pb.RecordValue{}},
|
||||
}
|
||||
integration.PublishRecordBatch("test-namespace", "test-topic", partition, records)
|
||||
|
||||
// Create subscription
|
||||
integration.CreateSubscription("test-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
|
||||
|
||||
// Test with data
|
||||
info, err = integration.GetPartitionOffsetInfo("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get partition offset info with data: %v", err)
|
||||
}
|
||||
|
||||
if info.EarliestOffset != 0 {
|
||||
t.Errorf("Expected earliest offset 0, got %d", info.EarliestOffset)
|
||||
}
|
||||
|
||||
if info.LatestOffset != 2 {
|
||||
t.Errorf("Expected latest offset 2, got %d", info.LatestOffset)
|
||||
}
|
||||
|
||||
if info.HighWaterMark != 3 {
|
||||
t.Errorf("Expected high water mark 3, got %d", info.HighWaterMark)
|
||||
}
|
||||
|
||||
if info.RecordCount != 3 {
|
||||
t.Errorf("Expected record count 3, got %d", info.RecordCount)
|
||||
}
|
||||
|
||||
if info.ActiveSubscriptions != 1 {
|
||||
t.Errorf("Expected 1 active subscription, got %d", info.ActiveSubscriptions)
|
||||
}
|
||||
}
|
||||
@@ -338,13 +338,6 @@ type OffsetAssigner struct {
|
||||
registry *PartitionOffsetRegistry
|
||||
}
|
||||
|
||||
// NewOffsetAssigner creates a new offset assigner
|
||||
func NewOffsetAssigner(storage OffsetStorage) *OffsetAssigner {
|
||||
return &OffsetAssigner{
|
||||
registry: NewPartitionOffsetRegistry(storage),
|
||||
}
|
||||
}
|
||||
|
||||
// AssignSingleOffset assigns a single offset with timestamp
|
||||
func (a *OffsetAssigner) AssignSingleOffset(namespace, topicName string, partition *schema_pb.Partition) *AssignmentResult {
|
||||
offset, err := a.registry.AssignOffset(namespace, topicName, partition)
|
||||
|
||||
@@ -1,388 +0,0 @@
|
||||
package offset
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
||||
)
|
||||
|
||||
func createTestPartition() *schema_pb.Partition {
|
||||
return &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionOffsetManager_BasicAssignment(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
partition := createTestPartition()
|
||||
|
||||
manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create offset manager: %v", err)
|
||||
}
|
||||
|
||||
// Test sequential offset assignment
|
||||
for i := int64(0); i < 10; i++ {
|
||||
offset := manager.AssignOffset()
|
||||
if offset != i {
|
||||
t.Errorf("Expected offset %d, got %d", i, offset)
|
||||
}
|
||||
}
|
||||
|
||||
// Test high water mark
|
||||
hwm := manager.GetHighWaterMark()
|
||||
if hwm != 10 {
|
||||
t.Errorf("Expected high water mark 10, got %d", hwm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionOffsetManager_BatchAssignment(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
partition := createTestPartition()
|
||||
|
||||
manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create offset manager: %v", err)
|
||||
}
|
||||
|
||||
// Assign batch of 5 offsets
|
||||
baseOffset, lastOffset := manager.AssignOffsets(5)
|
||||
if baseOffset != 0 {
|
||||
t.Errorf("Expected base offset 0, got %d", baseOffset)
|
||||
}
|
||||
if lastOffset != 4 {
|
||||
t.Errorf("Expected last offset 4, got %d", lastOffset)
|
||||
}
|
||||
|
||||
// Assign another batch
|
||||
baseOffset, lastOffset = manager.AssignOffsets(3)
|
||||
if baseOffset != 5 {
|
||||
t.Errorf("Expected base offset 5, got %d", baseOffset)
|
||||
}
|
||||
if lastOffset != 7 {
|
||||
t.Errorf("Expected last offset 7, got %d", lastOffset)
|
||||
}
|
||||
|
||||
// Check high water mark
|
||||
hwm := manager.GetHighWaterMark()
|
||||
if hwm != 8 {
|
||||
t.Errorf("Expected high water mark 8, got %d", hwm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionOffsetManager_Recovery(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
partition := createTestPartition()
|
||||
|
||||
// Create manager and assign some offsets
|
||||
manager1, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create offset manager: %v", err)
|
||||
}
|
||||
|
||||
// Assign offsets and simulate records
|
||||
for i := 0; i < 150; i++ { // More than checkpoint interval
|
||||
offset := manager1.AssignOffset()
|
||||
storage.AddRecord("test-namespace", "test-topic", partition, offset)
|
||||
}
|
||||
|
||||
// Wait for checkpoint to complete
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create new manager (simulates restart)
|
||||
manager2, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create offset manager after recovery: %v", err)
|
||||
}
|
||||
|
||||
// Next offset should continue from checkpoint + 1
|
||||
// With checkpoint interval 100, checkpoint happens at offset 100
|
||||
// So recovery should start from 101, but we assigned 150 offsets (0-149)
|
||||
// The checkpoint should be at 100, so next offset should be 101
|
||||
// But since we have records up to 149, it should recover from storage scan
|
||||
nextOffset := manager2.AssignOffset()
|
||||
if nextOffset != 150 {
|
||||
t.Errorf("Expected next offset 150 after recovery, got %d", nextOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionOffsetManager_RecoveryFromStorage(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
partition := createTestPartition()
|
||||
|
||||
// Simulate existing records in storage without checkpoint
|
||||
for i := int64(0); i < 50; i++ {
|
||||
storage.AddRecord("test-namespace", "test-topic", partition, i)
|
||||
}
|
||||
|
||||
// Create manager - should recover from storage scan
|
||||
manager, err := NewPartitionOffsetManager("test-namespace", "test-topic", partition, storage)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create offset manager: %v", err)
|
||||
}
|
||||
|
||||
// Next offset should be 50
|
||||
nextOffset := manager.AssignOffset()
|
||||
if nextOffset != 50 {
|
||||
t.Errorf("Expected next offset 50 after storage recovery, got %d", nextOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionOffsetRegistry_MultiplePartitions(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
|
||||
// Create different partitions
|
||||
partition1 := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
partition2 := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 32,
|
||||
RangeStop: 63,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
// Assign offsets to different partitions
|
||||
offset1, err := registry.AssignOffset("test-namespace", "test-topic", partition1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign offset to partition1: %v", err)
|
||||
}
|
||||
if offset1 != 0 {
|
||||
t.Errorf("Expected offset 0 for partition1, got %d", offset1)
|
||||
}
|
||||
|
||||
offset2, err := registry.AssignOffset("test-namespace", "test-topic", partition2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign offset to partition2: %v", err)
|
||||
}
|
||||
if offset2 != 0 {
|
||||
t.Errorf("Expected offset 0 for partition2, got %d", offset2)
|
||||
}
|
||||
|
||||
// Assign more offsets to partition1
|
||||
offset1_2, err := registry.AssignOffset("test-namespace", "test-topic", partition1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign second offset to partition1: %v", err)
|
||||
}
|
||||
if offset1_2 != 1 {
|
||||
t.Errorf("Expected offset 1 for partition1, got %d", offset1_2)
|
||||
}
|
||||
|
||||
// Partition2 should still be at 0 for next assignment
|
||||
offset2_2, err := registry.AssignOffset("test-namespace", "test-topic", partition2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign second offset to partition2: %v", err)
|
||||
}
|
||||
if offset2_2 != 1 {
|
||||
t.Errorf("Expected offset 1 for partition2, got %d", offset2_2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionOffsetRegistry_BatchAssignment(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Assign batch of offsets
|
||||
baseOffset, lastOffset, err := registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to assign batch offsets: %v", err)
|
||||
}
|
||||
|
||||
if baseOffset != 0 {
|
||||
t.Errorf("Expected base offset 0, got %d", baseOffset)
|
||||
}
|
||||
if lastOffset != 9 {
|
||||
t.Errorf("Expected last offset 9, got %d", lastOffset)
|
||||
}
|
||||
|
||||
// Get high water mark
|
||||
hwm, err := registry.GetHighWaterMark("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get high water mark: %v", err)
|
||||
}
|
||||
if hwm != 10 {
|
||||
t.Errorf("Expected high water mark 10, got %d", hwm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetAssigner_SingleAssignment(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
assigner := NewOffsetAssigner(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Assign single offset
|
||||
result := assigner.AssignSingleOffset("test-namespace", "test-topic", partition)
|
||||
if result.Error != nil {
|
||||
t.Fatalf("Failed to assign single offset: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.Assignment == nil {
|
||||
t.Fatal("Assignment result is nil")
|
||||
}
|
||||
|
||||
if result.Assignment.Offset != 0 {
|
||||
t.Errorf("Expected offset 0, got %d", result.Assignment.Offset)
|
||||
}
|
||||
|
||||
if result.Assignment.Partition != partition {
|
||||
t.Error("Partition mismatch in assignment")
|
||||
}
|
||||
|
||||
if result.Assignment.Timestamp <= 0 {
|
||||
t.Error("Timestamp should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetAssigner_BatchAssignment(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
assigner := NewOffsetAssigner(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Assign batch of offsets
|
||||
result := assigner.AssignBatchOffsets("test-namespace", "test-topic", partition, 5)
|
||||
if result.Error != nil {
|
||||
t.Fatalf("Failed to assign batch offsets: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.Batch == nil {
|
||||
t.Fatal("Batch result is nil")
|
||||
}
|
||||
|
||||
if result.Batch.BaseOffset != 0 {
|
||||
t.Errorf("Expected base offset 0, got %d", result.Batch.BaseOffset)
|
||||
}
|
||||
|
||||
if result.Batch.LastOffset != 4 {
|
||||
t.Errorf("Expected last offset 4, got %d", result.Batch.LastOffset)
|
||||
}
|
||||
|
||||
if result.Batch.Count != 5 {
|
||||
t.Errorf("Expected count 5, got %d", result.Batch.Count)
|
||||
}
|
||||
|
||||
if result.Batch.Timestamp <= 0 {
|
||||
t.Error("Timestamp should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetAssigner_HighWaterMark(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
assigner := NewOffsetAssigner(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Initially should be 0
|
||||
hwm, err := assigner.GetHighWaterMark("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get initial high water mark: %v", err)
|
||||
}
|
||||
if hwm != 0 {
|
||||
t.Errorf("Expected initial high water mark 0, got %d", hwm)
|
||||
}
|
||||
|
||||
// Assign some offsets
|
||||
assigner.AssignBatchOffsets("test-namespace", "test-topic", partition, 10)
|
||||
|
||||
// High water mark should be updated
|
||||
hwm, err = assigner.GetHighWaterMark("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get high water mark after assignment: %v", err)
|
||||
}
|
||||
if hwm != 10 {
|
||||
t.Errorf("Expected high water mark 10, got %d", hwm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartitionKey(t *testing.T) {
|
||||
partition1 := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: 1234567890,
|
||||
}
|
||||
|
||||
partition2 := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: 1234567890,
|
||||
}
|
||||
|
||||
partition3 := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 32,
|
||||
RangeStop: 63,
|
||||
UnixTimeNs: 1234567890,
|
||||
}
|
||||
|
||||
key1 := partitionKey(partition1)
|
||||
key2 := partitionKey(partition2)
|
||||
key3 := partitionKey(partition3)
|
||||
|
||||
// Same partitions should have same key
|
||||
if key1 != key2 {
|
||||
t.Errorf("Same partitions should have same key: %s vs %s", key1, key2)
|
||||
}
|
||||
|
||||
// Different partitions should have different keys
|
||||
if key1 == key3 {
|
||||
t.Errorf("Different partitions should have different keys: %s vs %s", key1, key3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentOffsetAssignment(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
partition := createTestPartition()
|
||||
|
||||
const numGoroutines = 10
|
||||
const offsetsPerGoroutine = 100
|
||||
|
||||
results := make(chan int64, numGoroutines*offsetsPerGoroutine)
|
||||
|
||||
// Start concurrent offset assignments
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
for j := 0; j < offsetsPerGoroutine; j++ {
|
||||
offset, err := registry.AssignOffset("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to assign offset: %v", err)
|
||||
return
|
||||
}
|
||||
results <- offset
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect all results
|
||||
offsets := make(map[int64]bool)
|
||||
for i := 0; i < numGoroutines*offsetsPerGoroutine; i++ {
|
||||
offset := <-results
|
||||
if offsets[offset] {
|
||||
t.Errorf("Duplicate offset assigned: %d", offset)
|
||||
}
|
||||
offsets[offset] = true
|
||||
}
|
||||
|
||||
// Verify we got all expected offsets
|
||||
expectedCount := numGoroutines * offsetsPerGoroutine
|
||||
if len(offsets) != expectedCount {
|
||||
t.Errorf("Expected %d unique offsets, got %d", expectedCount, len(offsets))
|
||||
}
|
||||
|
||||
// Verify offsets are in expected range
|
||||
for offset := range offsets {
|
||||
if offset < 0 || offset >= int64(expectedCount) {
|
||||
t.Errorf("Offset %d is out of expected range [0, %d)", offset, expectedCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,302 +0,0 @@
|
||||
package offset
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MigrationVersion represents a database migration version
|
||||
type MigrationVersion struct {
|
||||
Version int
|
||||
Description string
|
||||
SQL string
|
||||
}
|
||||
|
||||
// GetMigrations returns all available migrations for offset storage
|
||||
func GetMigrations() []MigrationVersion {
|
||||
return []MigrationVersion{
|
||||
{
|
||||
Version: 1,
|
||||
Description: "Create initial offset storage tables",
|
||||
SQL: `
|
||||
-- Partition offset checkpoints table
|
||||
-- TODO: Add _index as computed column when supported by database
|
||||
CREATE TABLE IF NOT EXISTS partition_offset_checkpoints (
|
||||
partition_key TEXT PRIMARY KEY,
|
||||
ring_size INTEGER NOT NULL,
|
||||
range_start INTEGER NOT NULL,
|
||||
range_stop INTEGER NOT NULL,
|
||||
unix_time_ns INTEGER NOT NULL,
|
||||
checkpoint_offset INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
|
||||
-- Offset mappings table for detailed tracking
|
||||
-- TODO: Add _index as computed column when supported by database
|
||||
CREATE TABLE IF NOT EXISTS offset_mappings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
partition_key TEXT NOT NULL,
|
||||
kafka_offset INTEGER NOT NULL,
|
||||
smq_timestamp INTEGER NOT NULL,
|
||||
message_size INTEGER NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
UNIQUE(partition_key, kafka_offset)
|
||||
);
|
||||
|
||||
-- Schema migrations tracking table
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
description TEXT NOT NULL,
|
||||
applied_at INTEGER NOT NULL
|
||||
);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 2,
|
||||
Description: "Add indexes for performance optimization",
|
||||
SQL: `
|
||||
-- Indexes for performance
|
||||
CREATE INDEX IF NOT EXISTS idx_partition_offset_checkpoints_partition
|
||||
ON partition_offset_checkpoints(partition_key);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_offset_mappings_partition_offset
|
||||
ON offset_mappings(partition_key, kafka_offset);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_offset_mappings_timestamp
|
||||
ON offset_mappings(partition_key, smq_timestamp);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_offset_mappings_created_at
|
||||
ON offset_mappings(created_at);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: 3,
|
||||
Description: "Add partition metadata table for enhanced tracking",
|
||||
SQL: `
|
||||
-- Partition metadata table
|
||||
CREATE TABLE IF NOT EXISTS partition_metadata (
|
||||
partition_key TEXT PRIMARY KEY,
|
||||
ring_size INTEGER NOT NULL,
|
||||
range_start INTEGER NOT NULL,
|
||||
range_stop INTEGER NOT NULL,
|
||||
unix_time_ns INTEGER NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
last_activity_at INTEGER NOT NULL,
|
||||
record_count INTEGER DEFAULT 0,
|
||||
total_size INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
-- Index for partition metadata
|
||||
CREATE INDEX IF NOT EXISTS idx_partition_metadata_activity
|
||||
ON partition_metadata(last_activity_at);
|
||||
`,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// MigrationManager handles database schema migrations
|
||||
type MigrationManager struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewMigrationManager creates a new migration manager
|
||||
func NewMigrationManager(db *sql.DB) *MigrationManager {
|
||||
return &MigrationManager{db: db}
|
||||
}
|
||||
|
||||
// GetCurrentVersion returns the current schema version
|
||||
func (m *MigrationManager) GetCurrentVersion() (int, error) {
|
||||
// First, ensure the migrations table exists
|
||||
_, err := m.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
description TEXT NOT NULL,
|
||||
applied_at INTEGER NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to create migrations table: %w", err)
|
||||
}
|
||||
|
||||
var version sql.NullInt64
|
||||
err = m.db.QueryRow("SELECT MAX(version) FROM schema_migrations").Scan(&version)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get current version: %w", err)
|
||||
}
|
||||
|
||||
if !version.Valid {
|
||||
return 0, nil // No migrations applied yet
|
||||
}
|
||||
|
||||
return int(version.Int64), nil
|
||||
}
|
||||
|
||||
// ApplyMigrations applies all pending migrations
|
||||
func (m *MigrationManager) ApplyMigrations() error {
|
||||
currentVersion, err := m.GetCurrentVersion()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current version: %w", err)
|
||||
}
|
||||
|
||||
migrations := GetMigrations()
|
||||
|
||||
for _, migration := range migrations {
|
||||
if migration.Version <= currentVersion {
|
||||
continue // Already applied
|
||||
}
|
||||
|
||||
fmt.Printf("Applying migration %d: %s\n", migration.Version, migration.Description)
|
||||
|
||||
// Begin transaction
|
||||
tx, err := m.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction for migration %d: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
// Execute migration SQL
|
||||
_, err = tx.Exec(migration.SQL)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to execute migration %d: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
// Record migration as applied
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO schema_migrations (version, description, applied_at) VALUES (?, ?, ?)",
|
||||
migration.Version,
|
||||
migration.Description,
|
||||
getCurrentTimestamp(),
|
||||
)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to record migration %d: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully applied migration %d\n", migration.Version)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RollbackMigration rolls back a specific migration (if supported)
|
||||
func (m *MigrationManager) RollbackMigration(version int) error {
|
||||
// TODO: Implement rollback functionality
|
||||
// ASSUMPTION: For now, rollbacks are not supported as they require careful planning
|
||||
return fmt.Errorf("migration rollbacks not implemented - manual intervention required")
|
||||
}
|
||||
|
||||
// GetAppliedMigrations returns a list of all applied migrations
|
||||
func (m *MigrationManager) GetAppliedMigrations() ([]AppliedMigration, error) {
|
||||
rows, err := m.db.Query(`
|
||||
SELECT version, description, applied_at
|
||||
FROM schema_migrations
|
||||
ORDER BY version
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query applied migrations: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var migrations []AppliedMigration
|
||||
for rows.Next() {
|
||||
var migration AppliedMigration
|
||||
err := rows.Scan(&migration.Version, &migration.Description, &migration.AppliedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan migration: %w", err)
|
||||
}
|
||||
migrations = append(migrations, migration)
|
||||
}
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
// ValidateSchema validates that the database schema is up to date
|
||||
func (m *MigrationManager) ValidateSchema() error {
|
||||
currentVersion, err := m.GetCurrentVersion()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current version: %w", err)
|
||||
}
|
||||
|
||||
migrations := GetMigrations()
|
||||
if len(migrations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
latestVersion := migrations[len(migrations)-1].Version
|
||||
if currentVersion < latestVersion {
|
||||
return fmt.Errorf("schema is outdated: current version %d, latest version %d", currentVersion, latestVersion)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppliedMigration represents a migration that has been applied
|
||||
type AppliedMigration struct {
|
||||
Version int
|
||||
Description string
|
||||
AppliedAt int64
|
||||
}
|
||||
|
||||
// getCurrentTimestamp returns the current timestamp in nanoseconds
|
||||
func getCurrentTimestamp() int64 {
|
||||
return time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// CreateDatabase creates and initializes a new offset storage database
|
||||
func CreateDatabase(dbPath string) (*sql.DB, error) {
|
||||
// TODO: Support different database types (PostgreSQL, MySQL, etc.)
|
||||
// ASSUMPTION: Using SQLite for now, can be extended for other databases
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Configure SQLite for better performance
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode=WAL", // Write-Ahead Logging for better concurrency
|
||||
"PRAGMA synchronous=NORMAL", // Balance between safety and performance
|
||||
"PRAGMA cache_size=10000", // Increase cache size
|
||||
"PRAGMA foreign_keys=ON", // Enable foreign key constraints
|
||||
"PRAGMA temp_store=MEMORY", // Store temporary tables in memory
|
||||
}
|
||||
|
||||
for _, pragma := range pragmas {
|
||||
_, err := db.Exec(pragma)
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to set pragma %s: %w", pragma, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply migrations
|
||||
migrationManager := NewMigrationManager(db)
|
||||
err = migrationManager.ApplyMigrations()
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to apply migrations: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// BackupDatabase creates a backup of the offset storage database
|
||||
func BackupDatabase(sourceDB *sql.DB, backupPath string) error {
|
||||
// TODO: Implement database backup functionality
|
||||
// ASSUMPTION: This would use database-specific backup mechanisms
|
||||
return fmt.Errorf("database backup not implemented yet")
|
||||
}
|
||||
|
||||
// RestoreDatabase restores a database from a backup
|
||||
func RestoreDatabase(backupPath, targetPath string) error {
|
||||
// TODO: Implement database restore functionality
|
||||
// ASSUMPTION: This would use database-specific restore mechanisms
|
||||
return fmt.Errorf("database restore not implemented yet")
|
||||
}
|
||||
@@ -1,394 +0,0 @@
|
||||
package offset
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
||||
)
|
||||
|
||||
// OffsetEntry represents a mapping between Kafka offset and SMQ timestamp
|
||||
type OffsetEntry struct {
|
||||
KafkaOffset int64
|
||||
SMQTimestamp int64
|
||||
MessageSize int32
|
||||
}
|
||||
|
||||
// SQLOffsetStorage implements OffsetStorage using SQL database with _index column
|
||||
type SQLOffsetStorage struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewSQLOffsetStorage creates a new SQL-based offset storage
|
||||
func NewSQLOffsetStorage(db *sql.DB) (*SQLOffsetStorage, error) {
|
||||
storage := &SQLOffsetStorage{db: db}
|
||||
|
||||
// Initialize database schema
|
||||
if err := storage.initializeSchema(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize schema: %w", err)
|
||||
}
|
||||
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
// initializeSchema creates the necessary tables for offset storage
|
||||
func (s *SQLOffsetStorage) initializeSchema() error {
|
||||
// TODO: Create offset storage tables with _index as hidden column
|
||||
// ASSUMPTION: Using SQLite-compatible syntax, may need adaptation for other databases
|
||||
|
||||
queries := []string{
|
||||
// Partition offset checkpoints table
|
||||
// TODO: Add _index as computed column when supported by database
|
||||
// ASSUMPTION: Using regular columns for now, _index concept preserved for future enhancement
|
||||
`CREATE TABLE IF NOT EXISTS partition_offset_checkpoints (
|
||||
partition_key TEXT PRIMARY KEY,
|
||||
ring_size INTEGER NOT NULL,
|
||||
range_start INTEGER NOT NULL,
|
||||
range_stop INTEGER NOT NULL,
|
||||
unix_time_ns INTEGER NOT NULL,
|
||||
checkpoint_offset INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
)`,
|
||||
|
||||
// Offset mappings table for detailed tracking
|
||||
// TODO: Add _index as computed column when supported by database
|
||||
`CREATE TABLE IF NOT EXISTS offset_mappings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
partition_key TEXT NOT NULL,
|
||||
kafka_offset INTEGER NOT NULL,
|
||||
smq_timestamp INTEGER NOT NULL,
|
||||
message_size INTEGER NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
UNIQUE(partition_key, kafka_offset)
|
||||
)`,
|
||||
|
||||
// Indexes for performance
|
||||
`CREATE INDEX IF NOT EXISTS idx_partition_offset_checkpoints_partition
|
||||
ON partition_offset_checkpoints(partition_key)`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_offset_mappings_partition_offset
|
||||
ON offset_mappings(partition_key, kafka_offset)`,
|
||||
|
||||
`CREATE INDEX IF NOT EXISTS idx_offset_mappings_timestamp
|
||||
ON offset_mappings(partition_key, smq_timestamp)`,
|
||||
}
|
||||
|
||||
for _, query := range queries {
|
||||
if _, err := s.db.Exec(query); err != nil {
|
||||
return fmt.Errorf("failed to execute schema query: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveCheckpoint saves the checkpoint for a partition
|
||||
func (s *SQLOffsetStorage) SaveCheckpoint(namespace, topicName string, partition *schema_pb.Partition, offset int64) error {
|
||||
// Use TopicPartitionKey to ensure each topic has isolated checkpoint storage
|
||||
partitionKey := TopicPartitionKey(namespace, topicName, partition)
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
// TODO: Use UPSERT for better performance
|
||||
// ASSUMPTION: SQLite REPLACE syntax, may need adaptation for other databases
|
||||
query := `
|
||||
REPLACE INTO partition_offset_checkpoints
|
||||
(partition_key, ring_size, range_start, range_stop, unix_time_ns, checkpoint_offset, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := s.db.Exec(query,
|
||||
partitionKey,
|
||||
partition.RingSize,
|
||||
partition.RangeStart,
|
||||
partition.RangeStop,
|
||||
partition.UnixTimeNs,
|
||||
offset,
|
||||
now,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save checkpoint: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadCheckpoint loads the checkpoint for a partition
|
||||
func (s *SQLOffsetStorage) LoadCheckpoint(namespace, topicName string, partition *schema_pb.Partition) (int64, error) {
|
||||
// Use TopicPartitionKey to match SaveCheckpoint
|
||||
partitionKey := TopicPartitionKey(namespace, topicName, partition)
|
||||
|
||||
query := `
|
||||
SELECT checkpoint_offset
|
||||
FROM partition_offset_checkpoints
|
||||
WHERE partition_key = ?
|
||||
`
|
||||
|
||||
var checkpointOffset int64
|
||||
err := s.db.QueryRow(query, partitionKey).Scan(&checkpointOffset)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return -1, fmt.Errorf("no checkpoint found")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("failed to load checkpoint: %w", err)
|
||||
}
|
||||
|
||||
return checkpointOffset, nil
|
||||
}
|
||||
|
||||
// GetHighestOffset finds the highest offset in storage for a partition
|
||||
func (s *SQLOffsetStorage) GetHighestOffset(namespace, topicName string, partition *schema_pb.Partition) (int64, error) {
|
||||
// Use TopicPartitionKey to match SaveCheckpoint
|
||||
partitionKey := TopicPartitionKey(namespace, topicName, partition)
|
||||
|
||||
// TODO: Use _index column for efficient querying
|
||||
// ASSUMPTION: kafka_offset represents the sequential offset we're tracking
|
||||
query := `
|
||||
SELECT MAX(kafka_offset)
|
||||
FROM offset_mappings
|
||||
WHERE partition_key = ?
|
||||
`
|
||||
|
||||
var highestOffset sql.NullInt64
|
||||
err := s.db.QueryRow(query, partitionKey).Scan(&highestOffset)
|
||||
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("failed to get highest offset: %w", err)
|
||||
}
|
||||
|
||||
if !highestOffset.Valid {
|
||||
return -1, fmt.Errorf("no records found")
|
||||
}
|
||||
|
||||
return highestOffset.Int64, nil
|
||||
}
|
||||
|
||||
// SaveOffsetMapping stores an offset mapping (extends OffsetStorage interface)
|
||||
func (s *SQLOffsetStorage) SaveOffsetMapping(partitionKey string, kafkaOffset, smqTimestamp int64, size int32) error {
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
// TODO: Handle duplicate key conflicts gracefully
|
||||
// ASSUMPTION: Using INSERT OR REPLACE for conflict resolution
|
||||
query := `
|
||||
INSERT OR REPLACE INTO offset_mappings
|
||||
(partition_key, kafka_offset, smq_timestamp, message_size, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
_, err := s.db.Exec(query, partitionKey, kafkaOffset, smqTimestamp, size, now)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save offset mapping: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadOffsetMappings retrieves all offset mappings for a partition
|
||||
func (s *SQLOffsetStorage) LoadOffsetMappings(partitionKey string) ([]OffsetEntry, error) {
|
||||
// TODO: Add pagination for large result sets
|
||||
// ASSUMPTION: Loading all mappings for now, should be paginated in production
|
||||
query := `
|
||||
SELECT kafka_offset, smq_timestamp, message_size
|
||||
FROM offset_mappings
|
||||
WHERE partition_key = ?
|
||||
ORDER BY kafka_offset ASC
|
||||
`
|
||||
|
||||
rows, err := s.db.Query(query, partitionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query offset mappings: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []OffsetEntry
|
||||
for rows.Next() {
|
||||
var entry OffsetEntry
|
||||
err := rows.Scan(&entry.KafkaOffset, &entry.SMQTimestamp, &entry.MessageSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan offset entry: %w", err)
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating offset mappings: %w", err)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// GetOffsetMappingsByRange retrieves offset mappings within a specific range
|
||||
func (s *SQLOffsetStorage) GetOffsetMappingsByRange(partitionKey string, startOffset, endOffset int64) ([]OffsetEntry, error) {
|
||||
// TODO: Use _index column for efficient range queries
|
||||
query := `
|
||||
SELECT kafka_offset, smq_timestamp, message_size
|
||||
FROM offset_mappings
|
||||
WHERE partition_key = ? AND kafka_offset >= ? AND kafka_offset <= ?
|
||||
ORDER BY kafka_offset ASC
|
||||
`
|
||||
|
||||
rows, err := s.db.Query(query, partitionKey, startOffset, endOffset)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query offset range: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []OffsetEntry
|
||||
for rows.Next() {
|
||||
var entry OffsetEntry
|
||||
err := rows.Scan(&entry.KafkaOffset, &entry.SMQTimestamp, &entry.MessageSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan offset entry: %w", err)
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// GetPartitionStats returns statistics about a partition's offset usage
|
||||
func (s *SQLOffsetStorage) GetPartitionStats(partitionKey string) (*PartitionStats, error) {
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as record_count,
|
||||
MIN(kafka_offset) as earliest_offset,
|
||||
MAX(kafka_offset) as latest_offset,
|
||||
SUM(message_size) as total_size,
|
||||
MIN(created_at) as first_record_time,
|
||||
MAX(created_at) as last_record_time
|
||||
FROM offset_mappings
|
||||
WHERE partition_key = ?
|
||||
`
|
||||
|
||||
var stats PartitionStats
|
||||
var earliestOffset, latestOffset sql.NullInt64
|
||||
var totalSize sql.NullInt64
|
||||
var firstRecordTime, lastRecordTime sql.NullInt64
|
||||
|
||||
err := s.db.QueryRow(query, partitionKey).Scan(
|
||||
&stats.RecordCount,
|
||||
&earliestOffset,
|
||||
&latestOffset,
|
||||
&totalSize,
|
||||
&firstRecordTime,
|
||||
&lastRecordTime,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get partition stats: %w", err)
|
||||
}
|
||||
|
||||
stats.PartitionKey = partitionKey
|
||||
|
||||
if earliestOffset.Valid {
|
||||
stats.EarliestOffset = earliestOffset.Int64
|
||||
} else {
|
||||
stats.EarliestOffset = -1
|
||||
}
|
||||
|
||||
if latestOffset.Valid {
|
||||
stats.LatestOffset = latestOffset.Int64
|
||||
stats.HighWaterMark = latestOffset.Int64 + 1
|
||||
} else {
|
||||
stats.LatestOffset = -1
|
||||
stats.HighWaterMark = 0
|
||||
}
|
||||
|
||||
if firstRecordTime.Valid {
|
||||
stats.FirstRecordTime = firstRecordTime.Int64
|
||||
}
|
||||
|
||||
if lastRecordTime.Valid {
|
||||
stats.LastRecordTime = lastRecordTime.Int64
|
||||
}
|
||||
|
||||
if totalSize.Valid {
|
||||
stats.TotalSize = totalSize.Int64
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// CleanupOldMappings removes offset mappings older than the specified time
|
||||
func (s *SQLOffsetStorage) CleanupOldMappings(olderThanNs int64) error {
|
||||
// TODO: Add configurable cleanup policies
|
||||
// ASSUMPTION: Simple time-based cleanup, could be enhanced with retention policies
|
||||
query := `
|
||||
DELETE FROM offset_mappings
|
||||
WHERE created_at < ?
|
||||
`
|
||||
|
||||
result, err := s.db.Exec(query, olderThanNs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to cleanup old mappings: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected > 0 {
|
||||
// Log cleanup activity
|
||||
fmt.Printf("Cleaned up %d old offset mappings\n", rowsAffected)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (s *SQLOffsetStorage) Close() error {
|
||||
if s.db != nil {
|
||||
return s.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PartitionStats provides statistics about a partition's offset usage
|
||||
type PartitionStats struct {
|
||||
PartitionKey string
|
||||
RecordCount int64
|
||||
EarliestOffset int64
|
||||
LatestOffset int64
|
||||
HighWaterMark int64
|
||||
TotalSize int64
|
||||
FirstRecordTime int64
|
||||
LastRecordTime int64
|
||||
}
|
||||
|
||||
// GetAllPartitions returns a list of all partitions with offset data
|
||||
func (s *SQLOffsetStorage) GetAllPartitions() ([]string, error) {
|
||||
query := `
|
||||
SELECT DISTINCT partition_key
|
||||
FROM offset_mappings
|
||||
ORDER BY partition_key
|
||||
`
|
||||
|
||||
rows, err := s.db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get all partitions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var partitions []string
|
||||
for rows.Next() {
|
||||
var partitionKey string
|
||||
if err := rows.Scan(&partitionKey); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan partition key: %w", err)
|
||||
}
|
||||
partitions = append(partitions, partitionKey)
|
||||
}
|
||||
|
||||
return partitions, nil
|
||||
}
|
||||
|
||||
// Vacuum performs database maintenance operations
|
||||
func (s *SQLOffsetStorage) Vacuum() error {
|
||||
// TODO: Add database-specific optimization commands
|
||||
// ASSUMPTION: SQLite VACUUM command, may need adaptation for other databases
|
||||
_, err := s.db.Exec("VACUUM")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to vacuum database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,516 +0,0 @@
|
||||
package offset
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
||||
)
|
||||
|
||||
func createTestDB(t *testing.T) *sql.DB {
|
||||
// Create temporary database file
|
||||
tmpFile, err := os.CreateTemp("", "offset_test_*.db")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp database file: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
// Clean up the file when test completes
|
||||
t.Cleanup(func() {
|
||||
os.Remove(tmpFile.Name())
|
||||
})
|
||||
|
||||
db, err := sql.Open("sqlite3", tmpFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
})
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func createTestPartitionForSQL() *schema_pb.Partition {
|
||||
return &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 0,
|
||||
RangeStop: 31,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLOffsetStorage_InitializeSchema(t *testing.T) {
|
||||
db := createTestDB(t)
|
||||
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
// Verify tables were created
|
||||
tables := []string{
|
||||
"partition_offset_checkpoints",
|
||||
"offset_mappings",
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check table %s: %v", table, err)
|
||||
}
|
||||
|
||||
if count != 1 {
|
||||
t.Errorf("Table %s was not created", table)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLOffsetStorage_SaveLoadCheckpoint(t *testing.T) {
|
||||
db := createTestDB(t)
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
partition := createTestPartitionForSQL()
|
||||
|
||||
// Test saving checkpoint
|
||||
err = storage.SaveCheckpoint("test-namespace", "test-topic", partition, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save checkpoint: %v", err)
|
||||
}
|
||||
|
||||
// Test loading checkpoint
|
||||
checkpoint, err := storage.LoadCheckpoint("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load checkpoint: %v", err)
|
||||
}
|
||||
|
||||
if checkpoint != 100 {
|
||||
t.Errorf("Expected checkpoint 100, got %d", checkpoint)
|
||||
}
|
||||
|
||||
// Test updating checkpoint
|
||||
err = storage.SaveCheckpoint("test-namespace", "test-topic", partition, 200)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update checkpoint: %v", err)
|
||||
}
|
||||
|
||||
checkpoint, err = storage.LoadCheckpoint("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load updated checkpoint: %v", err)
|
||||
}
|
||||
|
||||
if checkpoint != 200 {
|
||||
t.Errorf("Expected updated checkpoint 200, got %d", checkpoint)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLOffsetStorage_LoadCheckpointNotFound(t *testing.T) {
|
||||
db := createTestDB(t)
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
partition := createTestPartitionForSQL()
|
||||
|
||||
// Test loading non-existent checkpoint
|
||||
_, err = storage.LoadCheckpoint("test-namespace", "test-topic", partition)
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent checkpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLOffsetStorage_SaveLoadOffsetMappings(t *testing.T) {
|
||||
db := createTestDB(t)
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
partition := createTestPartitionForSQL()
|
||||
partitionKey := partitionKey(partition)
|
||||
|
||||
// Save multiple offset mappings
|
||||
mappings := []struct {
|
||||
offset int64
|
||||
timestamp int64
|
||||
size int32
|
||||
}{
|
||||
{0, 1000, 100},
|
||||
{1, 2000, 150},
|
||||
{2, 3000, 200},
|
||||
}
|
||||
|
||||
for _, mapping := range mappings {
|
||||
err := storage.SaveOffsetMapping(partitionKey, mapping.offset, mapping.timestamp, mapping.size)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save offset mapping: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load offset mappings
|
||||
entries, err := storage.LoadOffsetMappings(partitionKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load offset mappings: %v", err)
|
||||
}
|
||||
|
||||
if len(entries) != len(mappings) {
|
||||
t.Errorf("Expected %d entries, got %d", len(mappings), len(entries))
|
||||
}
|
||||
|
||||
// Verify entries are sorted by offset
|
||||
for i, entry := range entries {
|
||||
expected := mappings[i]
|
||||
if entry.KafkaOffset != expected.offset {
|
||||
t.Errorf("Entry %d: expected offset %d, got %d", i, expected.offset, entry.KafkaOffset)
|
||||
}
|
||||
if entry.SMQTimestamp != expected.timestamp {
|
||||
t.Errorf("Entry %d: expected timestamp %d, got %d", i, expected.timestamp, entry.SMQTimestamp)
|
||||
}
|
||||
if entry.MessageSize != expected.size {
|
||||
t.Errorf("Entry %d: expected size %d, got %d", i, expected.size, entry.MessageSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLOffsetStorage_GetHighestOffset(t *testing.T) {
|
||||
db := createTestDB(t)
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
partition := createTestPartitionForSQL()
|
||||
partitionKey := TopicPartitionKey("test-namespace", "test-topic", partition)
|
||||
|
||||
// Test empty partition
|
||||
_, err = storage.GetHighestOffset("test-namespace", "test-topic", partition)
|
||||
if err == nil {
|
||||
t.Error("Expected error for empty partition")
|
||||
}
|
||||
|
||||
// Add some offset mappings
|
||||
offsets := []int64{5, 1, 3, 2, 4}
|
||||
for _, offset := range offsets {
|
||||
err := storage.SaveOffsetMapping(partitionKey, offset, offset*1000, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save offset mapping: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get highest offset
|
||||
highest, err := storage.GetHighestOffset("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get highest offset: %v", err)
|
||||
}
|
||||
|
||||
if highest != 5 {
|
||||
t.Errorf("Expected highest offset 5, got %d", highest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLOffsetStorage_GetOffsetMappingsByRange(t *testing.T) {
|
||||
db := createTestDB(t)
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
partition := createTestPartitionForSQL()
|
||||
partitionKey := partitionKey(partition)
|
||||
|
||||
// Add offset mappings
|
||||
for i := int64(0); i < 10; i++ {
|
||||
err := storage.SaveOffsetMapping(partitionKey, i, i*1000, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save offset mapping: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get range of offsets
|
||||
entries, err := storage.GetOffsetMappingsByRange(partitionKey, 3, 7)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get offset range: %v", err)
|
||||
}
|
||||
|
||||
expectedCount := 5 // offsets 3, 4, 5, 6, 7
|
||||
if len(entries) != expectedCount {
|
||||
t.Errorf("Expected %d entries, got %d", expectedCount, len(entries))
|
||||
}
|
||||
|
||||
// Verify range
|
||||
for i, entry := range entries {
|
||||
expectedOffset := int64(3 + i)
|
||||
if entry.KafkaOffset != expectedOffset {
|
||||
t.Errorf("Entry %d: expected offset %d, got %d", i, expectedOffset, entry.KafkaOffset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLOffsetStorage_GetPartitionStats(t *testing.T) {
|
||||
db := createTestDB(t)
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
partition := createTestPartitionForSQL()
|
||||
partitionKey := partitionKey(partition)
|
||||
|
||||
// Test empty partition stats
|
||||
stats, err := storage.GetPartitionStats(partitionKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get empty partition stats: %v", err)
|
||||
}
|
||||
|
||||
if stats.RecordCount != 0 {
|
||||
t.Errorf("Expected record count 0, got %d", stats.RecordCount)
|
||||
}
|
||||
|
||||
if stats.EarliestOffset != -1 {
|
||||
t.Errorf("Expected earliest offset -1, got %d", stats.EarliestOffset)
|
||||
}
|
||||
|
||||
// Add some data
|
||||
sizes := []int32{100, 150, 200}
|
||||
for i, size := range sizes {
|
||||
err := storage.SaveOffsetMapping(partitionKey, int64(i), int64(i*1000), size)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save offset mapping: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get stats with data
|
||||
stats, err = storage.GetPartitionStats(partitionKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get partition stats: %v", err)
|
||||
}
|
||||
|
||||
if stats.RecordCount != 3 {
|
||||
t.Errorf("Expected record count 3, got %d", stats.RecordCount)
|
||||
}
|
||||
|
||||
if stats.EarliestOffset != 0 {
|
||||
t.Errorf("Expected earliest offset 0, got %d", stats.EarliestOffset)
|
||||
}
|
||||
|
||||
if stats.LatestOffset != 2 {
|
||||
t.Errorf("Expected latest offset 2, got %d", stats.LatestOffset)
|
||||
}
|
||||
|
||||
if stats.HighWaterMark != 3 {
|
||||
t.Errorf("Expected high water mark 3, got %d", stats.HighWaterMark)
|
||||
}
|
||||
|
||||
expectedTotalSize := int64(100 + 150 + 200)
|
||||
if stats.TotalSize != expectedTotalSize {
|
||||
t.Errorf("Expected total size %d, got %d", expectedTotalSize, stats.TotalSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLOffsetStorage_GetAllPartitions(t *testing.T) {
|
||||
db := createTestDB(t)
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
// Test empty database
|
||||
partitions, err := storage.GetAllPartitions()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get all partitions: %v", err)
|
||||
}
|
||||
|
||||
if len(partitions) != 0 {
|
||||
t.Errorf("Expected 0 partitions, got %d", len(partitions))
|
||||
}
|
||||
|
||||
// Add data for multiple partitions
|
||||
partition1 := createTestPartitionForSQL()
|
||||
partition2 := &schema_pb.Partition{
|
||||
RingSize: 1024,
|
||||
RangeStart: 32,
|
||||
RangeStop: 63,
|
||||
UnixTimeNs: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
partitionKey1 := partitionKey(partition1)
|
||||
partitionKey2 := partitionKey(partition2)
|
||||
|
||||
storage.SaveOffsetMapping(partitionKey1, 0, 1000, 100)
|
||||
storage.SaveOffsetMapping(partitionKey2, 0, 2000, 150)
|
||||
|
||||
// Get all partitions
|
||||
partitions, err = storage.GetAllPartitions()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get all partitions: %v", err)
|
||||
}
|
||||
|
||||
if len(partitions) != 2 {
|
||||
t.Errorf("Expected 2 partitions, got %d", len(partitions))
|
||||
}
|
||||
|
||||
// Verify partition keys are present
|
||||
partitionMap := make(map[string]bool)
|
||||
for _, p := range partitions {
|
||||
partitionMap[p] = true
|
||||
}
|
||||
|
||||
if !partitionMap[partitionKey1] {
|
||||
t.Errorf("Partition key %s not found", partitionKey1)
|
||||
}
|
||||
|
||||
if !partitionMap[partitionKey2] {
|
||||
t.Errorf("Partition key %s not found", partitionKey2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLOffsetStorage_CleanupOldMappings(t *testing.T) {
|
||||
db := createTestDB(t)
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
partition := createTestPartitionForSQL()
|
||||
partitionKey := partitionKey(partition)
|
||||
|
||||
// Add mappings with different timestamps
|
||||
now := time.Now().UnixNano()
|
||||
|
||||
// Add old mapping by directly inserting with old timestamp
|
||||
oldTime := now - (24 * time.Hour).Nanoseconds() // 24 hours ago
|
||||
_, err = db.Exec(`
|
||||
INSERT INTO offset_mappings
|
||||
(partition_key, kafka_offset, smq_timestamp, message_size, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`, partitionKey, 0, oldTime, 100, oldTime)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert old mapping: %v", err)
|
||||
}
|
||||
|
||||
// Add recent mapping
|
||||
storage.SaveOffsetMapping(partitionKey, 1, now, 150)
|
||||
|
||||
// Verify both mappings exist
|
||||
entries, err := storage.LoadOffsetMappings(partitionKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load mappings: %v", err)
|
||||
}
|
||||
|
||||
if len(entries) != 2 {
|
||||
t.Errorf("Expected 2 mappings before cleanup, got %d", len(entries))
|
||||
}
|
||||
|
||||
// Cleanup old mappings (older than 12 hours)
|
||||
cutoffTime := now - (12 * time.Hour).Nanoseconds()
|
||||
err = storage.CleanupOldMappings(cutoffTime)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup old mappings: %v", err)
|
||||
}
|
||||
|
||||
// Verify only recent mapping remains
|
||||
entries, err = storage.LoadOffsetMappings(partitionKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load mappings after cleanup: %v", err)
|
||||
}
|
||||
|
||||
if len(entries) != 1 {
|
||||
t.Errorf("Expected 1 mapping after cleanup, got %d", len(entries))
|
||||
}
|
||||
|
||||
if entries[0].KafkaOffset != 1 {
|
||||
t.Errorf("Expected remaining mapping offset 1, got %d", entries[0].KafkaOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLOffsetStorage_Vacuum(t *testing.T) {
|
||||
db := createTestDB(t)
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
// Vacuum should not fail on empty database
|
||||
err = storage.Vacuum()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to vacuum database: %v", err)
|
||||
}
|
||||
|
||||
// Add some data and vacuum again
|
||||
partition := createTestPartitionForSQL()
|
||||
partitionKey := partitionKey(partition)
|
||||
storage.SaveOffsetMapping(partitionKey, 0, 1000, 100)
|
||||
|
||||
err = storage.Vacuum()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to vacuum database with data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLOffsetStorage_ConcurrentAccess(t *testing.T) {
|
||||
db := createTestDB(t)
|
||||
storage, err := NewSQLOffsetStorage(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SQL storage: %v", err)
|
||||
}
|
||||
defer storage.Close()
|
||||
|
||||
partition := createTestPartitionForSQL()
|
||||
partitionKey := partitionKey(partition)
|
||||
|
||||
// Test concurrent writes
|
||||
const numGoroutines = 10
|
||||
const offsetsPerGoroutine = 10
|
||||
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(goroutineID int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
for j := 0; j < offsetsPerGoroutine; j++ {
|
||||
offset := int64(goroutineID*offsetsPerGoroutine + j)
|
||||
err := storage.SaveOffsetMapping(partitionKey, offset, offset*1000, 100)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save offset mapping %d: %v", offset, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all mappings were saved
|
||||
entries, err := storage.LoadOffsetMappings(partitionKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load mappings: %v", err)
|
||||
}
|
||||
|
||||
expectedCount := numGoroutines * offsetsPerGoroutine
|
||||
if len(entries) != expectedCount {
|
||||
t.Errorf("Expected %d mappings, got %d", expectedCount, len(entries))
|
||||
}
|
||||
}
|
||||
@@ -1,457 +0,0 @@
|
||||
package offset
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
||||
)
|
||||
|
||||
func TestOffsetSubscriber_CreateSubscription(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
subscriber := NewOffsetSubscriber(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Assign some offsets first
|
||||
registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
|
||||
|
||||
// Test EXACT_OFFSET subscription
|
||||
sub, err := subscriber.CreateSubscription("test-sub-1", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create EXACT_OFFSET subscription: %v", err)
|
||||
}
|
||||
|
||||
if sub.StartOffset != 5 {
|
||||
t.Errorf("Expected start offset 5, got %d", sub.StartOffset)
|
||||
}
|
||||
if sub.CurrentOffset != 5 {
|
||||
t.Errorf("Expected current offset 5, got %d", sub.CurrentOffset)
|
||||
}
|
||||
|
||||
// Test RESET_TO_LATEST subscription
|
||||
sub2, err := subscriber.CreateSubscription("test-sub-2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create RESET_TO_LATEST subscription: %v", err)
|
||||
}
|
||||
|
||||
if sub2.StartOffset != 10 { // Should be at high water mark
|
||||
t.Errorf("Expected start offset 10, got %d", sub2.StartOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSubscriber_InvalidSubscription(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
subscriber := NewOffsetSubscriber(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Assign some offsets
|
||||
registry.AssignOffsets("test-namespace", "test-topic", partition, 5)
|
||||
|
||||
// Test invalid offset (beyond high water mark)
|
||||
_, err := subscriber.CreateSubscription("invalid-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 10)
|
||||
if err == nil {
|
||||
t.Error("Expected error for offset beyond high water mark")
|
||||
}
|
||||
|
||||
// Test negative offset
|
||||
_, err = subscriber.CreateSubscription("invalid-sub-2", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, -1)
|
||||
if err == nil {
|
||||
t.Error("Expected error for negative offset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSubscriber_DuplicateSubscription(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
subscriber := NewOffsetSubscriber(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Create first subscription
|
||||
_, err := subscriber.CreateSubscription("duplicate-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create first subscription: %v", err)
|
||||
}
|
||||
|
||||
// Try to create duplicate
|
||||
_, err = subscriber.CreateSubscription("duplicate-sub", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
|
||||
if err == nil {
|
||||
t.Error("Expected error for duplicate subscription ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSubscription_SeekToOffset(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
subscriber := NewOffsetSubscriber(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Assign offsets
|
||||
registry.AssignOffsets("test-namespace", "test-topic", partition, 20)
|
||||
|
||||
// Create subscription
|
||||
sub, err := subscriber.CreateSubscription("seek-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Test valid seek
|
||||
err = sub.SeekToOffset(10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to seek to offset 10: %v", err)
|
||||
}
|
||||
|
||||
if sub.CurrentOffset != 10 {
|
||||
t.Errorf("Expected current offset 10, got %d", sub.CurrentOffset)
|
||||
}
|
||||
|
||||
// Test invalid seek (beyond high water mark)
|
||||
err = sub.SeekToOffset(25)
|
||||
if err == nil {
|
||||
t.Error("Expected error for seek beyond high water mark")
|
||||
}
|
||||
|
||||
// Test negative seek
|
||||
err = sub.SeekToOffset(-1)
|
||||
if err == nil {
|
||||
t.Error("Expected error for negative seek offset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSubscription_AdvanceOffset(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
subscriber := NewOffsetSubscriber(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Create subscription
|
||||
sub, err := subscriber.CreateSubscription("advance-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Test single advance
|
||||
initialOffset := sub.GetNextOffset()
|
||||
sub.AdvanceOffset()
|
||||
|
||||
if sub.GetNextOffset() != initialOffset+1 {
|
||||
t.Errorf("Expected offset %d, got %d", initialOffset+1, sub.GetNextOffset())
|
||||
}
|
||||
|
||||
// Test batch advance
|
||||
sub.AdvanceOffsetBy(5)
|
||||
|
||||
if sub.GetNextOffset() != initialOffset+6 {
|
||||
t.Errorf("Expected offset %d, got %d", initialOffset+6, sub.GetNextOffset())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSubscription_GetLag(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
subscriber := NewOffsetSubscriber(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Assign offsets
|
||||
registry.AssignOffsets("test-namespace", "test-topic", partition, 15)
|
||||
|
||||
// Create subscription at offset 5
|
||||
sub, err := subscriber.CreateSubscription("lag-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Check initial lag
|
||||
lag, err := sub.GetLag()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get lag: %v", err)
|
||||
}
|
||||
|
||||
expectedLag := int64(15 - 5) // hwm - current
|
||||
if lag != expectedLag {
|
||||
t.Errorf("Expected lag %d, got %d", expectedLag, lag)
|
||||
}
|
||||
|
||||
// Advance and check lag again
|
||||
sub.AdvanceOffsetBy(3)
|
||||
|
||||
lag, err = sub.GetLag()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get lag after advance: %v", err)
|
||||
}
|
||||
|
||||
expectedLag = int64(15 - 8) // hwm - current
|
||||
if lag != expectedLag {
|
||||
t.Errorf("Expected lag %d after advance, got %d", expectedLag, lag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSubscription_IsAtEnd(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
subscriber := NewOffsetSubscriber(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Assign offsets
|
||||
registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
|
||||
|
||||
// Create subscription at end
|
||||
sub, err := subscriber.CreateSubscription("end-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Should be at end
|
||||
atEnd, err := sub.IsAtEnd()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if at end: %v", err)
|
||||
}
|
||||
|
||||
if !atEnd {
|
||||
t.Error("Expected subscription to be at end")
|
||||
}
|
||||
|
||||
// Seek to middle and check again
|
||||
sub.SeekToOffset(5)
|
||||
|
||||
atEnd, err = sub.IsAtEnd()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to check if at end after seek: %v", err)
|
||||
}
|
||||
|
||||
if atEnd {
|
||||
t.Error("Expected subscription not to be at end after seek")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSubscription_GetOffsetRange(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
subscriber := NewOffsetSubscriber(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Assign offsets
|
||||
registry.AssignOffsets("test-namespace", "test-topic", partition, 20)
|
||||
|
||||
// Create subscription
|
||||
sub, err := subscriber.CreateSubscription("range-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_EXACT_OFFSET, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Test normal range
|
||||
offsetRange, err := sub.GetOffsetRange(10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get offset range: %v", err)
|
||||
}
|
||||
|
||||
if offsetRange.StartOffset != 5 {
|
||||
t.Errorf("Expected start offset 5, got %d", offsetRange.StartOffset)
|
||||
}
|
||||
if offsetRange.EndOffset != 14 {
|
||||
t.Errorf("Expected end offset 14, got %d", offsetRange.EndOffset)
|
||||
}
|
||||
if offsetRange.Count != 10 {
|
||||
t.Errorf("Expected count 10, got %d", offsetRange.Count)
|
||||
}
|
||||
|
||||
// Test range that exceeds high water mark
|
||||
sub.SeekToOffset(15)
|
||||
offsetRange, err = sub.GetOffsetRange(10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get offset range near end: %v", err)
|
||||
}
|
||||
|
||||
if offsetRange.StartOffset != 15 {
|
||||
t.Errorf("Expected start offset 15, got %d", offsetRange.StartOffset)
|
||||
}
|
||||
if offsetRange.EndOffset != 19 { // Should be capped at hwm-1
|
||||
t.Errorf("Expected end offset 19, got %d", offsetRange.EndOffset)
|
||||
}
|
||||
if offsetRange.Count != 5 {
|
||||
t.Errorf("Expected count 5, got %d", offsetRange.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSubscription_EmptyRange(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
subscriber := NewOffsetSubscriber(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Assign offsets
|
||||
registry.AssignOffsets("test-namespace", "test-topic", partition, 10)
|
||||
|
||||
// Create subscription at end
|
||||
sub, err := subscriber.CreateSubscription("empty-range-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_LATEST, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Request range when at end
|
||||
offsetRange, err := sub.GetOffsetRange(5)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get offset range at end: %v", err)
|
||||
}
|
||||
|
||||
if offsetRange.Count != 0 {
|
||||
t.Errorf("Expected empty range (count 0), got count %d", offsetRange.Count)
|
||||
}
|
||||
|
||||
if offsetRange.StartOffset != 10 {
|
||||
t.Errorf("Expected start offset 10, got %d", offsetRange.StartOffset)
|
||||
}
|
||||
|
||||
if offsetRange.EndOffset != 9 { // Empty range: end < start
|
||||
t.Errorf("Expected end offset 9 (empty range), got %d", offsetRange.EndOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSeeker_ValidateOffsetRange(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
seeker := NewOffsetSeeker(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Assign offsets
|
||||
registry.AssignOffsets("test-namespace", "test-topic", partition, 15)
|
||||
|
||||
// Test valid range
|
||||
err := seeker.ValidateOffsetRange("test-namespace", "test-topic", partition, 5, 10)
|
||||
if err != nil {
|
||||
t.Errorf("Valid range should not return error: %v", err)
|
||||
}
|
||||
|
||||
// Test invalid ranges
|
||||
testCases := []struct {
|
||||
name string
|
||||
startOffset int64
|
||||
endOffset int64
|
||||
expectError bool
|
||||
}{
|
||||
{"negative start", -1, 5, true},
|
||||
{"end before start", 10, 5, true},
|
||||
{"start beyond hwm", 20, 25, true},
|
||||
{"valid range", 0, 14, false},
|
||||
{"single offset", 5, 5, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := seeker.ValidateOffsetRange("test-namespace", "test-topic", partition, tc.startOffset, tc.endOffset)
|
||||
if tc.expectError && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !tc.expectError && err != nil {
|
||||
t.Errorf("Expected no error but got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSeeker_GetAvailableOffsetRange(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
seeker := NewOffsetSeeker(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Test empty partition
|
||||
offsetRange, err := seeker.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get available range for empty partition: %v", err)
|
||||
}
|
||||
|
||||
if offsetRange.Count != 0 {
|
||||
t.Errorf("Expected empty range for empty partition, got count %d", offsetRange.Count)
|
||||
}
|
||||
|
||||
// Assign offsets and test again
|
||||
registry.AssignOffsets("test-namespace", "test-topic", partition, 25)
|
||||
|
||||
offsetRange, err = seeker.GetAvailableOffsetRange("test-namespace", "test-topic", partition)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get available range: %v", err)
|
||||
}
|
||||
|
||||
if offsetRange.StartOffset != 0 {
|
||||
t.Errorf("Expected start offset 0, got %d", offsetRange.StartOffset)
|
||||
}
|
||||
if offsetRange.EndOffset != 24 {
|
||||
t.Errorf("Expected end offset 24, got %d", offsetRange.EndOffset)
|
||||
}
|
||||
if offsetRange.Count != 25 {
|
||||
t.Errorf("Expected count 25, got %d", offsetRange.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSubscriber_CloseSubscription(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
subscriber := NewOffsetSubscriber(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Create subscription
|
||||
sub, err := subscriber.CreateSubscription("close-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
// Verify subscription exists
|
||||
_, err = subscriber.GetSubscription("close-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Subscription should exist: %v", err)
|
||||
}
|
||||
|
||||
// Close subscription
|
||||
err = subscriber.CloseSubscription("close-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to close subscription: %v", err)
|
||||
}
|
||||
|
||||
// Verify subscription is gone
|
||||
_, err = subscriber.GetSubscription("close-test")
|
||||
if err == nil {
|
||||
t.Error("Subscription should not exist after close")
|
||||
}
|
||||
|
||||
// Verify subscription is marked inactive
|
||||
if sub.IsActive {
|
||||
t.Error("Subscription should be marked inactive after close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOffsetSubscription_InactiveOperations(t *testing.T) {
|
||||
storage := NewInMemoryOffsetStorage()
|
||||
registry := NewPartitionOffsetRegistry(storage)
|
||||
subscriber := NewOffsetSubscriber(registry)
|
||||
partition := createTestPartition()
|
||||
|
||||
// Create and close subscription
|
||||
sub, err := subscriber.CreateSubscription("inactive-test", "test-namespace", "test-topic", partition, schema_pb.OffsetType_RESET_TO_EARLIEST, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create subscription: %v", err)
|
||||
}
|
||||
|
||||
subscriber.CloseSubscription("inactive-test")
|
||||
|
||||
// Test operations on inactive subscription
|
||||
err = sub.SeekToOffset(5)
|
||||
if err == nil {
|
||||
t.Error("Expected error for seek on inactive subscription")
|
||||
}
|
||||
|
||||
_, err = sub.GetLag()
|
||||
if err == nil {
|
||||
t.Error("Expected error for GetLag on inactive subscription")
|
||||
}
|
||||
|
||||
_, err = sub.IsAtEnd()
|
||||
if err == nil {
|
||||
t.Error("Expected error for IsAtEnd on inactive subscription")
|
||||
}
|
||||
|
||||
_, err = sub.GetOffsetRange(10)
|
||||
if err == nil {
|
||||
t.Error("Expected error for GetOffsetRange on inactive subscription")
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,6 @@
|
||||
package pub_balancer
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"sort"
|
||||
|
||||
cmap "github.com/orcaman/concurrent-map/v2"
|
||||
"github.com/seaweedfs/seaweedfs/weed/mq/topic"
|
||||
"modernc.org/mathutil"
|
||||
)
|
||||
import ()
|
||||
|
||||
func (balancer *PubBalancer) RepairTopics() []BalanceAction {
|
||||
action := BalanceTopicPartitionOnBrokers(balancer.Brokers)
|
||||
@@ -17,107 +10,3 @@ func (balancer *PubBalancer) RepairTopics() []BalanceAction {
|
||||
type TopicPartitionInfo struct {
|
||||
Broker string
|
||||
}
|
||||
|
||||
// RepairMissingTopicPartitions check the stats of all brokers,
|
||||
// and repair the missing topic partitions on the brokers.
|
||||
func RepairMissingTopicPartitions(brokers cmap.ConcurrentMap[string, *BrokerStats]) (actions []BalanceAction) {
|
||||
|
||||
// find all topic partitions
|
||||
topicToTopicPartitions := make(map[topic.Topic]map[topic.Partition]*TopicPartitionInfo)
|
||||
for brokerStatsItem := range brokers.IterBuffered() {
|
||||
broker, brokerStats := brokerStatsItem.Key, brokerStatsItem.Val
|
||||
for topicPartitionStatsItem := range brokerStats.TopicPartitionStats.IterBuffered() {
|
||||
topicPartitionStat := topicPartitionStatsItem.Val
|
||||
topicPartitionToInfo, found := topicToTopicPartitions[topicPartitionStat.Topic]
|
||||
if !found {
|
||||
topicPartitionToInfo = make(map[topic.Partition]*TopicPartitionInfo)
|
||||
topicToTopicPartitions[topicPartitionStat.Topic] = topicPartitionToInfo
|
||||
}
|
||||
tpi, found := topicPartitionToInfo[topicPartitionStat.Partition]
|
||||
if !found {
|
||||
tpi = &TopicPartitionInfo{}
|
||||
topicPartitionToInfo[topicPartitionStat.Partition] = tpi
|
||||
}
|
||||
tpi.Broker = broker
|
||||
}
|
||||
}
|
||||
|
||||
// collect all brokers as candidates
|
||||
candidates := make([]string, 0, brokers.Count())
|
||||
for brokerStatsItem := range brokers.IterBuffered() {
|
||||
candidates = append(candidates, brokerStatsItem.Key)
|
||||
}
|
||||
|
||||
// find the missing topic partitions
|
||||
for t, topicPartitionToInfo := range topicToTopicPartitions {
|
||||
missingPartitions := EachTopicRepairMissingTopicPartitions(t, topicPartitionToInfo)
|
||||
for _, partition := range missingPartitions {
|
||||
actions = append(actions, BalanceActionCreate{
|
||||
TopicPartition: topic.TopicPartition{
|
||||
Topic: t,
|
||||
Partition: partition,
|
||||
},
|
||||
TargetBroker: candidates[rand.IntN(len(candidates))],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return actions
|
||||
}
|
||||
|
||||
func EachTopicRepairMissingTopicPartitions(t topic.Topic, info map[topic.Partition]*TopicPartitionInfo) (missingPartitions []topic.Partition) {
|
||||
|
||||
// find the missing topic partitions
|
||||
var partitions []topic.Partition
|
||||
for partition := range info {
|
||||
partitions = append(partitions, partition)
|
||||
}
|
||||
return findMissingPartitions(partitions, MaxPartitionCount)
|
||||
}
|
||||
|
||||
// findMissingPartitions find the missing partitions
|
||||
func findMissingPartitions(partitions []topic.Partition, ringSize int32) (missingPartitions []topic.Partition) {
|
||||
// sort the partitions by range start
|
||||
sort.Slice(partitions, func(i, j int) bool {
|
||||
return partitions[i].RangeStart < partitions[j].RangeStart
|
||||
})
|
||||
|
||||
// calculate the average partition size
|
||||
var covered int32
|
||||
for _, partition := range partitions {
|
||||
covered += partition.RangeStop - partition.RangeStart
|
||||
}
|
||||
averagePartitionSize := covered / int32(len(partitions))
|
||||
|
||||
// find the missing partitions
|
||||
var coveredWatermark int32
|
||||
i := 0
|
||||
for i < len(partitions) {
|
||||
partition := partitions[i]
|
||||
if partition.RangeStart > coveredWatermark {
|
||||
upperBound := mathutil.MinInt32(coveredWatermark+averagePartitionSize, partition.RangeStart)
|
||||
missingPartitions = append(missingPartitions, topic.Partition{
|
||||
RangeStart: coveredWatermark,
|
||||
RangeStop: upperBound,
|
||||
RingSize: ringSize,
|
||||
})
|
||||
coveredWatermark = upperBound
|
||||
if coveredWatermark == partition.RangeStop {
|
||||
i++
|
||||
}
|
||||
} else {
|
||||
coveredWatermark = partition.RangeStop
|
||||
i++
|
||||
}
|
||||
}
|
||||
for coveredWatermark < ringSize {
|
||||
upperBound := mathutil.MinInt32(coveredWatermark+averagePartitionSize, ringSize)
|
||||
missingPartitions = append(missingPartitions, topic.Partition{
|
||||
RangeStart: coveredWatermark,
|
||||
RangeStop: upperBound,
|
||||
RingSize: ringSize,
|
||||
})
|
||||
coveredWatermark = upperBound
|
||||
}
|
||||
return missingPartitions
|
||||
}
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
package pub_balancer
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/mq/topic"
|
||||
)
|
||||
|
||||
func Test_findMissingPartitions(t *testing.T) {
|
||||
type args struct {
|
||||
partitions []topic.Partition
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantMissingPartitions []topic.Partition
|
||||
}{
|
||||
{
|
||||
name: "one partition",
|
||||
args: args{
|
||||
partitions: []topic.Partition{
|
||||
{RingSize: 1024, RangeStart: 0, RangeStop: 1024},
|
||||
},
|
||||
},
|
||||
wantMissingPartitions: nil,
|
||||
},
|
||||
{
|
||||
name: "two partitions",
|
||||
args: args{
|
||||
partitions: []topic.Partition{
|
||||
{RingSize: 1024, RangeStart: 0, RangeStop: 512},
|
||||
{RingSize: 1024, RangeStart: 512, RangeStop: 1024},
|
||||
},
|
||||
},
|
||||
wantMissingPartitions: nil,
|
||||
},
|
||||
{
|
||||
name: "four partitions, missing last two",
|
||||
args: args{
|
||||
partitions: []topic.Partition{
|
||||
{RingSize: 1024, RangeStart: 0, RangeStop: 256},
|
||||
{RingSize: 1024, RangeStart: 256, RangeStop: 512},
|
||||
},
|
||||
},
|
||||
wantMissingPartitions: []topic.Partition{
|
||||
{RingSize: 1024, RangeStart: 512, RangeStop: 768},
|
||||
{RingSize: 1024, RangeStart: 768, RangeStop: 1024},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "four partitions, missing first two",
|
||||
args: args{
|
||||
partitions: []topic.Partition{
|
||||
{RingSize: 1024, RangeStart: 512, RangeStop: 768},
|
||||
{RingSize: 1024, RangeStart: 768, RangeStop: 1024},
|
||||
},
|
||||
},
|
||||
wantMissingPartitions: []topic.Partition{
|
||||
{RingSize: 1024, RangeStart: 0, RangeStop: 256},
|
||||
{RingSize: 1024, RangeStart: 256, RangeStop: 512},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "four partitions, missing middle two",
|
||||
args: args{
|
||||
partitions: []topic.Partition{
|
||||
{RingSize: 1024, RangeStart: 0, RangeStop: 256},
|
||||
{RingSize: 1024, RangeStart: 768, RangeStop: 1024},
|
||||
},
|
||||
},
|
||||
wantMissingPartitions: []topic.Partition{
|
||||
{RingSize: 1024, RangeStart: 256, RangeStop: 512},
|
||||
{RingSize: 1024, RangeStart: 512, RangeStop: 768},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "four partitions, missing three",
|
||||
args: args{
|
||||
partitions: []topic.Partition{
|
||||
{RingSize: 1024, RangeStart: 512, RangeStop: 768},
|
||||
},
|
||||
},
|
||||
wantMissingPartitions: []topic.Partition{
|
||||
{RingSize: 1024, RangeStart: 0, RangeStop: 256},
|
||||
{RingSize: 1024, RangeStart: 256, RangeStop: 512},
|
||||
{RingSize: 1024, RangeStart: 768, RangeStop: 1024},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if gotMissingPartitions := findMissingPartitions(tt.args.partitions, 1024); !reflect.DeepEqual(gotMissingPartitions, tt.wantMissingPartitions) {
|
||||
t.Errorf("findMissingPartitions() = %v, want %v", gotMissingPartitions, tt.wantMissingPartitions)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
package segment
|
||||
|
||||
import (
|
||||
flatbuffers "github.com/google/flatbuffers/go"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/message_fbs"
|
||||
)
|
||||
|
||||
type MessageBatchBuilder struct {
|
||||
b *flatbuffers.Builder
|
||||
producerId int32
|
||||
producerEpoch int32
|
||||
segmentId int32
|
||||
flags int32
|
||||
messageOffsets []flatbuffers.UOffsetT
|
||||
segmentSeqBase int64
|
||||
segmentSeqLast int64
|
||||
tsMsBase int64
|
||||
tsMsLast int64
|
||||
}
|
||||
|
||||
func NewMessageBatchBuilder(b *flatbuffers.Builder,
|
||||
producerId int32,
|
||||
producerEpoch int32,
|
||||
segmentId int32,
|
||||
flags int32) *MessageBatchBuilder {
|
||||
|
||||
b.Reset()
|
||||
|
||||
return &MessageBatchBuilder{
|
||||
b: b,
|
||||
producerId: producerId,
|
||||
producerEpoch: producerEpoch,
|
||||
segmentId: segmentId,
|
||||
flags: flags,
|
||||
}
|
||||
}
|
||||
|
||||
func (builder *MessageBatchBuilder) AddMessage(segmentSeq int64, tsMs int64, properties map[string][]byte, key []byte, value []byte) {
|
||||
if builder.segmentSeqBase == 0 {
|
||||
builder.segmentSeqBase = segmentSeq
|
||||
}
|
||||
builder.segmentSeqLast = segmentSeq
|
||||
if builder.tsMsBase == 0 {
|
||||
builder.tsMsBase = tsMs
|
||||
}
|
||||
builder.tsMsLast = tsMs
|
||||
|
||||
var names, values, pairs []flatbuffers.UOffsetT
|
||||
for k, v := range properties {
|
||||
names = append(names, builder.b.CreateString(k))
|
||||
values = append(values, builder.b.CreateByteVector(v))
|
||||
}
|
||||
for i, _ := range names {
|
||||
message_fbs.NameValueStart(builder.b)
|
||||
message_fbs.NameValueAddName(builder.b, names[i])
|
||||
message_fbs.NameValueAddValue(builder.b, values[i])
|
||||
pair := message_fbs.NameValueEnd(builder.b)
|
||||
pairs = append(pairs, pair)
|
||||
}
|
||||
|
||||
message_fbs.MessageStartPropertiesVector(builder.b, len(properties))
|
||||
for i := len(pairs) - 1; i >= 0; i-- {
|
||||
builder.b.PrependUOffsetT(pairs[i])
|
||||
}
|
||||
propOffset := builder.b.EndVector(len(properties))
|
||||
|
||||
keyOffset := builder.b.CreateByteVector(key)
|
||||
valueOffset := builder.b.CreateByteVector(value)
|
||||
|
||||
message_fbs.MessageStart(builder.b)
|
||||
message_fbs.MessageAddSeqDelta(builder.b, int32(segmentSeq-builder.segmentSeqBase))
|
||||
message_fbs.MessageAddTsMsDelta(builder.b, int32(tsMs-builder.tsMsBase))
|
||||
|
||||
message_fbs.MessageAddProperties(builder.b, propOffset)
|
||||
message_fbs.MessageAddKey(builder.b, keyOffset)
|
||||
message_fbs.MessageAddData(builder.b, valueOffset)
|
||||
messageOffset := message_fbs.MessageEnd(builder.b)
|
||||
|
||||
builder.messageOffsets = append(builder.messageOffsets, messageOffset)
|
||||
|
||||
}
|
||||
|
||||
func (builder *MessageBatchBuilder) BuildMessageBatch() {
|
||||
message_fbs.MessageBatchStartMessagesVector(builder.b, len(builder.messageOffsets))
|
||||
for i := len(builder.messageOffsets) - 1; i >= 0; i-- {
|
||||
builder.b.PrependUOffsetT(builder.messageOffsets[i])
|
||||
}
|
||||
messagesOffset := builder.b.EndVector(len(builder.messageOffsets))
|
||||
|
||||
message_fbs.MessageBatchStart(builder.b)
|
||||
message_fbs.MessageBatchAddProducerId(builder.b, builder.producerId)
|
||||
message_fbs.MessageBatchAddProducerEpoch(builder.b, builder.producerEpoch)
|
||||
message_fbs.MessageBatchAddSegmentId(builder.b, builder.segmentId)
|
||||
message_fbs.MessageBatchAddFlags(builder.b, builder.flags)
|
||||
message_fbs.MessageBatchAddSegmentSeqBase(builder.b, builder.segmentSeqBase)
|
||||
message_fbs.MessageBatchAddSegmentSeqMaxDelta(builder.b, int32(builder.segmentSeqLast-builder.segmentSeqBase))
|
||||
message_fbs.MessageBatchAddTsMsBase(builder.b, builder.tsMsBase)
|
||||
message_fbs.MessageBatchAddTsMsMaxDelta(builder.b, int32(builder.tsMsLast-builder.tsMsBase))
|
||||
|
||||
message_fbs.MessageBatchAddMessages(builder.b, messagesOffset)
|
||||
|
||||
messageBatch := message_fbs.MessageBatchEnd(builder.b)
|
||||
|
||||
builder.b.Finish(messageBatch)
|
||||
}
|
||||
|
||||
func (builder *MessageBatchBuilder) GetBytes() []byte {
|
||||
return builder.b.FinishedBytes()
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
package segment
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
flatbuffers "github.com/google/flatbuffers/go"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/message_fbs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMessageSerde(t *testing.T) {
|
||||
b := flatbuffers.NewBuilder(1024)
|
||||
|
||||
prop := make(map[string][]byte)
|
||||
prop["n1"] = []byte("v1")
|
||||
prop["n2"] = []byte("v2")
|
||||
|
||||
bb := NewMessageBatchBuilder(b, 1, 2, 3, 4)
|
||||
|
||||
bb.AddMessage(5, 6, prop, []byte("the primary key"), []byte("body is here"))
|
||||
bb.AddMessage(5, 7, prop, []byte("the primary 2"), []byte("body is 2"))
|
||||
|
||||
bb.BuildMessageBatch()
|
||||
|
||||
buf := bb.GetBytes()
|
||||
|
||||
println("serialized size", len(buf))
|
||||
|
||||
mb := message_fbs.GetRootAsMessageBatch(buf, 0)
|
||||
|
||||
assert.Equal(t, int32(1), mb.ProducerId())
|
||||
assert.Equal(t, int32(2), mb.ProducerEpoch())
|
||||
assert.Equal(t, int32(3), mb.SegmentId())
|
||||
assert.Equal(t, int32(4), mb.Flags())
|
||||
assert.Equal(t, int64(5), mb.SegmentSeqBase())
|
||||
assert.Equal(t, int32(0), mb.SegmentSeqMaxDelta())
|
||||
assert.Equal(t, int64(6), mb.TsMsBase())
|
||||
assert.Equal(t, int32(1), mb.TsMsMaxDelta())
|
||||
|
||||
assert.Equal(t, 2, mb.MessagesLength())
|
||||
|
||||
m := &message_fbs.Message{}
|
||||
mb.Messages(m, 0)
|
||||
|
||||
/*
|
||||
// the vector seems not consistent
|
||||
nv := &message_fbs.NameValue{}
|
||||
m.Properties(nv, 0)
|
||||
assert.Equal(t, "n1", string(nv.Name()))
|
||||
assert.Equal(t, "v1", string(nv.Value()))
|
||||
m.Properties(nv, 1)
|
||||
assert.Equal(t, "n2", string(nv.Name()))
|
||||
assert.Equal(t, "v2", string(nv.Value()))
|
||||
*/
|
||||
assert.Equal(t, []byte("the primary key"), m.Key())
|
||||
assert.Equal(t, []byte("body is here"), m.Data())
|
||||
|
||||
assert.Equal(t, int32(0), m.SeqDelta())
|
||||
assert.Equal(t, int32(0), m.TsMsDelta())
|
||||
|
||||
}
|
||||
@@ -28,28 +28,6 @@ func (imt *InflightMessageTracker) EnflightMessage(key []byte, tsNs int64) {
|
||||
imt.timestamps.EnflightTimestamp(tsNs)
|
||||
}
|
||||
|
||||
// IsMessageAcknowledged returns true if the message has been acknowledged.
|
||||
// If the message is older than the oldest inflight messages, returns false.
|
||||
// returns false if the message is inflight.
|
||||
// Otherwise, returns false if the message is old and can be ignored.
|
||||
func (imt *InflightMessageTracker) IsMessageAcknowledged(key []byte, tsNs int64) bool {
|
||||
imt.mu.Lock()
|
||||
defer imt.mu.Unlock()
|
||||
|
||||
if tsNs <= imt.timestamps.OldestAckedTimestamp() {
|
||||
return true
|
||||
}
|
||||
if tsNs > imt.timestamps.Latest() {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, found := imt.messages[string(key)]; found {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AcknowledgeMessage acknowledges the message with the key and timestamp.
|
||||
func (imt *InflightMessageTracker) AcknowledgeMessage(key []byte, tsNs int64) bool {
|
||||
// fmt.Printf("AcknowledgeMessage(%s,%d)\n", string(key), tsNs)
|
||||
@@ -164,8 +142,3 @@ func (rb *RingBuffer) AckTimestamp(timestamp int64) {
|
||||
func (rb *RingBuffer) OldestAckedTimestamp() int64 {
|
||||
return rb.maxAllAckedTs
|
||||
}
|
||||
|
||||
// Latest returns the most recently known timestamp in the ring buffer.
|
||||
func (rb *RingBuffer) Latest() int64 {
|
||||
return rb.maxTimestamp
|
||||
}
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
package sub_coordinator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRingBuffer(t *testing.T) {
|
||||
// Initialize a RingBuffer with capacity 5
|
||||
rb := NewRingBuffer(5)
|
||||
|
||||
// Add timestamps to the buffer
|
||||
timestamps := []int64{100, 200, 300, 400, 500}
|
||||
for _, ts := range timestamps {
|
||||
rb.EnflightTimestamp(ts)
|
||||
}
|
||||
|
||||
// Test Add method and buffer size
|
||||
expectedSize := 5
|
||||
if rb.size != expectedSize {
|
||||
t.Errorf("Expected buffer size %d, got %d", expectedSize, rb.size)
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(0), rb.OldestAckedTimestamp())
|
||||
assert.Equal(t, int64(500), rb.Latest())
|
||||
|
||||
rb.AckTimestamp(200)
|
||||
assert.Equal(t, int64(0), rb.OldestAckedTimestamp())
|
||||
rb.AckTimestamp(100)
|
||||
assert.Equal(t, int64(200), rb.OldestAckedTimestamp())
|
||||
|
||||
rb.EnflightTimestamp(int64(600))
|
||||
rb.EnflightTimestamp(int64(700))
|
||||
|
||||
rb.AckTimestamp(500)
|
||||
assert.Equal(t, int64(200), rb.OldestAckedTimestamp())
|
||||
rb.AckTimestamp(400)
|
||||
assert.Equal(t, int64(200), rb.OldestAckedTimestamp())
|
||||
rb.AckTimestamp(300)
|
||||
assert.Equal(t, int64(500), rb.OldestAckedTimestamp())
|
||||
|
||||
assert.Equal(t, int64(700), rb.Latest())
|
||||
}
|
||||
|
||||
func TestInflightMessageTracker(t *testing.T) {
|
||||
// Initialize an InflightMessageTracker with capacity 5
|
||||
tracker := NewInflightMessageTracker(5)
|
||||
|
||||
// Add inflight messages
|
||||
key := []byte("1")
|
||||
timestamp := int64(1)
|
||||
tracker.EnflightMessage(key, timestamp)
|
||||
|
||||
// Test IsMessageAcknowledged method
|
||||
isOld := tracker.IsMessageAcknowledged(key, timestamp-10)
|
||||
if !isOld {
|
||||
t.Error("Expected message to be old")
|
||||
}
|
||||
|
||||
// Test AcknowledgeMessage method
|
||||
acked := tracker.AcknowledgeMessage(key, timestamp)
|
||||
if !acked {
|
||||
t.Error("Expected message to be acked")
|
||||
}
|
||||
if _, exists := tracker.messages[string(key)]; exists {
|
||||
t.Error("Expected message to be deleted after ack")
|
||||
}
|
||||
if tracker.timestamps.size != 0 {
|
||||
t.Error("Expected buffer size to be 0 after ack")
|
||||
}
|
||||
assert.Equal(t, timestamp, tracker.GetOldestAckedTimestamp())
|
||||
}
|
||||
|
||||
func TestInflightMessageTracker2(t *testing.T) {
|
||||
// Initialize an InflightMessageTracker with initial capacity 1
|
||||
tracker := NewInflightMessageTracker(1)
|
||||
|
||||
tracker.EnflightMessage([]byte("1"), int64(1))
|
||||
tracker.EnflightMessage([]byte("2"), int64(2))
|
||||
tracker.EnflightMessage([]byte("3"), int64(3))
|
||||
tracker.EnflightMessage([]byte("4"), int64(4))
|
||||
tracker.EnflightMessage([]byte("5"), int64(5))
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1)))
|
||||
assert.Equal(t, int64(1), tracker.GetOldestAckedTimestamp())
|
||||
|
||||
// Test IsMessageAcknowledged method
|
||||
isAcked := tracker.IsMessageAcknowledged([]byte("2"), int64(2))
|
||||
if isAcked {
|
||||
t.Error("Expected message to be not acked")
|
||||
}
|
||||
|
||||
// Test AcknowledgeMessage method
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2)))
|
||||
assert.Equal(t, int64(2), tracker.GetOldestAckedTimestamp())
|
||||
|
||||
}
|
||||
|
||||
func TestInflightMessageTracker3(t *testing.T) {
|
||||
// Initialize an InflightMessageTracker with initial capacity 1
|
||||
tracker := NewInflightMessageTracker(1)
|
||||
|
||||
tracker.EnflightMessage([]byte("1"), int64(1))
|
||||
tracker.EnflightMessage([]byte("2"), int64(2))
|
||||
tracker.EnflightMessage([]byte("3"), int64(3))
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1)))
|
||||
tracker.EnflightMessage([]byte("4"), int64(4))
|
||||
tracker.EnflightMessage([]byte("5"), int64(5))
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2)))
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("3"), int64(3)))
|
||||
tracker.EnflightMessage([]byte("6"), int64(6))
|
||||
tracker.EnflightMessage([]byte("7"), int64(7))
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("4"), int64(4)))
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("5"), int64(5)))
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("6"), int64(6)))
|
||||
assert.Equal(t, int64(6), tracker.GetOldestAckedTimestamp())
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("7"), int64(7)))
|
||||
assert.Equal(t, int64(7), tracker.GetOldestAckedTimestamp())
|
||||
|
||||
}
|
||||
|
||||
func TestInflightMessageTracker4(t *testing.T) {
|
||||
// Initialize an InflightMessageTracker with initial capacity 1
|
||||
tracker := NewInflightMessageTracker(1)
|
||||
|
||||
tracker.EnflightMessage([]byte("1"), int64(1))
|
||||
tracker.EnflightMessage([]byte("2"), int64(2))
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("1"), int64(1)))
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("2"), int64(2)))
|
||||
tracker.EnflightMessage([]byte("3"), int64(3))
|
||||
assert.True(t, tracker.AcknowledgeMessage([]byte("3"), int64(3)))
|
||||
assert.Equal(t, int64(3), tracker.GetOldestAckedTimestamp())
|
||||
|
||||
}
|
||||
@@ -1,130 +1,6 @@
|
||||
package sub_coordinator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
|
||||
)
|
||||
|
||||
type PartitionConsumerMapping struct {
|
||||
currentMapping *PartitionSlotToConsumerInstanceList
|
||||
prevMappings []*PartitionSlotToConsumerInstanceList
|
||||
}
|
||||
|
||||
// Balance goal:
|
||||
// 1. max processing power utilization
|
||||
// 2. allow one consumer instance to be down unexpectedly
|
||||
// without affecting the processing power utilization
|
||||
|
||||
func (pcm *PartitionConsumerMapping) BalanceToConsumerInstances(partitionSlotToBrokerList *pub_balancer.PartitionSlotToBrokerList, consumerInstances []*ConsumerGroupInstance) {
|
||||
if len(partitionSlotToBrokerList.PartitionSlots) == 0 || len(consumerInstances) == 0 {
|
||||
return
|
||||
}
|
||||
newMapping := NewPartitionSlotToConsumerInstanceList(partitionSlotToBrokerList.RingSize, time.Now())
|
||||
var prevMapping *PartitionSlotToConsumerInstanceList
|
||||
if len(pcm.prevMappings) > 0 {
|
||||
prevMapping = pcm.prevMappings[len(pcm.prevMappings)-1]
|
||||
} else {
|
||||
prevMapping = nil
|
||||
}
|
||||
newMapping.PartitionSlots = doBalanceSticky(partitionSlotToBrokerList.PartitionSlots, consumerInstances, prevMapping)
|
||||
if pcm.currentMapping != nil {
|
||||
pcm.prevMappings = append(pcm.prevMappings, pcm.currentMapping)
|
||||
if len(pcm.prevMappings) > 10 {
|
||||
pcm.prevMappings = pcm.prevMappings[1:]
|
||||
}
|
||||
}
|
||||
pcm.currentMapping = newMapping
|
||||
}
|
||||
|
||||
func doBalanceSticky(partitions []*pub_balancer.PartitionSlotToBroker, consumerInstances []*ConsumerGroupInstance, prevMapping *PartitionSlotToConsumerInstanceList) (partitionSlots []*PartitionSlotToConsumerInstance) {
|
||||
// collect previous consumer instance ids
|
||||
prevConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{})
|
||||
if prevMapping != nil {
|
||||
for _, prevPartitionSlot := range prevMapping.PartitionSlots {
|
||||
if prevPartitionSlot.AssignedInstanceId != "" {
|
||||
prevConsumerInstanceIds[prevPartitionSlot.AssignedInstanceId] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
// collect current consumer instance ids
|
||||
currConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{})
|
||||
for _, consumerInstance := range consumerInstances {
|
||||
currConsumerInstanceIds[consumerInstance.InstanceId] = struct{}{}
|
||||
}
|
||||
|
||||
// check deleted consumer instances
|
||||
deletedConsumerInstanceIds := make(map[ConsumerGroupInstanceId]struct{})
|
||||
for consumerInstanceId := range prevConsumerInstanceIds {
|
||||
if _, ok := currConsumerInstanceIds[consumerInstanceId]; !ok {
|
||||
deletedConsumerInstanceIds[consumerInstanceId] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// convert partition slots from list to a map
|
||||
prevPartitionSlotMap := make(map[string]*PartitionSlotToConsumerInstance)
|
||||
if prevMapping != nil {
|
||||
for _, partitionSlot := range prevMapping.PartitionSlots {
|
||||
key := fmt.Sprintf("%d-%d", partitionSlot.RangeStart, partitionSlot.RangeStop)
|
||||
prevPartitionSlotMap[key] = partitionSlot
|
||||
}
|
||||
}
|
||||
|
||||
// make a copy of old mapping, skipping the deleted consumer instances
|
||||
newPartitionSlots := make([]*PartitionSlotToConsumerInstance, 0, len(partitions))
|
||||
for _, partition := range partitions {
|
||||
newPartitionSlots = append(newPartitionSlots, &PartitionSlotToConsumerInstance{
|
||||
RangeStart: partition.RangeStart,
|
||||
RangeStop: partition.RangeStop,
|
||||
UnixTimeNs: partition.UnixTimeNs,
|
||||
Broker: partition.AssignedBroker,
|
||||
FollowerBroker: partition.FollowerBroker,
|
||||
})
|
||||
}
|
||||
for _, newPartitionSlot := range newPartitionSlots {
|
||||
key := fmt.Sprintf("%d-%d", newPartitionSlot.RangeStart, newPartitionSlot.RangeStop)
|
||||
if prevPartitionSlot, ok := prevPartitionSlotMap[key]; ok {
|
||||
if _, ok := deletedConsumerInstanceIds[prevPartitionSlot.AssignedInstanceId]; !ok {
|
||||
newPartitionSlot.AssignedInstanceId = prevPartitionSlot.AssignedInstanceId
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for all consumer instances, count the average number of partitions
|
||||
// that are assigned to them
|
||||
consumerInstancePartitionCount := make(map[ConsumerGroupInstanceId]int)
|
||||
for _, newPartitionSlot := range newPartitionSlots {
|
||||
if newPartitionSlot.AssignedInstanceId != "" {
|
||||
consumerInstancePartitionCount[newPartitionSlot.AssignedInstanceId]++
|
||||
}
|
||||
}
|
||||
// average number of partitions that are assigned to each consumer instance
|
||||
averageConsumerInstanceLoad := float32(len(partitions)) / float32(len(consumerInstances))
|
||||
|
||||
// assign unassigned partition slots to consumer instances that is underloaded
|
||||
consumerInstanceIdsIndex := 0
|
||||
for _, newPartitionSlot := range newPartitionSlots {
|
||||
if newPartitionSlot.AssignedInstanceId == "" {
|
||||
for avoidDeadLoop := len(consumerInstances); avoidDeadLoop > 0; avoidDeadLoop-- {
|
||||
consumerInstance := consumerInstances[consumerInstanceIdsIndex]
|
||||
if float32(consumerInstancePartitionCount[consumerInstance.InstanceId]) < averageConsumerInstanceLoad {
|
||||
newPartitionSlot.AssignedInstanceId = consumerInstance.InstanceId
|
||||
consumerInstancePartitionCount[consumerInstance.InstanceId]++
|
||||
consumerInstanceIdsIndex++
|
||||
if consumerInstanceIdsIndex >= len(consumerInstances) {
|
||||
consumerInstanceIdsIndex = 0
|
||||
}
|
||||
break
|
||||
} else {
|
||||
consumerInstanceIdsIndex++
|
||||
if consumerInstanceIdsIndex >= len(consumerInstances) {
|
||||
consumerInstanceIdsIndex = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return newPartitionSlots
|
||||
}
|
||||
|
||||
@@ -1,385 +0,0 @@
|
||||
package sub_coordinator
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/mq/pub_balancer"
|
||||
)
|
||||
|
||||
func Test_doBalanceSticky(t *testing.T) {
|
||||
type args struct {
|
||||
partitions []*pub_balancer.PartitionSlotToBroker
|
||||
consumerInstanceIds []*ConsumerGroupInstance
|
||||
prevMapping *PartitionSlotToConsumerInstanceList
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantPartitionSlots []*PartitionSlotToConsumerInstance
|
||||
}{
|
||||
{
|
||||
name: "1 consumer instance, 1 partition",
|
||||
args: args{
|
||||
partitions: []*pub_balancer.PartitionSlotToBroker{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 100,
|
||||
},
|
||||
},
|
||||
consumerInstanceIds: []*ConsumerGroupInstance{
|
||||
{
|
||||
InstanceId: "consumer-instance-1",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
},
|
||||
prevMapping: nil,
|
||||
},
|
||||
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2 consumer instances, 1 partition",
|
||||
args: args{
|
||||
partitions: []*pub_balancer.PartitionSlotToBroker{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 100,
|
||||
},
|
||||
},
|
||||
consumerInstanceIds: []*ConsumerGroupInstance{
|
||||
{
|
||||
InstanceId: "consumer-instance-1",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
{
|
||||
InstanceId: "consumer-instance-2",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
},
|
||||
prevMapping: nil,
|
||||
},
|
||||
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "1 consumer instance, 2 partitions",
|
||||
args: args{
|
||||
partitions: []*pub_balancer.PartitionSlotToBroker{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
},
|
||||
},
|
||||
consumerInstanceIds: []*ConsumerGroupInstance{
|
||||
{
|
||||
InstanceId: "consumer-instance-1",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
},
|
||||
prevMapping: nil,
|
||||
},
|
||||
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
AssignedInstanceId: "consumer-instance-1",
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2 consumer instances, 2 partitions",
|
||||
args: args{
|
||||
partitions: []*pub_balancer.PartitionSlotToBroker{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
},
|
||||
},
|
||||
consumerInstanceIds: []*ConsumerGroupInstance{
|
||||
{
|
||||
InstanceId: "consumer-instance-1",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
{
|
||||
InstanceId: "consumer-instance-2",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
},
|
||||
prevMapping: nil,
|
||||
},
|
||||
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
AssignedInstanceId: "consumer-instance-1",
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2 consumer instances, 2 partitions, 1 deleted consumer instance",
|
||||
args: args{
|
||||
partitions: []*pub_balancer.PartitionSlotToBroker{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
},
|
||||
},
|
||||
consumerInstanceIds: []*ConsumerGroupInstance{
|
||||
{
|
||||
InstanceId: "consumer-instance-1",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
{
|
||||
InstanceId: "consumer-instance-2",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
},
|
||||
prevMapping: &PartitionSlotToConsumerInstanceList{
|
||||
PartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
AssignedInstanceId: "consumer-instance-3",
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
AssignedInstanceId: "consumer-instance-1",
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2 consumer instances, 2 partitions, 1 new consumer instance",
|
||||
args: args{
|
||||
partitions: []*pub_balancer.PartitionSlotToBroker{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
},
|
||||
},
|
||||
consumerInstanceIds: []*ConsumerGroupInstance{
|
||||
{
|
||||
InstanceId: "consumer-instance-1",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
{
|
||||
InstanceId: "consumer-instance-2",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
{
|
||||
InstanceId: "consumer-instance-3",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
},
|
||||
prevMapping: &PartitionSlotToConsumerInstanceList{
|
||||
PartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
AssignedInstanceId: "consumer-instance-3",
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
AssignedInstanceId: "consumer-instance-3",
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2 consumer instances, 2 partitions, 1 new partition",
|
||||
args: args{
|
||||
partitions: []*pub_balancer.PartitionSlotToBroker{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
},
|
||||
{
|
||||
RangeStart: 100,
|
||||
RangeStop: 150,
|
||||
},
|
||||
},
|
||||
consumerInstanceIds: []*ConsumerGroupInstance{
|
||||
{
|
||||
InstanceId: "consumer-instance-1",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
{
|
||||
InstanceId: "consumer-instance-2",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
},
|
||||
prevMapping: &PartitionSlotToConsumerInstanceList{
|
||||
PartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
AssignedInstanceId: "consumer-instance-1",
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
AssignedInstanceId: "consumer-instance-1",
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-2",
|
||||
},
|
||||
{
|
||||
RangeStart: 100,
|
||||
RangeStop: 150,
|
||||
AssignedInstanceId: "consumer-instance-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2 consumer instances, 2 partitions, 1 new partition, 1 new consumer instance",
|
||||
args: args{
|
||||
partitions: []*pub_balancer.PartitionSlotToBroker{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
},
|
||||
{
|
||||
RangeStart: 100,
|
||||
RangeStop: 150,
|
||||
},
|
||||
},
|
||||
consumerInstanceIds: []*ConsumerGroupInstance{
|
||||
{
|
||||
InstanceId: "consumer-instance-1",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
{
|
||||
InstanceId: "consumer-instance-2",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
{
|
||||
InstanceId: "consumer-instance-3",
|
||||
MaxPartitionCount: 1,
|
||||
},
|
||||
},
|
||||
prevMapping: &PartitionSlotToConsumerInstanceList{
|
||||
PartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
AssignedInstanceId: "consumer-instance-1",
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantPartitionSlots: []*PartitionSlotToConsumerInstance{
|
||||
{
|
||||
RangeStart: 0,
|
||||
RangeStop: 50,
|
||||
AssignedInstanceId: "consumer-instance-1",
|
||||
},
|
||||
{
|
||||
RangeStart: 50,
|
||||
RangeStop: 100,
|
||||
AssignedInstanceId: "consumer-instance-2",
|
||||
},
|
||||
{
|
||||
RangeStart: 100,
|
||||
RangeStop: 150,
|
||||
AssignedInstanceId: "consumer-instance-3",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if gotPartitionSlots := doBalanceSticky(tt.args.partitions, tt.args.consumerInstanceIds, tt.args.prevMapping); !reflect.DeepEqual(gotPartitionSlots, tt.wantPartitionSlots) {
|
||||
t.Errorf("doBalanceSticky() = %v, want %v", gotPartitionSlots, tt.wantPartitionSlots)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,5 @@
|
||||
package sub_coordinator
|
||||
|
||||
import "time"
|
||||
|
||||
type PartitionSlotToConsumerInstance struct {
|
||||
RangeStart int32
|
||||
RangeStop int32
|
||||
@@ -16,10 +14,3 @@ type PartitionSlotToConsumerInstanceList struct {
|
||||
RingSize int32
|
||||
Version int64
|
||||
}
|
||||
|
||||
func NewPartitionSlotToConsumerInstanceList(ringSize int32, version time.Time) *PartitionSlotToConsumerInstanceList {
|
||||
return &PartitionSlotToConsumerInstanceList{
|
||||
RingSize: ringSize,
|
||||
Version: version.UnixNano(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,22 +90,3 @@ type OffsetAwarePublisher struct {
|
||||
partition *LocalPartition
|
||||
assignOffsetFn OffsetAssignmentFunc
|
||||
}
|
||||
|
||||
// NewOffsetAwarePublisher creates a new offset-aware publisher
|
||||
func NewOffsetAwarePublisher(partition *LocalPartition, assignOffsetFn OffsetAssignmentFunc) *OffsetAwarePublisher {
|
||||
return &OffsetAwarePublisher{
|
||||
partition: partition,
|
||||
assignOffsetFn: assignOffsetFn,
|
||||
}
|
||||
}
|
||||
|
||||
// Publish publishes a message with automatic offset assignment
|
||||
func (oap *OffsetAwarePublisher) Publish(message *mq_pb.DataMessage) error {
|
||||
_, err := oap.partition.PublishWithOffset(message, oap.assignOffsetFn)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetPartition returns the underlying partition
|
||||
func (oap *OffsetAwarePublisher) GetPartition() *LocalPartition {
|
||||
return oap.partition
|
||||
}
|
||||
|
||||
@@ -16,15 +16,6 @@ type Partition struct {
|
||||
UnixTimeNs int64 // in nanoseconds
|
||||
}
|
||||
|
||||
func NewPartition(rangeStart, rangeStop, ringSize int32, unixTimeNs int64) *Partition {
|
||||
return &Partition{
|
||||
RangeStart: rangeStart,
|
||||
RangeStop: rangeStop,
|
||||
RingSize: ringSize,
|
||||
UnixTimeNs: unixTimeNs,
|
||||
}
|
||||
}
|
||||
|
||||
func (partition Partition) Equals(other Partition) bool {
|
||||
if partition.RangeStart != other.RangeStart {
|
||||
return false
|
||||
@@ -57,24 +48,6 @@ func FromPbPartition(partition *schema_pb.Partition) Partition {
|
||||
}
|
||||
}
|
||||
|
||||
func SplitPartitions(targetCount int32, ts int64) []*Partition {
|
||||
partitions := make([]*Partition, 0, targetCount)
|
||||
partitionSize := PartitionCount / targetCount
|
||||
for i := int32(0); i < targetCount; i++ {
|
||||
partitionStop := (i + 1) * partitionSize
|
||||
if i == targetCount-1 {
|
||||
partitionStop = PartitionCount
|
||||
}
|
||||
partitions = append(partitions, &Partition{
|
||||
RangeStart: i * partitionSize,
|
||||
RangeStop: partitionStop,
|
||||
RingSize: PartitionCount,
|
||||
UnixTimeNs: ts,
|
||||
})
|
||||
}
|
||||
return partitions
|
||||
}
|
||||
|
||||
func (partition Partition) ToPbPartition() *schema_pb.Partition {
|
||||
return &schema_pb.Partition{
|
||||
RangeStart: partition.RangeStart,
|
||||
|
||||
@@ -3,8 +3,6 @@ package operation
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb"
|
||||
@@ -41,118 +39,6 @@ type AssignResult struct {
|
||||
Replicas []Location `json:"replicas,omitempty"`
|
||||
}
|
||||
|
||||
// This is a proxy to the master server, only for assigning volume ids.
|
||||
// It runs via grpc to the master server in streaming mode.
|
||||
// The connection to the master would only be re-established when the last connection has error.
|
||||
type AssignProxy struct {
|
||||
grpcConnection *grpc.ClientConn
|
||||
pool chan *singleThreadAssignProxy
|
||||
}
|
||||
|
||||
func NewAssignProxy(masterFn GetMasterFn, grpcDialOption grpc.DialOption, concurrency int) (ap *AssignProxy, err error) {
|
||||
ap = &AssignProxy{
|
||||
pool: make(chan *singleThreadAssignProxy, concurrency),
|
||||
}
|
||||
ap.grpcConnection, err = pb.GrpcDial(context.Background(), masterFn(context.Background()).ToGrpcAddress(), true, grpcDialOption)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fail to dial %s: %v", masterFn(context.Background()).ToGrpcAddress(), err)
|
||||
}
|
||||
for i := 0; i < concurrency; i++ {
|
||||
ap.pool <- &singleThreadAssignProxy{}
|
||||
}
|
||||
return ap, nil
|
||||
}
|
||||
|
||||
func (ap *AssignProxy) Assign(primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (ret *AssignResult, err error) {
|
||||
p := <-ap.pool
|
||||
defer func() {
|
||||
ap.pool <- p
|
||||
}()
|
||||
|
||||
return p.doAssign(ap.grpcConnection, primaryRequest, alternativeRequests...)
|
||||
}
|
||||
|
||||
type singleThreadAssignProxy struct {
|
||||
assignClient master_pb.Seaweed_StreamAssignClient
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (ap *singleThreadAssignProxy) doAssign(grpcConnection *grpc.ClientConn, primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (ret *AssignResult, err error) {
|
||||
ap.Lock()
|
||||
defer ap.Unlock()
|
||||
|
||||
if ap.assignClient == nil {
|
||||
client := master_pb.NewSeaweedClient(grpcConnection)
|
||||
ap.assignClient, err = client.StreamAssign(context.Background())
|
||||
if err != nil {
|
||||
ap.assignClient = nil
|
||||
return nil, fmt.Errorf("fail to create stream assign client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var requests []*VolumeAssignRequest
|
||||
requests = append(requests, primaryRequest)
|
||||
requests = append(requests, alternativeRequests...)
|
||||
ret = &AssignResult{}
|
||||
|
||||
for _, request := range requests {
|
||||
if request == nil {
|
||||
continue
|
||||
}
|
||||
req := &master_pb.AssignRequest{
|
||||
Count: request.Count,
|
||||
Replication: request.Replication,
|
||||
Collection: request.Collection,
|
||||
Ttl: request.Ttl,
|
||||
DiskType: request.DiskType,
|
||||
DataCenter: request.DataCenter,
|
||||
Rack: request.Rack,
|
||||
DataNode: request.DataNode,
|
||||
WritableVolumeCount: request.WritableVolumeCount,
|
||||
}
|
||||
if err = ap.assignClient.Send(req); err != nil {
|
||||
ap.assignClient = nil
|
||||
return nil, fmt.Errorf("StreamAssignSend: %w", err)
|
||||
}
|
||||
resp, grpcErr := ap.assignClient.Recv()
|
||||
if grpcErr != nil {
|
||||
ap.assignClient = nil
|
||||
return nil, grpcErr
|
||||
}
|
||||
if resp.Error != "" {
|
||||
// StreamAssign returns transient warmup errors as in-band responses.
|
||||
// Wrap them as codes.Unavailable so the caller's retry logic can
|
||||
// classify them as retriable.
|
||||
if strings.Contains(resp.Error, "warming up") {
|
||||
return nil, status.Errorf(codes.Unavailable, "StreamAssignRecv: %s", resp.Error)
|
||||
}
|
||||
return nil, fmt.Errorf("StreamAssignRecv: %v", resp.Error)
|
||||
}
|
||||
|
||||
ret.Count = resp.Count
|
||||
ret.Fid = resp.Fid
|
||||
ret.Url = resp.Location.Url
|
||||
ret.PublicUrl = resp.Location.PublicUrl
|
||||
ret.GrpcPort = int(resp.Location.GrpcPort)
|
||||
ret.Error = resp.Error
|
||||
ret.Auth = security.EncodedJwt(resp.Auth)
|
||||
for _, r := range resp.Replicas {
|
||||
ret.Replicas = append(ret.Replicas, Location{
|
||||
Url: r.Url,
|
||||
PublicUrl: r.PublicUrl,
|
||||
DataCenter: r.DataCenter,
|
||||
})
|
||||
}
|
||||
|
||||
if ret.Count <= 0 {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func Assign(ctx context.Context, masterFn GetMasterFn, grpcDialOption grpc.DialOption, primaryRequest *VolumeAssignRequest, alternativeRequests ...*VolumeAssignRequest) (*AssignResult, error) {
|
||||
|
||||
var requests []*VolumeAssignRequest
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
package operation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func BenchmarkWithConcurrency(b *testing.B) {
|
||||
concurrencyLevels := []int{1, 10, 100, 1000}
|
||||
|
||||
ap, _ := NewAssignProxy(func(_ context.Context) pb.ServerAddress {
|
||||
return pb.ServerAddress("localhost:9333")
|
||||
}, grpc.WithInsecure(), 16)
|
||||
|
||||
for _, concurrency := range concurrencyLevels {
|
||||
b.Run(
|
||||
fmt.Sprintf("Concurrency-%d", concurrency),
|
||||
func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
done := make(chan struct{})
|
||||
startTime := time.Now()
|
||||
|
||||
for j := 0; j < concurrency; j++ {
|
||||
go func() {
|
||||
|
||||
ap.Assign(&VolumeAssignRequest{
|
||||
Count: 1,
|
||||
})
|
||||
|
||||
done <- struct{}{}
|
||||
}()
|
||||
}
|
||||
|
||||
for j := 0; j < concurrency; j++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
b.Logf("Concurrency: %d, Duration: %v", concurrency, duration)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStreamAssign(b *testing.B) {
|
||||
ap, _ := NewAssignProxy(func(_ context.Context) pb.ServerAddress {
|
||||
return pb.ServerAddress("localhost:9333")
|
||||
}, grpc.WithInsecure(), 16)
|
||||
for i := 0; i < b.N; i++ {
|
||||
ap.Assign(&VolumeAssignRequest{
|
||||
Count: 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnaryAssign(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
Assign(context.Background(), func(_ context.Context) pb.ServerAddress {
|
||||
return pb.ServerAddress("localhost:9333")
|
||||
}, grpc.WithInsecure(), &VolumeAssignRequest{
|
||||
Count: 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -93,11 +93,6 @@ func List(ctx context.Context, filerClient FilerClient, parentDirectoryPath, pre
|
||||
})
|
||||
}
|
||||
|
||||
func doList(ctx context.Context, filerClient FilerClient, fullDirPath util.FullPath, prefix string, fn EachEntryFunction, startFrom string, inclusive bool, limit uint32) (err error) {
|
||||
_, err = doListWithSnapshot(ctx, filerClient, fullDirPath, prefix, fn, startFrom, inclusive, limit, 0)
|
||||
return err
|
||||
}
|
||||
|
||||
func doListWithSnapshot(ctx context.Context, filerClient FilerClient, fullDirPath util.FullPath, prefix string, fn EachEntryFunction, startFrom string, inclusive bool, limit uint32, snapshotTsNs int64) (actualSnapshotTsNs int64, err error) {
|
||||
err = filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error {
|
||||
actualSnapshotTsNs, err = DoSeaweedListWithSnapshot(ctx, client, fullDirPath, prefix, fn, startFrom, inclusive, limit, snapshotTsNs)
|
||||
@@ -212,26 +207,6 @@ func Exists(ctx context.Context, filerClient FilerClient, parentDirectoryPath st
|
||||
return
|
||||
}
|
||||
|
||||
func Touch(ctx context.Context, filerClient FilerClient, parentDirectoryPath string, entryName string, entry *Entry) (err error) {
|
||||
|
||||
return filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error {
|
||||
|
||||
request := &UpdateEntryRequest{
|
||||
Directory: parentDirectoryPath,
|
||||
Entry: entry,
|
||||
}
|
||||
|
||||
glog.V(4).InfofCtx(ctx, "touch entry %v/%v: %v", parentDirectoryPath, entryName, request)
|
||||
if err := UpdateEntry(ctx, client, request); err != nil {
|
||||
glog.V(0).InfofCtx(ctx, "touch exists entry %v: %v", request, err)
|
||||
return fmt.Errorf("touch exists entry %s/%s: %v", parentDirectoryPath, entryName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func Mkdir(ctx context.Context, filerClient FilerClient, parentDirectoryPath string, dirName string, fn func(entry *Entry)) error {
|
||||
return filerClient.WithFilerClient(false, func(client SeaweedFilerClient) error {
|
||||
return DoMkdir(ctx, client, parentDirectoryPath, dirName, fn)
|
||||
@@ -349,59 +324,3 @@ func DoRemoveWithResponse(ctx context.Context, client SeaweedFilerClient, parent
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
// DoDeleteEmptyParentDirectories recursively deletes empty parent directories.
|
||||
// It stops at root "/" or at stopAtPath.
|
||||
// For safety, dirPath must be under stopAtPath (when stopAtPath is provided).
|
||||
// The checked map tracks already-processed directories to avoid redundant work in batch operations.
|
||||
func DoDeleteEmptyParentDirectories(ctx context.Context, client SeaweedFilerClient, dirPath util.FullPath, stopAtPath util.FullPath, checked map[string]bool) {
|
||||
if dirPath == "/" || dirPath == stopAtPath {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if already checked (for batch delete optimization)
|
||||
dirPathStr := string(dirPath)
|
||||
if checked != nil {
|
||||
if checked[dirPathStr] {
|
||||
return
|
||||
}
|
||||
checked[dirPathStr] = true
|
||||
}
|
||||
|
||||
// Safety check: if stopAtPath is provided, dirPath must be under it (root "/" allows everything)
|
||||
stopStr := string(stopAtPath)
|
||||
if stopAtPath != "" && stopStr != "/" && !strings.HasPrefix(dirPathStr+"/", stopStr+"/") {
|
||||
glog.V(1).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: %s is not under %s, skipping", dirPath, stopAtPath)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if directory is empty by listing with limit 1
|
||||
isEmpty := true
|
||||
err := SeaweedList(ctx, client, dirPathStr, "", func(entry *Entry, isLast bool) error {
|
||||
isEmpty = false
|
||||
return io.EOF // Use sentinel error to explicitly stop iteration
|
||||
}, "", false, 1)
|
||||
|
||||
if err != nil && err != io.EOF {
|
||||
glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: error checking %s: %v", dirPath, err)
|
||||
return
|
||||
}
|
||||
|
||||
if !isEmpty {
|
||||
// Directory is not empty, stop checking upward
|
||||
glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: directory %s is not empty, stopping cleanup", dirPath)
|
||||
return
|
||||
}
|
||||
|
||||
// Directory is empty, try to delete it
|
||||
glog.V(2).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: deleting empty directory %s", dirPath)
|
||||
parentDir, dirName := dirPath.DirAndName()
|
||||
|
||||
if err := DoRemove(ctx, client, parentDir, dirName, false, false, false, false, nil); err == nil {
|
||||
// Successfully deleted, continue checking upwards
|
||||
DoDeleteEmptyParentDirectories(ctx, client, util.FullPath(parentDir), stopAtPath, checked)
|
||||
} else {
|
||||
// Failed to delete, stop cleanup
|
||||
glog.V(3).InfofCtx(ctx, "DoDeleteEmptyParentDirectories: failed to delete %s: %v", dirPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,15 +111,6 @@ func BeforeEntrySerialization(chunks []*FileChunk) {
|
||||
}
|
||||
}
|
||||
|
||||
func EnsureFid(chunk *FileChunk) {
|
||||
if chunk.Fid != nil {
|
||||
return
|
||||
}
|
||||
if fid, err := ToFileIdObject(chunk.FileId); err == nil {
|
||||
chunk.Fid = fid
|
||||
}
|
||||
}
|
||||
|
||||
func AfterEntryDeserialization(chunks []*FileChunk) {
|
||||
|
||||
for _, chunk := range chunks {
|
||||
@@ -309,16 +300,6 @@ func MetadataEventTouchesDirectory(event *SubscribeMetadataResponse, dir string)
|
||||
MetadataEventTargetDirectory(event) == dir
|
||||
}
|
||||
|
||||
func MetadataEventTouchesDirectoryPrefix(event *SubscribeMetadataResponse, prefix string) bool {
|
||||
if strings.HasPrefix(MetadataEventSourceDirectory(event), prefix) {
|
||||
return true
|
||||
}
|
||||
return event != nil &&
|
||||
event.EventNotification != nil &&
|
||||
event.EventNotification.NewEntry != nil &&
|
||||
strings.HasPrefix(MetadataEventTargetDirectory(event), prefix)
|
||||
}
|
||||
|
||||
func MetadataEventMatchesSubscription(event *SubscribeMetadataResponse, pathPrefix string, pathPrefixes []string, directories []string) bool {
|
||||
if event == nil {
|
||||
return false
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
package filer_pb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestFileIdSize(t *testing.T) {
|
||||
fileIdStr := "11745,0293434534cbb9892b"
|
||||
|
||||
fid, _ := ToFileIdObject(fileIdStr)
|
||||
bytes, _ := proto.Marshal(fid)
|
||||
|
||||
println(len(fileIdStr))
|
||||
println(len(bytes))
|
||||
}
|
||||
|
||||
func TestMetadataEventMatchesSubscription(t *testing.T) {
|
||||
event := &SubscribeMetadataResponse{
|
||||
Directory: "/tmp",
|
||||
EventNotification: &EventNotification{
|
||||
OldEntry: &Entry{Name: "old-name"},
|
||||
NewEntry: &Entry{Name: "new-name"},
|
||||
NewParentPath: "/watched",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pathPrefix string
|
||||
pathPrefixes []string
|
||||
directories []string
|
||||
}{
|
||||
{
|
||||
name: "primary path prefix matches rename target",
|
||||
pathPrefix: "/watched/new-name",
|
||||
},
|
||||
{
|
||||
name: "additional path prefix matches rename target",
|
||||
pathPrefix: "/data",
|
||||
pathPrefixes: []string{"/watched"},
|
||||
},
|
||||
{
|
||||
name: "directory watch matches rename target directory",
|
||||
pathPrefix: "/data",
|
||||
directories: []string{"/watched"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if !MetadataEventMatchesSubscription(event, tt.pathPrefix, tt.pathPrefixes, tt.directories) {
|
||||
t.Fatalf("MetadataEventMatchesSubscription returned false")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataEventTouchesDirectoryHelpers(t *testing.T) {
|
||||
renameInto := &SubscribeMetadataResponse{
|
||||
Directory: "/tmp",
|
||||
EventNotification: &EventNotification{
|
||||
OldEntry: &Entry{Name: "filer.conf"},
|
||||
NewEntry: &Entry{Name: "filer.conf"},
|
||||
NewParentPath: "/etc/seaweedfs",
|
||||
},
|
||||
}
|
||||
if got := MetadataEventTargetDirectory(renameInto); got != "/etc/seaweedfs" {
|
||||
t.Fatalf("MetadataEventTargetDirectory = %q, want /etc/seaweedfs", got)
|
||||
}
|
||||
if !MetadataEventTouchesDirectory(renameInto, "/etc/seaweedfs") {
|
||||
t.Fatalf("expected rename target to touch /etc/seaweedfs")
|
||||
}
|
||||
|
||||
renameOut := &SubscribeMetadataResponse{
|
||||
Directory: "/etc/remote",
|
||||
EventNotification: &EventNotification{
|
||||
OldEntry: &Entry{Name: "remote.conf"},
|
||||
NewEntry: &Entry{Name: "remote.conf"},
|
||||
NewParentPath: "/tmp",
|
||||
},
|
||||
}
|
||||
if !MetadataEventTouchesDirectoryPrefix(renameOut, "/etc/remote") {
|
||||
t.Fatalf("expected rename source to touch /etc/remote")
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/master_pb"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/mq_pb"
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/worker_pb"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -318,18 +317,6 @@ func WithGrpcClient(streamingMode bool, signature int32, fn func(*grpc.ClientCon
|
||||
|
||||
}
|
||||
|
||||
func ParseServerAddress(server string, deltaPort int) (newServerAddress string, err error) {
|
||||
|
||||
host, port, parseErr := hostAndPort(server)
|
||||
if parseErr != nil {
|
||||
return "", fmt.Errorf("server port parse error: %w", parseErr)
|
||||
}
|
||||
|
||||
newPort := int(port) + deltaPort
|
||||
|
||||
return util.JoinHostPort(host, newPort), nil
|
||||
}
|
||||
|
||||
func hostAndPort(address string) (host string, port uint64, err error) {
|
||||
colonIndex := strings.LastIndex(address, ":")
|
||||
if colonIndex < 0 {
|
||||
@@ -457,10 +444,3 @@ func WithOneOfGrpcFilerClients(streamingMode bool, filerAddresses []ServerAddres
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func WithWorkerClient(streamingMode bool, workerAddress string, grpcDialOption grpc.DialOption, fn func(client worker_pb.WorkerServiceClient) error) error {
|
||||
return WithGrpcClient(streamingMode, 0, func(grpcConnection *grpc.ClientConn) error {
|
||||
client := worker_pb.NewWorkerServiceClient(grpcConnection)
|
||||
return fn(client)
|
||||
}, workerAddress, false, grpcDialOption)
|
||||
}
|
||||
|
||||
@@ -157,14 +157,6 @@ func (sa ServerAddresses) ToAddressMap() (addresses map[string]ServerAddress) {
|
||||
return
|
||||
}
|
||||
|
||||
func (sa ServerAddresses) ToAddressStrings() (addresses []string) {
|
||||
parts := strings.Split(string(sa), ",")
|
||||
for _, address := range parts {
|
||||
addresses = append(addresses, address)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ToAddressStrings(addresses []ServerAddress) []string {
|
||||
var strings []string
|
||||
for _, addr := range addresses {
|
||||
@@ -172,20 +164,6 @@ func ToAddressStrings(addresses []ServerAddress) []string {
|
||||
}
|
||||
return strings
|
||||
}
|
||||
func ToAddressStringsFromMap(addresses map[string]ServerAddress) []string {
|
||||
var strings []string
|
||||
for _, addr := range addresses {
|
||||
strings = append(strings, string(addr))
|
||||
}
|
||||
return strings
|
||||
}
|
||||
func FromAddressStrings(strings []string) []ServerAddress {
|
||||
var addresses []ServerAddress
|
||||
for _, addr := range strings {
|
||||
addresses = append(addresses, ServerAddress(addr))
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
|
||||
func ParseUrl(input string) (address ServerAddress, path string, err error) {
|
||||
if !strings.HasPrefix(input, "http://") {
|
||||
|
||||
@@ -449,58 +449,6 @@ func hasEligibleCompaction(
|
||||
return len(bins) > 0, nil
|
||||
}
|
||||
|
||||
func countDataManifestsForRewrite(
|
||||
ctx context.Context,
|
||||
filerClient filer_pb.SeaweedFilerClient,
|
||||
bucketName, tablePath string,
|
||||
manifests []iceberg.ManifestFile,
|
||||
meta table.Metadata,
|
||||
predicate *partitionPredicate,
|
||||
) (int64, error) {
|
||||
if predicate == nil {
|
||||
return countDataManifests(manifests), nil
|
||||
}
|
||||
|
||||
specsByID := specByID(meta)
|
||||
|
||||
var count int64
|
||||
for _, mf := range manifests {
|
||||
if mf.ManifestContent() != iceberg.ManifestContentData {
|
||||
continue
|
||||
}
|
||||
manifestData, err := loadFileByIcebergPath(ctx, filerClient, bucketName, tablePath, mf.FilePath())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read manifest %s: %w", mf.FilePath(), err)
|
||||
}
|
||||
entries, err := iceberg.ReadManifest(mf, bytes.NewReader(manifestData), true)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse manifest %s: %w", mf.FilePath(), err)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
spec, ok := specsByID[int(mf.PartitionSpecID())]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
allMatch := len(entries) > 0
|
||||
for _, entry := range entries {
|
||||
match, err := predicate.Matches(spec, entry.DataFile().Partition())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !match {
|
||||
allMatch = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allMatch {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func compactionMinInputFiles(minInputFiles int64) (int, error) {
|
||||
// Ensure the configured value is positive and fits into the platform's int type
|
||||
if minInputFiles <= 0 {
|
||||
|
||||
@@ -137,26 +137,6 @@ func mergePlanningIndexSections(index, existing *planningIndex) *planningIndex {
|
||||
return index
|
||||
}
|
||||
|
||||
func buildPlanningIndex(
|
||||
ctx context.Context,
|
||||
filerClient filer_pb.SeaweedFilerClient,
|
||||
bucketName, tablePath string,
|
||||
meta table.Metadata,
|
||||
config Config,
|
||||
ops []string,
|
||||
) (*planningIndex, error) {
|
||||
currentSnap := meta.CurrentSnapshot()
|
||||
if currentSnap == nil || currentSnap.ManifestList == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
manifests, err := loadCurrentManifests(ctx, filerClient, bucketName, tablePath, meta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buildPlanningIndexFromManifests(ctx, filerClient, bucketName, tablePath, meta, config, ops, manifests)
|
||||
}
|
||||
|
||||
func buildPlanningIndexFromManifests(
|
||||
ctx context.Context,
|
||||
filerClient filer_pb.SeaweedFilerClient,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user