chore: remove ~50k lines of unreachable dead code (#8913)

* chore: remove unreachable dead code across the codebase

Remove ~50,000 lines of unreachable code identified by static analysis.

Major removals:
- weed/filer/redis_lua: entire unused Redis Lua filer store implementation
- weed/wdclient/net2, resource_pool: unused connection/resource pool packages
- weed/plugin/worker/lifecycle: unused lifecycle plugin worker
- weed/s3api: unused S3 policy templates, presigned URL IAM, streaming copy,
  multipart IAM, key rotation, and various SSE helper functions
- weed/mq/kafka: unused partition mapping, compression, schema, and protocol functions
- weed/mq/offset: unused SQL storage and migration code
- weed/worker: unused registry, task, and monitoring functions
- weed/query: unused SQL engine, parquet scanner, and type functions
- weed/shell: unused EC proportional rebalance functions
- weed/storage/erasure_coding/distribution: unused distribution analysis functions
- Individual unreachable functions removed from 150+ files across admin,
  credential, filer, iam, kms, mount, mq, operation, pb, s3api, server,
  shell, storage, topology, and util packages

* fix(s3): reset shared memory store in IAM test to prevent flaky failure

TestLoadIAMManagerFromConfig_EmptyConfigWithFallbackKey was flaky because
the MemoryStore credential backend is a singleton registered via init().
Earlier tests that create anonymous identities pollute the shared store,
causing LookupAnonymous() to unexpectedly return true.

Fix by calling Reset() on the memory store before the test runs.

* style: run gofmt on changed files

* fix: restore KMS functions used by integration tests

* fix(plugin): prevent panic on send to closed worker session channel

The Plugin.sendToWorker method could panic with "send on closed channel"
when a worker disconnected while a message was being sent. The race was
between streamSession.close() closing the outgoing channel and sendToWorker
writing to it concurrently.

Add a done channel to streamSession that is closed before the outgoing
channel, and check it in sendToWorker's select to safely detect closed
sessions without panicking.
This commit is contained in:
Chris Lu
2026-04-03 16:04:27 -07:00
committed by GitHub
parent 8fad85aed7
commit 995dfc4d5d
264 changed files with 62 additions and 46027 deletions

View File

@@ -1,503 +0,0 @@
package sts
import (
"context"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/seaweedfs/seaweedfs/weed/iam/oidc"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test-only constants for mock providers
const (
ProviderTypeMock = "mock"
)
// createMockOIDCProvider creates a mock OIDC provider for testing
// This is only available in test builds
func createMockOIDCProvider(name string, config map[string]interface{}) (providers.IdentityProvider, error) {
// Convert config to OIDC format
factory := NewProviderFactory()
oidcConfig, err := factory.convertToOIDCConfig(config)
if err != nil {
return nil, err
}
// Set default values for mock provider if not provided
if oidcConfig.Issuer == "" {
oidcConfig.Issuer = "http://localhost:9999"
}
provider := oidc.NewMockOIDCProvider(name)
if err := provider.Initialize(oidcConfig); err != nil {
return nil, err
}
// Set up default test data for the mock provider
provider.SetupDefaultTestData()
return provider, nil
}
// createMockJWT creates a test JWT token with the specified issuer for mock provider testing
func createMockJWT(t *testing.T, issuer, subject string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": issuer,
"sub": subject,
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
return tokenString
}
// TestCrossInstanceTokenUsage verifies that tokens generated by one STS instance
// can be used and validated by other STS instances in a distributed environment
func TestCrossInstanceTokenUsage(t *testing.T) {
ctx := context.Background()
// Dummy filer address for testing
// Common configuration that would be shared across all instances in production
sharedConfig := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "distributed-sts-cluster", // SAME across all instances
SigningKey: []byte(TestSigningKey32Chars), // SAME across all instances
Providers: []*ProviderConfig{
{
Name: "company-oidc",
Type: ProviderTypeOIDC,
Enabled: true,
Config: map[string]interface{}{
ConfigFieldIssuer: "https://sso.company.com/realms/production",
ConfigFieldClientID: "seaweedfs-cluster",
ConfigFieldJWKSUri: "https://sso.company.com/realms/production/protocol/openid-connect/certs",
},
},
},
}
// Create multiple STS instances simulating different S3 gateway instances
instanceA := NewSTSService() // e.g., s3-gateway-1
instanceB := NewSTSService() // e.g., s3-gateway-2
instanceC := NewSTSService() // e.g., s3-gateway-3
// Initialize all instances with IDENTICAL configuration
err := instanceA.Initialize(sharedConfig)
require.NoError(t, err, "Instance A should initialize")
err = instanceB.Initialize(sharedConfig)
require.NoError(t, err, "Instance B should initialize")
err = instanceC.Initialize(sharedConfig)
require.NoError(t, err, "Instance C should initialize")
// Set up mock trust policy validator for all instances (required for STS testing)
mockValidator := &MockTrustPolicyValidator{}
instanceA.SetTrustPolicyValidator(mockValidator)
instanceB.SetTrustPolicyValidator(mockValidator)
instanceC.SetTrustPolicyValidator(mockValidator)
// Manually register mock provider for testing (not available in production)
mockProviderConfig := map[string]interface{}{
ConfigFieldIssuer: "http://test-mock:9999",
ConfigFieldClientID: TestClientID,
}
mockProviderA, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
mockProviderB, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
mockProviderC, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
instanceA.RegisterProvider(mockProviderA)
instanceB.RegisterProvider(mockProviderB)
instanceC.RegisterProvider(mockProviderC)
// Test 1: Token generated on Instance A can be validated on Instance B & C
t.Run("cross_instance_token_validation", func(t *testing.T) {
// Generate session token on Instance A
sessionId := TestSessionID
expiresAt := time.Now().Add(time.Hour)
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err, "Instance A should generate token")
// Validate token on Instance B
claimsFromB, err := instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
require.NoError(t, err, "Instance B should validate token from Instance A")
assert.Equal(t, sessionId, claimsFromB.SessionId, "Session ID should match")
// Validate same token on Instance C
claimsFromC, err := instanceC.GetTokenGenerator().ValidateSessionToken(tokenFromA)
require.NoError(t, err, "Instance C should validate token from Instance A")
assert.Equal(t, sessionId, claimsFromC.SessionId, "Session ID should match")
// All instances should extract identical claims
assert.Equal(t, claimsFromB.SessionId, claimsFromC.SessionId)
assert.Equal(t, claimsFromB.ExpiresAt.Unix(), claimsFromC.ExpiresAt.Unix())
assert.Equal(t, claimsFromB.IssuedAt.Unix(), claimsFromC.IssuedAt.Unix())
})
// Test 2: Complete assume role flow across instances
t.Run("cross_instance_assume_role_flow", func(t *testing.T) {
// Step 1: User authenticates and assumes role on Instance A
// Create a valid JWT token for the mock provider
mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
assumeRequest := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/CrossInstanceTestRole",
WebIdentityToken: mockToken, // JWT token for mock provider
RoleSessionName: "cross-instance-test-session",
DurationSeconds: int64ToPtr(3600),
}
// Instance A processes assume role request
responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
require.NoError(t, err, "Instance A should process assume role")
sessionToken := responseFromA.Credentials.SessionToken
accessKeyId := responseFromA.Credentials.AccessKeyId
secretAccessKey := responseFromA.Credentials.SecretAccessKey
// Verify response structure
assert.NotEmpty(t, sessionToken, "Should have session token")
assert.NotEmpty(t, accessKeyId, "Should have access key ID")
assert.NotEmpty(t, secretAccessKey, "Should have secret access key")
assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user")
// Step 2: Use session token on Instance B (different instance)
sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Instance B should validate session token from Instance A")
assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName)
assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn)
// Step 3: Use same session token on Instance C (yet another instance)
sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Instance C should validate session token from Instance A")
// All instances should return identical session information
assert.Equal(t, sessionInfoFromB.SessionId, sessionInfoFromC.SessionId)
assert.Equal(t, sessionInfoFromB.SessionName, sessionInfoFromC.SessionName)
assert.Equal(t, sessionInfoFromB.RoleArn, sessionInfoFromC.RoleArn)
assert.Equal(t, sessionInfoFromB.Subject, sessionInfoFromC.Subject)
assert.Equal(t, sessionInfoFromB.Provider, sessionInfoFromC.Provider)
})
// Test 3: Session revocation across instances
t.Run("cross_instance_session_revocation", func(t *testing.T) {
// Create session on Instance A
mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
assumeRequest := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/RevocationTestRole",
WebIdentityToken: mockToken,
RoleSessionName: "revocation-test-session",
}
response, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
require.NoError(t, err)
sessionToken := response.Credentials.SessionToken
// Verify token works on Instance B
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Token should be valid on Instance B initially")
// Validate session on Instance C to verify cross-instance token compatibility
_, err = instanceC.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Instance C should be able to validate session token")
// In a stateless JWT system, tokens remain valid on all instances since they're self-contained
// No revocation is possible without breaking the stateless architecture
_, err = instanceA.ValidateSessionToken(ctx, sessionToken)
assert.NoError(t, err, "Token should still be valid on Instance A (stateless system)")
// Verify token is still valid on Instance B
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
assert.NoError(t, err, "Token should still be valid on Instance B (stateless system)")
})
// Test 4: Provider consistency across instances
t.Run("provider_consistency_affects_token_generation", func(t *testing.T) {
// All instances should have same providers and be able to process same OIDC tokens
providerNamesA := instanceA.getProviderNames()
providerNamesB := instanceB.getProviderNames()
providerNamesC := instanceC.getProviderNames()
assert.ElementsMatch(t, providerNamesA, providerNamesB, "Instance A and B should have same providers")
assert.ElementsMatch(t, providerNamesB, providerNamesC, "Instance B and C should have same providers")
// All instances should be able to process same web identity token
testToken := createMockJWT(t, "http://test-mock:9999", "test-user")
// Try to assume role with same token on different instances
assumeRequest := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/ProviderTestRole",
WebIdentityToken: testToken,
RoleSessionName: "provider-consistency-test",
}
// Should work on any instance
responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, assumeRequest)
responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, assumeRequest)
require.NoError(t, errA, "Instance A should process OIDC token")
require.NoError(t, errB, "Instance B should process OIDC token")
require.NoError(t, errC, "Instance C should process OIDC token")
// All should return valid responses (sessions will have different IDs but same structure)
assert.NotEmpty(t, responseA.Credentials.SessionToken)
assert.NotEmpty(t, responseB.Credentials.SessionToken)
assert.NotEmpty(t, responseC.Credentials.SessionToken)
})
}
// TestSTSDistributedConfigurationRequirements tests the configuration requirements
// for cross-instance token compatibility
func TestSTSDistributedConfigurationRequirements(t *testing.T) {
_ = "localhost:8888" // Dummy filer address for testing (not used in these tests)
t.Run("same_signing_key_required", func(t *testing.T) {
// Instance A with signing key 1
configA := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "test-sts",
SigningKey: []byte("signing-key-1-32-characters-long"),
}
// Instance B with different signing key
configB := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "test-sts",
SigningKey: []byte("signing-key-2-32-characters-long"), // DIFFERENT!
}
instanceA := NewSTSService()
instanceB := NewSTSService()
err := instanceA.Initialize(configA)
require.NoError(t, err)
err = instanceB.Initialize(configB)
require.NoError(t, err)
// Generate token on Instance A
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance A should validate its own token
_, err = instanceA.GetTokenGenerator().ValidateSessionToken(tokenFromA)
assert.NoError(t, err, "Instance A should validate own token")
// Instance B should REJECT token due to different signing key
_, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
assert.Error(t, err, "Instance B should reject token with different signing key")
assert.Contains(t, err.Error(), "invalid token", "Should be signature validation error")
})
t.Run("same_issuer_required", func(t *testing.T) {
sharedSigningKey := []byte("shared-signing-key-32-characters-lo")
// Instance A with issuer 1
configA := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "sts-cluster-1",
SigningKey: sharedSigningKey,
}
// Instance B with different issuer
configB := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "sts-cluster-2", // DIFFERENT!
SigningKey: sharedSigningKey,
}
instanceA := NewSTSService()
instanceB := NewSTSService()
err := instanceA.Initialize(configA)
require.NoError(t, err)
err = instanceB.Initialize(configB)
require.NoError(t, err)
// Generate token on Instance A
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance B should REJECT token due to different issuer
_, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
assert.Error(t, err, "Instance B should reject token with different issuer")
assert.Contains(t, err.Error(), "invalid issuer", "Should be issuer validation error")
})
t.Run("identical_configuration_required", func(t *testing.T) {
// Identical configuration
identicalConfig := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "production-sts-cluster",
SigningKey: []byte("production-signing-key-32-chars-l"),
}
// Create multiple instances with identical config
instances := make([]*STSService, 5)
for i := 0; i < 5; i++ {
instances[i] = NewSTSService()
err := instances[i].Initialize(identicalConfig)
require.NoError(t, err, "Instance %d should initialize", i)
}
// Generate token on Instance 0
sessionId := "multi-instance-test"
expiresAt := time.Now().Add(time.Hour)
token, err := instances[0].GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// All other instances should validate the token
for i := 1; i < 5; i++ {
claims, err := instances[i].GetTokenGenerator().ValidateSessionToken(token)
require.NoError(t, err, "Instance %d should validate token", i)
assert.Equal(t, sessionId, claims.SessionId, "Instance %d should extract correct session ID", i)
}
})
}
// TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios
func TestSTSRealWorldDistributedScenarios(t *testing.T) {
ctx := context.Background()
t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) {
// Simulate real production scenario:
// 1. User authenticates with OIDC provider
// 2. User calls AssumeRoleWithWebIdentity on S3 Gateway 1
// 3. User makes S3 requests that hit S3 Gateway 2 & 3 via load balancer
// 4. All instances should handle the session token correctly
productionConfig := &STSConfig{
TokenDuration: FlexibleDuration{2 * time.Hour},
MaxSessionLength: FlexibleDuration{24 * time.Hour},
Issuer: "seaweedfs-production-sts",
SigningKey: []byte("prod-signing-key-32-characters-lon"),
Providers: []*ProviderConfig{
{
Name: "corporate-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://sso.company.com/realms/production",
"clientId": "seaweedfs-prod-cluster",
"clientSecret": "supersecret-prod-key",
"scopes": []string{"openid", "profile", "email", "groups"},
},
},
},
}
// Create 3 S3 Gateway instances behind load balancer
gateway1 := NewSTSService()
gateway2 := NewSTSService()
gateway3 := NewSTSService()
err := gateway1.Initialize(productionConfig)
require.NoError(t, err)
err = gateway2.Initialize(productionConfig)
require.NoError(t, err)
err = gateway3.Initialize(productionConfig)
require.NoError(t, err)
// Set up mock trust policy validator for all gateway instances
mockValidator := &MockTrustPolicyValidator{}
gateway1.SetTrustPolicyValidator(mockValidator)
gateway2.SetTrustPolicyValidator(mockValidator)
gateway3.SetTrustPolicyValidator(mockValidator)
// Manually register mock provider for testing (not available in production)
mockProviderConfig := map[string]interface{}{
ConfigFieldIssuer: "http://test-mock:9999",
ConfigFieldClientID: "test-client-id",
}
mockProvider1, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
mockProvider2, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
mockProvider3, err := createMockOIDCProvider("test-mock", mockProviderConfig)
require.NoError(t, err)
gateway1.RegisterProvider(mockProvider1)
gateway2.RegisterProvider(mockProvider2)
gateway3.RegisterProvider(mockProvider3)
// Step 1: User authenticates and hits Gateway 1 for AssumeRole
mockToken := createMockJWT(t, "http://test-mock:9999", "production-user")
assumeRequest := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/ProductionS3User",
WebIdentityToken: mockToken, // JWT token from mock provider
RoleSessionName: "user-production-session",
DurationSeconds: int64ToPtr(7200), // 2 hours
}
stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, assumeRequest)
require.NoError(t, err, "Gateway 1 should handle AssumeRole")
sessionToken := stsResponse.Credentials.SessionToken
accessKey := stsResponse.Credentials.AccessKeyId
secretKey := stsResponse.Credentials.SecretAccessKey
// Step 2: User makes S3 requests that hit different gateways via load balancer
// Simulate S3 request validation on Gateway 2
sessionInfo2, err := gateway2.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Gateway 2 should validate session from Gateway 1")
assert.Equal(t, "user-production-session", sessionInfo2.SessionName)
assert.Equal(t, "arn:aws:iam::role/ProductionS3User", sessionInfo2.RoleArn)
// Simulate S3 request validation on Gateway 3
sessionInfo3, err := gateway3.ValidateSessionToken(ctx, sessionToken)
require.NoError(t, err, "Gateway 3 should validate session from Gateway 1")
assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session")
// Step 3: Verify credentials are consistent
assert.Equal(t, accessKey, stsResponse.Credentials.AccessKeyId, "Access key should be consistent")
assert.Equal(t, secretKey, stsResponse.Credentials.SecretAccessKey, "Secret key should be consistent")
// Step 4: Session expiration should be honored across all instances
assert.True(t, sessionInfo2.ExpiresAt.After(time.Now()), "Session should not be expired")
assert.True(t, sessionInfo3.ExpiresAt.After(time.Now()), "Session should not be expired")
// Step 5: Token should be identical when parsed
claims2, err := gateway2.GetTokenGenerator().ValidateSessionToken(sessionToken)
require.NoError(t, err)
claims3, err := gateway3.GetTokenGenerator().ValidateSessionToken(sessionToken)
require.NoError(t, err)
assert.Equal(t, claims2.SessionId, claims3.SessionId, "Session IDs should match")
assert.Equal(t, claims2.ExpiresAt.Unix(), claims3.ExpiresAt.Unix(), "Expiration should match")
})
}
// Helper function to convert int64 to pointer
func int64ToPtr(i int64) *int64 {
return &i
}

View File

@@ -1,340 +0,0 @@
package sts
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestDistributedSTSService verifies that multiple STS instances with identical configurations
// behave consistently across distributed environments
func TestDistributedSTSService(t *testing.T) {
ctx := context.Background()
// Common configuration for all instances
commonConfig := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "distributed-sts-test",
SigningKey: []byte("test-signing-key-32-characters-long"),
Providers: []*ProviderConfig{
{
Name: "keycloak-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "http://keycloak:8080/realms/seaweedfs-test",
"clientId": "seaweedfs-s3",
"jwksUri": "http://keycloak:8080/realms/seaweedfs-test/protocol/openid-connect/certs",
},
},
{
Name: "disabled-ldap",
Type: "oidc", // Use OIDC as placeholder since LDAP isn't implemented
Enabled: false,
Config: map[string]interface{}{
"issuer": "ldap://company.com",
"clientId": "ldap-client",
},
},
},
}
// Create multiple STS instances simulating distributed deployment
instance1 := NewSTSService()
instance2 := NewSTSService()
instance3 := NewSTSService()
// Initialize all instances with identical configuration
err := instance1.Initialize(commonConfig)
require.NoError(t, err, "Instance 1 should initialize successfully")
err = instance2.Initialize(commonConfig)
require.NoError(t, err, "Instance 2 should initialize successfully")
err = instance3.Initialize(commonConfig)
require.NoError(t, err, "Instance 3 should initialize successfully")
// Manually register mock providers for testing (not available in production)
mockProviderConfig := map[string]interface{}{
"issuer": "http://localhost:9999",
"clientId": "test-client",
}
mockProvider1, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
require.NoError(t, err)
mockProvider2, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
require.NoError(t, err)
mockProvider3, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
require.NoError(t, err)
instance1.RegisterProvider(mockProvider1)
instance2.RegisterProvider(mockProvider2)
instance3.RegisterProvider(mockProvider3)
// Verify all instances have identical provider configurations
t.Run("provider_consistency", func(t *testing.T) {
// All instances should have same number of providers
assert.Len(t, instance1.providers, 2, "Instance 1 should have 2 enabled providers")
assert.Len(t, instance2.providers, 2, "Instance 2 should have 2 enabled providers")
assert.Len(t, instance3.providers, 2, "Instance 3 should have 2 enabled providers")
// All instances should have same provider names
instance1Names := instance1.getProviderNames()
instance2Names := instance2.getProviderNames()
instance3Names := instance3.getProviderNames()
assert.ElementsMatch(t, instance1Names, instance2Names, "Instance 1 and 2 should have same providers")
assert.ElementsMatch(t, instance2Names, instance3Names, "Instance 2 and 3 should have same providers")
// Verify specific providers exist on all instances
expectedProviders := []string{"keycloak-oidc", "test-mock-provider"}
assert.ElementsMatch(t, instance1Names, expectedProviders, "Instance 1 should have expected providers")
assert.ElementsMatch(t, instance2Names, expectedProviders, "Instance 2 should have expected providers")
assert.ElementsMatch(t, instance3Names, expectedProviders, "Instance 3 should have expected providers")
// Verify disabled providers are not loaded
assert.NotContains(t, instance1Names, "disabled-ldap", "Disabled providers should not be loaded")
assert.NotContains(t, instance2Names, "disabled-ldap", "Disabled providers should not be loaded")
assert.NotContains(t, instance3Names, "disabled-ldap", "Disabled providers should not be loaded")
})
// Test token generation consistency across instances
t.Run("token_generation_consistency", func(t *testing.T) {
sessionId := "test-session-123"
expiresAt := time.Now().Add(time.Hour)
// Generate tokens from different instances
token1, err1 := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
token2, err2 := instance2.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
token3, err3 := instance3.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err1, "Instance 1 token generation should succeed")
require.NoError(t, err2, "Instance 2 token generation should succeed")
require.NoError(t, err3, "Instance 3 token generation should succeed")
// All tokens should be different (due to timestamp variations)
// But they should all be valid JWTs with same signing key
assert.NotEmpty(t, token1)
assert.NotEmpty(t, token2)
assert.NotEmpty(t, token3)
})
// Test token validation consistency - any instance should validate tokens from any other instance
t.Run("cross_instance_token_validation", func(t *testing.T) {
sessionId := "cross-validation-session"
expiresAt := time.Now().Add(time.Hour)
// Generate token on instance 1
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Validate on all instances
claims1, err1 := instance1.GetTokenGenerator().ValidateSessionToken(token)
claims2, err2 := instance2.GetTokenGenerator().ValidateSessionToken(token)
claims3, err3 := instance3.GetTokenGenerator().ValidateSessionToken(token)
require.NoError(t, err1, "Instance 1 should validate token from instance 1")
require.NoError(t, err2, "Instance 2 should validate token from instance 1")
require.NoError(t, err3, "Instance 3 should validate token from instance 1")
// All instances should extract same session ID
assert.Equal(t, sessionId, claims1.SessionId)
assert.Equal(t, sessionId, claims2.SessionId)
assert.Equal(t, sessionId, claims3.SessionId)
assert.Equal(t, claims1.SessionId, claims2.SessionId)
assert.Equal(t, claims2.SessionId, claims3.SessionId)
})
// Test provider access consistency
t.Run("provider_access_consistency", func(t *testing.T) {
// All instances should be able to access the same providers
provider1, exists1 := instance1.providers["test-mock-provider"]
provider2, exists2 := instance2.providers["test-mock-provider"]
provider3, exists3 := instance3.providers["test-mock-provider"]
assert.True(t, exists1, "Instance 1 should have test-mock-provider")
assert.True(t, exists2, "Instance 2 should have test-mock-provider")
assert.True(t, exists3, "Instance 3 should have test-mock-provider")
assert.Equal(t, provider1.Name(), provider2.Name())
assert.Equal(t, provider2.Name(), provider3.Name())
// Test authentication with the mock provider on all instances
testToken := "valid_test_token"
identity1, err1 := provider1.Authenticate(ctx, testToken)
identity2, err2 := provider2.Authenticate(ctx, testToken)
identity3, err3 := provider3.Authenticate(ctx, testToken)
require.NoError(t, err1, "Instance 1 provider should authenticate successfully")
require.NoError(t, err2, "Instance 2 provider should authenticate successfully")
require.NoError(t, err3, "Instance 3 provider should authenticate successfully")
// All instances should return identical identity information
assert.Equal(t, identity1.UserID, identity2.UserID)
assert.Equal(t, identity2.UserID, identity3.UserID)
assert.Equal(t, identity1.Email, identity2.Email)
assert.Equal(t, identity2.Email, identity3.Email)
assert.Equal(t, identity1.Provider, identity2.Provider)
assert.Equal(t, identity2.Provider, identity3.Provider)
})
}
// TestSTSConfigurationValidation tests configuration validation for distributed deployments
func TestSTSConfigurationValidation(t *testing.T) {
t.Run("consistent_signing_keys_required", func(t *testing.T) {
// Different signing keys should result in incompatible token validation
config1 := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "test-sts",
SigningKey: []byte("signing-key-1-32-characters-long"),
}
config2 := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "test-sts",
SigningKey: []byte("signing-key-2-32-characters-long"), // Different key!
}
instance1 := NewSTSService()
instance2 := NewSTSService()
err1 := instance1.Initialize(config1)
err2 := instance2.Initialize(config2)
require.NoError(t, err1)
require.NoError(t, err2)
// Generate token on instance 1
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance 1 should validate its own token
_, err = instance1.GetTokenGenerator().ValidateSessionToken(token)
assert.NoError(t, err, "Instance 1 should validate its own token")
// Instance 2 should reject token from instance 1 (different signing key)
_, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
assert.Error(t, err, "Instance 2 should reject token with different signing key")
})
t.Run("consistent_issuer_required", func(t *testing.T) {
// Different issuers should result in incompatible tokens
commonSigningKey := []byte("shared-signing-key-32-characters-lo")
config1 := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "sts-instance-1",
SigningKey: commonSigningKey,
}
config2 := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{12 * time.Hour},
Issuer: "sts-instance-2", // Different issuer!
SigningKey: commonSigningKey,
}
instance1 := NewSTSService()
instance2 := NewSTSService()
err1 := instance1.Initialize(config1)
err2 := instance2.Initialize(config2)
require.NoError(t, err1)
require.NoError(t, err2)
// Generate token on instance 1
sessionId := "test-session"
expiresAt := time.Now().Add(time.Hour)
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err)
// Instance 2 should reject token due to issuer mismatch
// (Even though signing key is the same, issuer validation will fail)
_, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
assert.Error(t, err, "Instance 2 should reject token with different issuer")
})
}
// TestProviderFactoryDistributed tests the provider factory in distributed scenarios
func TestProviderFactoryDistributed(t *testing.T) {
factory := NewProviderFactory()
// Simulate configuration that would be identical across all instances
configs := []*ProviderConfig{
{
Name: "production-keycloak",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://keycloak.company.com/realms/seaweedfs",
"clientId": "seaweedfs-prod",
"clientSecret": "super-secret-key",
"jwksUri": "https://keycloak.company.com/realms/seaweedfs/protocol/openid-connect/certs",
"scopes": []string{"openid", "profile", "email", "roles"},
},
},
{
Name: "backup-oidc",
Type: "oidc",
Enabled: false, // Disabled by default
Config: map[string]interface{}{
"issuer": "https://backup-oidc.company.com",
"clientId": "seaweedfs-backup",
},
},
}
// Create providers multiple times (simulating multiple instances)
providers1, err1 := factory.LoadProvidersFromConfig(configs)
providers2, err2 := factory.LoadProvidersFromConfig(configs)
providers3, err3 := factory.LoadProvidersFromConfig(configs)
require.NoError(t, err1, "First load should succeed")
require.NoError(t, err2, "Second load should succeed")
require.NoError(t, err3, "Third load should succeed")
// All instances should have same provider counts
assert.Len(t, providers1, 1, "First instance should have 1 enabled provider")
assert.Len(t, providers2, 1, "Second instance should have 1 enabled provider")
assert.Len(t, providers3, 1, "Third instance should have 1 enabled provider")
// All instances should have same provider names
names1 := make([]string, 0, len(providers1))
names2 := make([]string, 0, len(providers2))
names3 := make([]string, 0, len(providers3))
for name := range providers1 {
names1 = append(names1, name)
}
for name := range providers2 {
names2 = append(names2, name)
}
for name := range providers3 {
names3 = append(names3, name)
}
assert.ElementsMatch(t, names1, names2, "Instance 1 and 2 should have same provider names")
assert.ElementsMatch(t, names2, names3, "Instance 2 and 3 should have same provider names")
// Verify specific providers
expectedProviders := []string{"production-keycloak"}
assert.ElementsMatch(t, names1, expectedProviders, "Should have expected enabled providers")
// Verify disabled providers are not included
assert.NotContains(t, names1, "backup-oidc", "Disabled providers should not be loaded")
assert.NotContains(t, names2, "backup-oidc", "Disabled providers should not be loaded")
assert.NotContains(t, names3, "backup-oidc", "Disabled providers should not be loaded")
}

View File

@@ -274,69 +274,3 @@ func (f *ProviderFactory) convertToRoleMapping(value interface{}) (*providers.Ro
return roleMapping, nil
}
// ValidateProviderConfig validates a provider configuration
func (f *ProviderFactory) ValidateProviderConfig(config *ProviderConfig) error {
if config == nil {
return fmt.Errorf("provider config cannot be nil")
}
if config.Name == "" {
return fmt.Errorf("provider name cannot be empty")
}
if config.Type == "" {
return fmt.Errorf("provider type cannot be empty")
}
if config.Config == nil {
return fmt.Errorf("provider config cannot be nil")
}
// Type-specific validation
switch config.Type {
case "oidc":
return f.validateOIDCConfig(config.Config)
case "ldap":
return f.validateLDAPConfig(config.Config)
case "saml":
return f.validateSAMLConfig(config.Config)
default:
return fmt.Errorf("unsupported provider type: %s", config.Type)
}
}
// validateOIDCConfig validates OIDC provider configuration
func (f *ProviderFactory) validateOIDCConfig(config map[string]interface{}) error {
if _, ok := config[ConfigFieldIssuer]; !ok {
return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldIssuer)
}
if _, ok := config[ConfigFieldClientID]; !ok {
return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldClientID)
}
return nil
}
// validateLDAPConfig validates LDAP provider configuration
func (f *ProviderFactory) validateLDAPConfig(config map[string]interface{}) error {
if _, ok := config["server"]; !ok {
return fmt.Errorf("LDAP provider requires 'server' field")
}
if _, ok := config["baseDN"]; !ok {
return fmt.Errorf("LDAP provider requires 'baseDN' field")
}
return nil
}
// validateSAMLConfig validates SAML provider configuration
func (f *ProviderFactory) validateSAMLConfig(config map[string]interface{}) error {
// TODO: Implement when SAML provider is available
return nil
}
// GetSupportedProviderTypes returns list of supported provider types
func (f *ProviderFactory) GetSupportedProviderTypes() []string {
return []string{ProviderTypeOIDC}
}

View File

@@ -1,312 +0,0 @@
package sts
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProviderFactory_CreateOIDCProvider(t *testing.T) {
factory := NewProviderFactory()
config := &ProviderConfig{
Name: "test-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
"clientSecret": "test-secret",
"jwksUri": "https://test-issuer.com/.well-known/jwks.json",
"scopes": []string{"openid", "profile", "email"},
},
}
provider, err := factory.CreateProvider(config)
require.NoError(t, err)
assert.NotNil(t, provider)
assert.Equal(t, "test-oidc", provider.Name())
}
// Note: Mock provider tests removed - mock providers are now test-only
// and not available through the production ProviderFactory
func TestProviderFactory_DisabledProvider(t *testing.T) {
factory := NewProviderFactory()
config := &ProviderConfig{
Name: "disabled-provider",
Type: "oidc",
Enabled: false,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
},
}
provider, err := factory.CreateProvider(config)
require.NoError(t, err)
assert.Nil(t, provider) // Should return nil for disabled providers
}
func TestProviderFactory_InvalidProviderType(t *testing.T) {
factory := NewProviderFactory()
config := &ProviderConfig{
Name: "invalid-provider",
Type: "unsupported-type",
Enabled: true,
Config: map[string]interface{}{},
}
provider, err := factory.CreateProvider(config)
assert.Error(t, err)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "unsupported provider type")
}
func TestProviderFactory_LoadMultipleProviders(t *testing.T) {
factory := NewProviderFactory()
configs := []*ProviderConfig{
{
Name: "oidc-provider",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://oidc-issuer.com",
"clientId": "oidc-client",
},
},
{
Name: "disabled-provider",
Type: "oidc",
Enabled: false,
Config: map[string]interface{}{
"issuer": "https://disabled-issuer.com",
"clientId": "disabled-client",
},
},
}
providers, err := factory.LoadProvidersFromConfig(configs)
require.NoError(t, err)
assert.Len(t, providers, 1) // Only enabled providers should be loaded
assert.Contains(t, providers, "oidc-provider")
assert.NotContains(t, providers, "disabled-provider")
}
func TestProviderFactory_ValidateOIDCConfig(t *testing.T) {
factory := NewProviderFactory()
t.Run("valid config", func(t *testing.T) {
config := &ProviderConfig{
Name: "valid-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://valid-issuer.com",
"clientId": "valid-client",
},
}
err := factory.ValidateProviderConfig(config)
assert.NoError(t, err)
})
t.Run("missing issuer", func(t *testing.T) {
config := &ProviderConfig{
Name: "invalid-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"clientId": "valid-client",
},
}
err := factory.ValidateProviderConfig(config)
assert.Error(t, err)
assert.Contains(t, err.Error(), "issuer")
})
t.Run("missing clientId", func(t *testing.T) {
config := &ProviderConfig{
Name: "invalid-oidc",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://valid-issuer.com",
},
}
err := factory.ValidateProviderConfig(config)
assert.Error(t, err)
assert.Contains(t, err.Error(), "clientId")
})
}
func TestProviderFactory_ConvertToStringSlice(t *testing.T) {
factory := NewProviderFactory()
t.Run("string slice", func(t *testing.T) {
input := []string{"a", "b", "c"}
result, err := factory.convertToStringSlice(input)
require.NoError(t, err)
assert.Equal(t, []string{"a", "b", "c"}, result)
})
t.Run("interface slice", func(t *testing.T) {
input := []interface{}{"a", "b", "c"}
result, err := factory.convertToStringSlice(input)
require.NoError(t, err)
assert.Equal(t, []string{"a", "b", "c"}, result)
})
t.Run("invalid type", func(t *testing.T) {
input := "not-a-slice"
result, err := factory.convertToStringSlice(input)
assert.Error(t, err)
assert.Nil(t, result)
})
}
func TestProviderFactory_ConfigConversionErrors(t *testing.T) {
factory := NewProviderFactory()
t.Run("invalid scopes type", func(t *testing.T) {
config := &ProviderConfig{
Name: "invalid-scopes",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
"scopes": "invalid-not-array", // Should be array
},
}
provider, err := factory.CreateProvider(config)
assert.Error(t, err)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "failed to convert scopes")
})
t.Run("invalid claimsMapping type", func(t *testing.T) {
config := &ProviderConfig{
Name: "invalid-claims",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
"claimsMapping": "invalid-not-map", // Should be map
},
}
provider, err := factory.CreateProvider(config)
assert.Error(t, err)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "failed to convert claimsMapping")
})
t.Run("invalid roleMapping type", func(t *testing.T) {
config := &ProviderConfig{
Name: "invalid-roles",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
"roleMapping": "invalid-not-map", // Should be map
},
}
provider, err := factory.CreateProvider(config)
assert.Error(t, err)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "failed to convert roleMapping")
})
}
func TestProviderFactory_ConvertToStringMap(t *testing.T) {
factory := NewProviderFactory()
t.Run("string map", func(t *testing.T) {
input := map[string]string{"key1": "value1", "key2": "value2"}
result, err := factory.convertToStringMap(input)
require.NoError(t, err)
assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
})
t.Run("interface map", func(t *testing.T) {
input := map[string]interface{}{"key1": "value1", "key2": "value2"}
result, err := factory.convertToStringMap(input)
require.NoError(t, err)
assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
})
t.Run("invalid type", func(t *testing.T) {
input := "not-a-map"
result, err := factory.convertToStringMap(input)
assert.Error(t, err)
assert.Nil(t, result)
})
}
func TestProviderFactory_GetSupportedProviderTypes(t *testing.T) {
factory := NewProviderFactory()
supportedTypes := factory.GetSupportedProviderTypes()
assert.Contains(t, supportedTypes, "oidc")
assert.Len(t, supportedTypes, 1) // Currently only OIDC is supported in production
}
func TestSTSService_LoadProvidersFromConfig(t *testing.T) {
stsConfig := &STSConfig{
TokenDuration: FlexibleDuration{3600 * time.Second},
MaxSessionLength: FlexibleDuration{43200 * time.Second},
Issuer: "test-issuer",
SigningKey: []byte("test-signing-key-32-characters-long"),
Providers: []*ProviderConfig{
{
Name: "test-provider",
Type: "oidc",
Enabled: true,
Config: map[string]interface{}{
"issuer": "https://test-issuer.com",
"clientId": "test-client",
},
},
},
}
stsService := NewSTSService()
err := stsService.Initialize(stsConfig)
require.NoError(t, err)
// Check that provider was loaded
assert.Len(t, stsService.providers, 1)
assert.Contains(t, stsService.providers, "test-provider")
assert.Equal(t, "test-provider", stsService.providers["test-provider"].Name())
}
func TestSTSService_NoProvidersConfig(t *testing.T) {
stsConfig := &STSConfig{
TokenDuration: FlexibleDuration{3600 * time.Second},
MaxSessionLength: FlexibleDuration{43200 * time.Second},
Issuer: "test-issuer",
SigningKey: []byte("test-signing-key-32-characters-long"),
// No providers configured
}
stsService := NewSTSService()
err := stsService.Initialize(stsConfig)
require.NoError(t, err)
// Should initialize successfully with no providers
assert.Len(t, stsService.providers, 0)
}

View File

@@ -1,193 +0,0 @@
package sts
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestSecurityIssuerToProviderMapping tests the security fix that ensures JWT tokens
// with specific issuer claims can only be validated by the provider registered for that issuer
func TestSecurityIssuerToProviderMapping(t *testing.T) {
ctx := context.Background()
// Create STS service with two mock providers
service := NewSTSService()
config := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{time.Hour * 12},
Issuer: "test-sts",
SigningKey: []byte("test-signing-key-32-characters-long"),
}
err := service.Initialize(config)
require.NoError(t, err)
// Set up mock trust policy validator
mockValidator := &MockTrustPolicyValidator{}
service.SetTrustPolicyValidator(mockValidator)
// Create two mock providers with different issuers
providerA := &MockIdentityProviderWithIssuer{
name: "provider-a",
issuer: "https://provider-a.com",
validTokens: map[string]bool{
"token-for-provider-a": true,
},
}
providerB := &MockIdentityProviderWithIssuer{
name: "provider-b",
issuer: "https://provider-b.com",
validTokens: map[string]bool{
"token-for-provider-b": true,
},
}
// Register both providers
err = service.RegisterProvider(providerA)
require.NoError(t, err)
err = service.RegisterProvider(providerB)
require.NoError(t, err)
// Create JWT tokens with specific issuer claims
tokenForProviderA := createTestJWT(t, "https://provider-a.com", "user-a")
tokenForProviderB := createTestJWT(t, "https://provider-b.com", "user-b")
t.Run("jwt_token_with_issuer_a_only_validated_by_provider_a", func(t *testing.T) {
// This should succeed - token has issuer A and provider A is registered
identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderA)
assert.NoError(t, err)
assert.NotNil(t, identity)
assert.Equal(t, "provider-a", provider.Name())
})
t.Run("jwt_token_with_issuer_b_only_validated_by_provider_b", func(t *testing.T) {
// This should succeed - token has issuer B and provider B is registered
identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderB)
assert.NoError(t, err)
assert.NotNil(t, identity)
assert.Equal(t, "provider-b", provider.Name())
})
t.Run("jwt_token_with_unregistered_issuer_fails", func(t *testing.T) {
// Create token with unregistered issuer
tokenWithUnknownIssuer := createTestJWT(t, "https://unknown-issuer.com", "user-x")
// This should fail - no provider registered for this issuer
identity, provider, err := service.validateWebIdentityToken(ctx, tokenWithUnknownIssuer)
assert.Error(t, err)
assert.Nil(t, identity)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "no identity provider registered for issuer: https://unknown-issuer.com")
})
t.Run("non_jwt_tokens_are_rejected", func(t *testing.T) {
// Non-JWT tokens should be rejected - no fallback mechanism exists for security
identity, provider, err := service.validateWebIdentityToken(ctx, "token-for-provider-a")
assert.Error(t, err)
assert.Nil(t, identity)
assert.Nil(t, provider)
assert.Contains(t, err.Error(), "web identity token must be a valid JWT token")
})
}
// createTestJWT creates a test JWT token with the specified issuer and subject
func createTestJWT(t *testing.T, issuer, subject string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": issuer,
"sub": subject,
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
return tokenString
}
// MockIdentityProviderWithIssuer is a mock provider that supports issuer mapping
type MockIdentityProviderWithIssuer struct {
name string
issuer string
validTokens map[string]bool
}
func (m *MockIdentityProviderWithIssuer) Name() string {
return m.name
}
func (m *MockIdentityProviderWithIssuer) GetIssuer() string {
return m.issuer
}
func (m *MockIdentityProviderWithIssuer) Initialize(config interface{}) error {
return nil
}
func (m *MockIdentityProviderWithIssuer) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
// For JWT tokens, parse and validate the token format
if len(token) > 50 && strings.Contains(token, ".") {
// This looks like a JWT - parse it to get the subject
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
if err != nil {
return nil, fmt.Errorf("invalid JWT token")
}
claims, ok := parsedToken.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid claims")
}
issuer, _ := claims["iss"].(string)
subject, _ := claims["sub"].(string)
// Verify the issuer matches what we expect
if issuer != m.issuer {
return nil, fmt.Errorf("token issuer %s does not match provider issuer %s", issuer, m.issuer)
}
return &providers.ExternalIdentity{
UserID: subject,
Email: subject + "@" + m.name + ".com",
Provider: m.name,
}, nil
}
// For non-JWT tokens, check our simple token list
if m.validTokens[token] {
return &providers.ExternalIdentity{
UserID: "test-user",
Email: "test@" + m.name + ".com",
Provider: m.name,
}, nil
}
return nil, fmt.Errorf("invalid token")
}
func (m *MockIdentityProviderWithIssuer) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
return &providers.ExternalIdentity{
UserID: userID,
Email: userID + "@" + m.name + ".com",
Provider: m.name,
}, nil
}
func (m *MockIdentityProviderWithIssuer) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
if m.validTokens[token] {
return &providers.TokenClaims{
Subject: "test-user",
Issuer: m.issuer,
}, nil
}
return nil, fmt.Errorf("invalid token")
}

View File

@@ -1,168 +0,0 @@
package sts
import (
"context"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// createSessionPolicyTestJWT creates a test JWT token for session policy tests
func createSessionPolicyTestJWT(t *testing.T, issuer, subject string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": issuer,
"sub": subject,
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
return tokenString
}
// TestAssumeRoleWithWebIdentity_SessionPolicy verifies inline session policies are preserved in tokens.
func TestAssumeRoleWithWebIdentity_SessionPolicy(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::example-bucket/*"}]}`
testToken := createSessionPolicyTestJWT(t, "test-issuer", "test-user")
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: testToken,
RoleSessionName: "test-session",
Policy: &sessionPolicy,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
require.NoError(t, err)
normalized, err := NormalizeSessionPolicy(sessionPolicy)
require.NoError(t, err)
assert.Equal(t, normalized, sessionInfo.SessionPolicy)
t.Run("should_succeed_without_session_policy", func(t *testing.T) {
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
require.NoError(t, err)
assert.Empty(t, sessionInfo.SessionPolicy)
})
}
// Test edge case scenarios for the Policy field handling
func TestAssumeRoleWithWebIdentity_SessionPolicy_EdgeCases(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
t.Run("malformed_json_policy_rejected", func(t *testing.T) {
malformedPolicy := `{"Version": "2012-10-17", "Statement": [` // Incomplete JSON
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
Policy: &malformedPolicy,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
assert.Error(t, err)
assert.Nil(t, response)
assert.Contains(t, err.Error(), "invalid session policy JSON")
})
t.Run("invalid_policy_document_rejected", func(t *testing.T) {
invalidPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow"}]}`
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
Policy: &invalidPolicy,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
assert.Error(t, err)
assert.Nil(t, response)
assert.Contains(t, err.Error(), "invalid session policy document")
})
t.Run("whitespace_policy_ignored", func(t *testing.T) {
whitespacePolicy := " \t\n "
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
Policy: &whitespacePolicy,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
require.NoError(t, err)
assert.Empty(t, sessionInfo.SessionPolicy)
})
}
// TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation verifies that the struct field exists and is optional.
func TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation(t *testing.T) {
request := &AssumeRoleWithWebIdentityRequest{}
assert.IsType(t, (*string)(nil), request.Policy,
"Policy field should be *string type for optional JSON policy")
assert.Nil(t, request.Policy,
"Policy field should default to nil (no session policy)")
policyValue := `{"Version": "2012-10-17"}`
request.Policy = &policyValue
assert.NotNil(t, request.Policy, "Should be able to assign policy value")
assert.Equal(t, policyValue, *request.Policy, "Policy value should be preserved")
}
// TestAssumeRoleWithCredentials_SessionPolicy verifies session policy support for credentials-based flow.
func TestAssumeRoleWithCredentials_SessionPolicy(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"filer:CreateEntry","Resource":"arn:aws:filer::path/user-docs/*"}]}`
request := &AssumeRoleWithCredentialsRequest{
RoleArn: "arn:aws:iam::role/TestRole",
Username: "testuser",
Password: "testpass",
RoleSessionName: "test-session",
ProviderName: "test-ldap",
Policy: &sessionPolicy,
}
response, err := service.AssumeRoleWithCredentials(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
require.NoError(t, err)
normalized, err := NormalizeSessionPolicy(sessionPolicy)
require.NoError(t, err)
assert.Equal(t, normalized, sessionInfo.SessionPolicy)
}

View File

@@ -879,21 +879,6 @@ func (s *STSService) calculateSessionDuration(durationSeconds *int64, tokenExpir
return duration
}
// extractSessionIdFromToken extracts session ID from JWT session token
func (s *STSService) extractSessionIdFromToken(sessionToken string) string {
// Validate JWT and extract session claims
claims, err := s.tokenGenerator.ValidateJWTWithClaims(sessionToken)
if err != nil {
// For test compatibility, also handle direct session IDs
if len(sessionToken) == 32 { // Typical session ID length
return sessionToken
}
return ""
}
return claims.SessionId
}
// validateAssumeRoleWithCredentialsRequest validates the credentials request parameters
func (s *STSService) validateAssumeRoleWithCredentialsRequest(request *AssumeRoleWithCredentialsRequest) error {
if request.RoleArn == "" {

View File

@@ -1,778 +0,0 @@
package sts
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// createSTSTestJWT creates a test JWT token for STS service tests
func createSTSTestJWT(t *testing.T, issuer, subject string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": issuer,
"sub": subject,
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
return tokenString
}
// TestSTSServiceInitialization tests STS service initialization
func TestSTSServiceInitialization(t *testing.T) {
tests := []struct {
name string
config *STSConfig
wantErr bool
}{
{
name: "valid config",
config: &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{time.Hour * 12},
Issuer: "seaweedfs-sts",
SigningKey: []byte("test-signing-key"),
},
wantErr: false,
},
{
name: "missing signing key",
config: &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
Issuer: "seaweedfs-sts",
},
wantErr: true,
},
{
name: "invalid token duration",
config: &STSConfig{
TokenDuration: FlexibleDuration{-time.Hour},
Issuer: "seaweedfs-sts",
SigningKey: []byte("test-key"),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := NewSTSService()
err := service.Initialize(tt.config)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.True(t, service.IsInitialized())
// Verify defaults if applicable
if tt.config.Issuer == "" {
assert.Equal(t, DefaultIssuer, service.Config.Issuer)
}
if tt.config.TokenDuration.Duration == 0 {
assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, service.Config.TokenDuration.Duration)
}
}
})
}
}
func TestSTSServiceDefaults(t *testing.T) {
service := NewSTSService()
config := &STSConfig{
SigningKey: []byte("test-signing-key"),
// Missing duration and issuer
}
err := service.Initialize(config)
assert.NoError(t, err)
assert.Equal(t, DefaultIssuer, config.Issuer)
assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, config.TokenDuration.Duration)
assert.Equal(t, time.Duration(DefaultMaxSessionLength)*time.Second, config.MaxSessionLength.Duration)
}
// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
func TestAssumeRoleWithWebIdentity(t *testing.T) {
service := setupTestSTSService(t)
tests := []struct {
name string
roleArn string
webIdentityToken string
sessionName string
durationSeconds *int64
wantErr bool
expectedSubject string
}{
{
name: "successful role assumption",
roleArn: "arn:aws:iam::role/TestRole",
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user-id"),
sessionName: "test-session",
durationSeconds: nil, // Use default
wantErr: false,
expectedSubject: "test-user-id",
},
{
name: "invalid web identity token",
roleArn: "arn:aws:iam::role/TestRole",
webIdentityToken: "invalid-token",
sessionName: "test-session",
wantErr: true,
},
{
name: "non-existent role",
roleArn: "arn:aws:iam::role/NonExistentRole",
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
sessionName: "test-session",
wantErr: true,
},
{
name: "custom session duration",
roleArn: "arn:aws:iam::role/TestRole",
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
sessionName: "test-session",
durationSeconds: int64Ptr(7200), // 2 hours
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: tt.roleArn,
WebIdentityToken: tt.webIdentityToken,
RoleSessionName: tt.sessionName,
DurationSeconds: tt.durationSeconds,
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, response)
} else {
assert.NoError(t, err)
assert.NotNil(t, response)
assert.NotNil(t, response.Credentials)
assert.NotNil(t, response.AssumedRoleUser)
// Verify credentials
creds := response.Credentials
assert.NotEmpty(t, creds.AccessKeyId)
assert.NotEmpty(t, creds.SecretAccessKey)
assert.NotEmpty(t, creds.SessionToken)
assert.True(t, creds.Expiration.After(time.Now()))
// Verify assumed role user
user := response.AssumedRoleUser
assert.Equal(t, tt.roleArn, user.AssumedRoleId)
assert.Contains(t, user.Arn, tt.sessionName)
if tt.expectedSubject != "" {
assert.Equal(t, tt.expectedSubject, user.Subject)
}
}
})
}
}
// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
func TestAssumeRoleWithLDAP(t *testing.T) {
service := setupTestSTSService(t)
tests := []struct {
name string
roleArn string
username string
password string
sessionName string
wantErr bool
}{
{
name: "successful LDAP role assumption",
roleArn: "arn:aws:iam::role/LDAPRole",
username: "testuser",
password: "testpass",
sessionName: "ldap-session",
wantErr: false,
},
{
name: "invalid LDAP credentials",
roleArn: "arn:aws:iam::role/LDAPRole",
username: "testuser",
password: "wrongpass",
sessionName: "ldap-session",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
request := &AssumeRoleWithCredentialsRequest{
RoleArn: tt.roleArn,
Username: tt.username,
Password: tt.password,
RoleSessionName: tt.sessionName,
ProviderName: "test-ldap",
}
response, err := service.AssumeRoleWithCredentials(ctx, request)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, response)
} else {
assert.NoError(t, err)
assert.NotNil(t, response)
assert.NotNil(t, response.Credentials)
}
})
}
}
// TestSessionTokenValidation tests session token validation
func TestSessionTokenValidation(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
// First, create a session
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
sessionToken := response.Credentials.SessionToken
tests := []struct {
name string
token string
wantErr bool
}{
{
name: "valid session token",
token: sessionToken,
wantErr: false,
},
{
name: "invalid session token",
token: "invalid-session-token",
wantErr: true,
},
{
name: "empty session token",
token: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
session, err := service.ValidateSessionToken(ctx, tt.token)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, session)
} else {
assert.NoError(t, err)
assert.NotNil(t, session)
assert.Equal(t, "test-session", session.SessionName)
assert.Equal(t, "arn:aws:iam::role/TestRole", session.RoleArn)
}
})
}
}
// TestSessionTokenPersistence tests that JWT tokens remain valid throughout their lifetime
// Note: In the stateless JWT design, tokens cannot be revoked and remain valid until expiration
func TestSessionTokenPersistence(t *testing.T) {
service := setupTestSTSService(t)
ctx := context.Background()
// Create a session first
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
sessionToken := response.Credentials.SessionToken
// Verify token is valid initially
session, err := service.ValidateSessionToken(ctx, sessionToken)
assert.NoError(t, err)
assert.NotNil(t, session)
assert.Equal(t, "test-session", session.SessionName)
// In a stateless JWT system, tokens remain valid throughout their lifetime
// Multiple validations should all succeed as long as the token hasn't expired
session2, err := service.ValidateSessionToken(ctx, sessionToken)
assert.NoError(t, err, "Token should remain valid in stateless system")
assert.NotNil(t, session2, "Session should be returned from JWT token")
assert.Equal(t, session.SessionId, session2.SessionId, "Session ID should be consistent")
}
// Helper functions
func setupTestSTSService(t *testing.T) *STSService {
service := NewSTSService()
config := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{time.Hour * 12},
Issuer: "test-sts",
SigningKey: []byte("test-signing-key-32-characters-long"),
}
err := service.Initialize(config)
require.NoError(t, err)
// Set up mock trust policy validator (required for STS testing)
mockValidator := &MockTrustPolicyValidator{}
service.SetTrustPolicyValidator(mockValidator)
// Register test providers
mockOIDCProvider := &MockIdentityProvider{
name: "test-oidc",
validTokens: map[string]*providers.TokenClaims{
createSTSTestJWT(t, "test-issuer", "test-user"): {
Subject: "test-user-id",
Issuer: "test-issuer",
Claims: map[string]interface{}{
"email": "test@example.com",
"name": "Test User",
},
},
},
}
mockLDAPProvider := &MockIdentityProvider{
name: "test-ldap",
validCredentials: map[string]string{
"testuser": "testpass",
},
}
service.RegisterProvider(mockOIDCProvider)
service.RegisterProvider(mockLDAPProvider)
return service
}
func int64Ptr(v int64) *int64 {
return &v
}
// Mock identity provider for testing
type MockIdentityProvider struct {
name string
validTokens map[string]*providers.TokenClaims
validCredentials map[string]string
}
func (m *MockIdentityProvider) Name() string {
return m.name
}
func (m *MockIdentityProvider) GetIssuer() string {
return "test-issuer" // This matches the issuer in the token claims
}
func (m *MockIdentityProvider) Initialize(config interface{}) error {
return nil
}
func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
// First try to parse as JWT token
if len(token) > 20 && strings.Count(token, ".") >= 2 {
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
if err == nil {
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
issuer, _ := claims["iss"].(string)
subject, _ := claims["sub"].(string)
// Verify the issuer matches what we expect
if issuer == "test-issuer" && subject != "" {
return &providers.ExternalIdentity{
UserID: subject,
Email: subject + "@test-domain.com",
DisplayName: "Test User " + subject,
Provider: m.name,
}, nil
}
}
}
}
// Handle legacy OIDC tokens (for backwards compatibility)
if claims, exists := m.validTokens[token]; exists {
email, _ := claims.GetClaimString("email")
name, _ := claims.GetClaimString("name")
return &providers.ExternalIdentity{
UserID: claims.Subject,
Email: email,
DisplayName: name,
Provider: m.name,
}, nil
}
// Handle LDAP credentials (username:password format)
if m.validCredentials != nil {
parts := strings.Split(token, ":")
if len(parts) == 2 {
username, password := parts[0], parts[1]
if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password {
return &providers.ExternalIdentity{
UserID: username,
Email: username + "@" + m.name + ".com",
DisplayName: "Test User " + username,
Provider: m.name,
}, nil
}
}
}
return nil, fmt.Errorf("unknown test token: %s", token)
}
func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
return &providers.ExternalIdentity{
UserID: userID,
Email: userID + "@" + m.name + ".com",
Provider: m.name,
}, nil
}
func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
if claims, exists := m.validTokens[token]; exists {
return claims, nil
}
return nil, fmt.Errorf("invalid token")
}
// TestSessionDurationCappedByTokenExpiration tests that session duration is capped by the source token's exp claim
func TestSessionDurationCappedByTokenExpiration(t *testing.T) {
service := NewSTSService()
config := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour}, // Default: 1 hour
MaxSessionLength: FlexibleDuration{time.Hour * 12},
Issuer: "test-sts",
SigningKey: []byte("test-signing-key-32-characters-long"),
}
err := service.Initialize(config)
require.NoError(t, err)
tests := []struct {
name string
durationSeconds *int64
tokenExpiration *time.Time
expectedMaxSeconds int64
description string
}{
{
name: "no token expiration - use default duration",
durationSeconds: nil,
tokenExpiration: nil,
expectedMaxSeconds: 3600, // 1 hour default
description: "When no token expiration is set, use the configured default duration",
},
{
name: "token expires before default duration",
durationSeconds: nil,
tokenExpiration: timePtr(time.Now().Add(30 * time.Minute)),
expectedMaxSeconds: 30 * 60, // 30 minutes
description: "When token expires in 30 min, session should be capped at 30 min",
},
{
name: "token expires after default duration - use default",
durationSeconds: nil,
tokenExpiration: timePtr(time.Now().Add(2 * time.Hour)),
expectedMaxSeconds: 3600, // 1 hour default, since it's less than 2 hour token expiry
description: "When token expires after default duration, use the default duration",
},
{
name: "requested duration shorter than token expiry",
durationSeconds: int64Ptr(1800), // 30 min requested
tokenExpiration: timePtr(time.Now().Add(time.Hour)),
expectedMaxSeconds: 1800, // 30 minutes as requested
description: "When requested duration is shorter than token expiry, use requested duration",
},
{
name: "requested duration longer than token expiry - cap at token expiry",
durationSeconds: int64Ptr(3600), // 1 hour requested
tokenExpiration: timePtr(time.Now().Add(15 * time.Minute)),
expectedMaxSeconds: 15 * 60, // Capped at 15 minutes
description: "When requested duration exceeds token expiry, cap at token expiry",
},
{
name: "GitLab CI short-lived token scenario",
durationSeconds: nil,
tokenExpiration: timePtr(time.Now().Add(5 * time.Minute)),
expectedMaxSeconds: 5 * 60, // 5 minutes
description: "GitLab CI job with 5 minute timeout should result in 5 minute session",
},
{
name: "already expired token - defense in depth",
durationSeconds: nil,
tokenExpiration: timePtr(time.Now().Add(-5 * time.Minute)), // Expired 5 minutes ago
expectedMaxSeconds: 60, // 1 minute minimum
description: "Already expired token should result in minimal 1 minute session",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
duration := service.calculateSessionDuration(tt.durationSeconds, tt.tokenExpiration)
// Allow 5 second tolerance for time calculations
maxExpected := time.Duration(tt.expectedMaxSeconds+5) * time.Second
minExpected := time.Duration(tt.expectedMaxSeconds-5) * time.Second
assert.GreaterOrEqual(t, duration, minExpected,
"%s: duration %v should be >= %v", tt.description, duration, minExpected)
assert.LessOrEqual(t, duration, maxExpected,
"%s: duration %v should be <= %v", tt.description, duration, maxExpected)
})
}
}
// TestAssumeRoleWithWebIdentityRespectsTokenExpiration tests end-to-end that session duration is capped
func TestAssumeRoleWithWebIdentityRespectsTokenExpiration(t *testing.T) {
service := NewSTSService()
config := &STSConfig{
TokenDuration: FlexibleDuration{time.Hour},
MaxSessionLength: FlexibleDuration{time.Hour * 12},
Issuer: "test-sts",
SigningKey: []byte("test-signing-key-32-characters-long"),
}
err := service.Initialize(config)
require.NoError(t, err)
// Set up mock trust policy validator
mockValidator := &MockTrustPolicyValidator{}
service.SetTrustPolicyValidator(mockValidator)
// Create a mock provider that returns tokens with short expiration
shortLivedTokenExpiration := time.Now().Add(10 * time.Minute)
mockProvider := &MockIdentityProviderWithExpiration{
name: "short-lived-issuer",
tokenExpiration: &shortLivedTokenExpiration,
}
service.RegisterProvider(mockProvider)
ctx := context.Background()
// Create a JWT token with short expiration
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": "short-lived-issuer",
"sub": "test-user",
"aud": "test-client",
"exp": shortLivedTokenExpiration.Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: tokenString,
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
// Verify the session expires at or before the token expiration
// Allow 5 second tolerance
assert.True(t, response.Credentials.Expiration.Before(shortLivedTokenExpiration.Add(5*time.Second)),
"Session expiration (%v) should not exceed token expiration (%v)",
response.Credentials.Expiration, shortLivedTokenExpiration)
}
// MockIdentityProviderWithExpiration is a mock provider that returns tokens with configurable expiration
type MockIdentityProviderWithExpiration struct {
name string
tokenExpiration *time.Time
}
func (m *MockIdentityProviderWithExpiration) Name() string {
return m.name
}
func (m *MockIdentityProviderWithExpiration) GetIssuer() string {
return m.name
}
func (m *MockIdentityProviderWithExpiration) Initialize(config interface{}) error {
return nil
}
func (m *MockIdentityProviderWithExpiration) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
// Parse the token to get subject
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := parsedToken.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("invalid claims")
}
subject, _ := claims["sub"].(string)
identity := &providers.ExternalIdentity{
UserID: subject,
Email: subject + "@example.com",
DisplayName: "Test User",
Provider: m.name,
TokenExpiration: m.tokenExpiration,
}
return identity, nil
}
func (m *MockIdentityProviderWithExpiration) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
return &providers.ExternalIdentity{
UserID: userID,
Provider: m.name,
}, nil
}
func (m *MockIdentityProviderWithExpiration) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
claims := &providers.TokenClaims{
Subject: "test-user",
Issuer: m.name,
}
if m.tokenExpiration != nil {
claims.ExpiresAt = *m.tokenExpiration
}
return claims, nil
}
func timePtr(t time.Time) *time.Time {
return &t
}
// TestAssumeRoleWithWebIdentity_PreservesAttributes tests that attributes from the identity provider
// are correctly propagated to the session token's request context
func TestAssumeRoleWithWebIdentity_PreservesAttributes(t *testing.T) {
service := setupTestSTSService(t)
// Create a mock provider that returns a user with attributes
mockProvider := &MockIdentityProviderWithAttributes{
name: "attr-provider",
attributes: map[string]string{
"preferred_username": "my-user",
"department": "engineering",
"project": "seaweedfs",
},
}
service.RegisterProvider(mockProvider)
// Create a valid JWT token for the provider
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"iss": "attr-provider",
"sub": "test-user-id",
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
})
tokenString, err := token.SignedString([]byte("test-signing-key"))
require.NoError(t, err)
ctx := context.Background()
request := &AssumeRoleWithWebIdentityRequest{
RoleArn: "arn:aws:iam::role/TestRole",
WebIdentityToken: tokenString,
RoleSessionName: "test-session",
}
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
require.NoError(t, err)
require.NotNil(t, response)
// Validate the session token to check claims
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
require.NoError(t, err)
// Check that attributes are present in RequestContext
require.NotNil(t, sessionInfo.RequestContext, "RequestContext should not be nil")
assert.Equal(t, "my-user", sessionInfo.RequestContext["preferred_username"])
assert.Equal(t, "engineering", sessionInfo.RequestContext["department"])
assert.Equal(t, "seaweedfs", sessionInfo.RequestContext["project"])
// Check standard claims are also present
assert.Equal(t, "test-user-id", sessionInfo.RequestContext["sub"])
assert.Equal(t, "test@example.com", sessionInfo.RequestContext["email"])
assert.Equal(t, "Test User", sessionInfo.RequestContext["name"])
}
// MockIdentityProviderWithAttributes is a mock provider that returns configured attributes
type MockIdentityProviderWithAttributes struct {
name string
attributes map[string]string
}
func (m *MockIdentityProviderWithAttributes) Name() string {
return m.name
}
func (m *MockIdentityProviderWithAttributes) GetIssuer() string {
return m.name
}
func (m *MockIdentityProviderWithAttributes) Initialize(config interface{}) error {
return nil
}
func (m *MockIdentityProviderWithAttributes) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
return &providers.ExternalIdentity{
UserID: "test-user-id",
Email: "test@example.com",
DisplayName: "Test User",
Provider: m.name,
Attributes: m.attributes,
}, nil
}
func (m *MockIdentityProviderWithAttributes) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
return nil, nil
}
func (m *MockIdentityProviderWithAttributes) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
return &providers.TokenClaims{
Subject: "test-user-id",
Issuer: m.name,
}, nil
}

View File

@@ -1,53 +1,4 @@
package sts
import (
"context"
"fmt"
"strings"
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
)
// MockTrustPolicyValidator is a simple mock for testing STS functionality
type MockTrustPolicyValidator struct{}
// ValidateTrustPolicyForWebIdentity allows valid JWT test tokens for STS testing
func (m *MockTrustPolicyValidator) ValidateTrustPolicyForWebIdentity(ctx context.Context, roleArn string, webIdentityToken string, durationSeconds *int64) error {
// Reject non-existent roles for testing
if strings.Contains(roleArn, "NonExistentRole") {
return fmt.Errorf("trust policy validation failed: role does not exist")
}
// For STS unit tests, allow JWT tokens that look valid (contain dots for JWT structure)
// In real implementation, this would validate against actual trust policies
if len(webIdentityToken) > 20 && strings.Count(webIdentityToken, ".") >= 2 {
// This appears to be a JWT token - allow it for testing
return nil
}
// Legacy support for specific test tokens during migration
if webIdentityToken == "valid_test_token" || webIdentityToken == "valid-oidc-token" {
return nil
}
// Reject invalid tokens
if webIdentityToken == "invalid_token" || webIdentityToken == "expired_token" || webIdentityToken == "invalid-token" {
return fmt.Errorf("trust policy denies token")
}
return nil
}
// ValidateTrustPolicyForCredentials allows valid test identities for STS testing
func (m *MockTrustPolicyValidator) ValidateTrustPolicyForCredentials(ctx context.Context, roleArn string, identity *providers.ExternalIdentity) error {
// Reject non-existent roles for testing
if strings.Contains(roleArn, "NonExistentRole") {
return fmt.Errorf("trust policy validation failed: role does not exist")
}
// For STS unit tests, allow test identities
if identity != nil && identity.UserID != "" {
return nil
}
return fmt.Errorf("invalid identity for role assumption")
}