chore: remove ~50k lines of unreachable dead code (#8913)
* chore: remove unreachable dead code across the codebase Remove ~50,000 lines of unreachable code identified by static analysis. Major removals: - weed/filer/redis_lua: entire unused Redis Lua filer store implementation - weed/wdclient/net2, resource_pool: unused connection/resource pool packages - weed/plugin/worker/lifecycle: unused lifecycle plugin worker - weed/s3api: unused S3 policy templates, presigned URL IAM, streaming copy, multipart IAM, key rotation, and various SSE helper functions - weed/mq/kafka: unused partition mapping, compression, schema, and protocol functions - weed/mq/offset: unused SQL storage and migration code - weed/worker: unused registry, task, and monitoring functions - weed/query: unused SQL engine, parquet scanner, and type functions - weed/shell: unused EC proportional rebalance functions - weed/storage/erasure_coding/distribution: unused distribution analysis functions - Individual unreachable functions removed from 150+ files across admin, credential, filer, iam, kms, mount, mq, operation, pb, s3api, server, shell, storage, topology, and util packages * fix(s3): reset shared memory store in IAM test to prevent flaky failure TestLoadIAMManagerFromConfig_EmptyConfigWithFallbackKey was flaky because the MemoryStore credential backend is a singleton registered via init(). Earlier tests that create anonymous identities pollute the shared store, causing LookupAnonymous() to unexpectedly return true. Fix by calling Reset() on the memory store before the test runs. * style: run gofmt on changed files * fix: restore KMS functions used by integration tests * fix(plugin): prevent panic on send to closed worker session channel The Plugin.sendToWorker method could panic with "send on closed channel" when a worker disconnected while a message was being sent. The race was between streamSession.close() closing the outgoing channel and sendToWorker writing to it concurrently. Add a done channel to streamSession that is closed before the outgoing channel, and check it in sendToWorker's select to safely detect closed sessions without panicking.
This commit is contained in:
@@ -37,11 +37,6 @@ func GenerateRandomString(length int, charset string) (string, error) {
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// GenerateAccessKeyId generates a new access key ID.
|
||||
func GenerateAccessKeyId() (string, error) {
|
||||
return GenerateRandomString(AccessKeyIdLength, CharsetUpper)
|
||||
}
|
||||
|
||||
// GenerateSecretAccessKey generates a new secret access key.
|
||||
func GenerateSecretAccessKey() (string, error) {
|
||||
return GenerateRandomString(SecretAccessKeyLength, Charset)
|
||||
@@ -179,11 +174,3 @@ func MapToIdentitiesAction(action string) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// MaskAccessKey masks an access key for logging, showing only the first 4 characters.
|
||||
func MaskAccessKey(accessKeyId string) string {
|
||||
if len(accessKeyId) > 4 {
|
||||
return accessKeyId[:4] + "***"
|
||||
}
|
||||
return accessKeyId
|
||||
}
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
package iam
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
input := "test"
|
||||
result := Hash(&input)
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Len(t, result, 40) // SHA1 hex is 40 chars
|
||||
|
||||
// Same input should produce same hash
|
||||
result2 := Hash(&input)
|
||||
assert.Equal(t, result, result2)
|
||||
|
||||
// Different input should produce different hash
|
||||
different := "different"
|
||||
result3 := Hash(&different)
|
||||
assert.NotEqual(t, result, result3)
|
||||
}
|
||||
|
||||
func TestGenerateRandomString(t *testing.T) {
|
||||
// Valid generation
|
||||
result, err := GenerateRandomString(10, CharsetUpper)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 10)
|
||||
|
||||
// Different calls should produce different results (with high probability)
|
||||
result2, err := GenerateRandomString(10, CharsetUpper)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, result, result2)
|
||||
|
||||
// Invalid length
|
||||
_, err = GenerateRandomString(0, CharsetUpper)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = GenerateRandomString(-1, CharsetUpper)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Empty charset
|
||||
_, err = GenerateRandomString(10, "")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGenerateAccessKeyId(t *testing.T) {
|
||||
keyId, err := GenerateAccessKeyId()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, keyId, AccessKeyIdLength)
|
||||
}
|
||||
|
||||
func TestGenerateSecretAccessKey(t *testing.T) {
|
||||
secretKey, err := GenerateSecretAccessKey()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, secretKey, SecretAccessKeyLength)
|
||||
}
|
||||
|
||||
func TestGenerateSecretAccessKey_URLSafe(t *testing.T) {
|
||||
// Generate multiple keys to increase probability of catching unsafe chars
|
||||
for i := 0; i < 100; i++ {
|
||||
secretKey, err := GenerateSecretAccessKey()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify no URL-unsafe characters that would cause authentication issues
|
||||
assert.NotContains(t, secretKey, "/", "Secret key should not contain /")
|
||||
assert.NotContains(t, secretKey, "+", "Secret key should not contain +")
|
||||
|
||||
// Verify only expected characters are present
|
||||
for _, char := range secretKey {
|
||||
assert.Contains(t, Charset, string(char), "Secret key contains unexpected character: %c", char)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSlicesEqual(t *testing.T) {
|
||||
tests := []struct {
|
||||
a []string
|
||||
b []string
|
||||
expected bool
|
||||
}{
|
||||
{[]string{"a", "b", "c"}, []string{"a", "b", "c"}, true},
|
||||
{[]string{"c", "b", "a"}, []string{"a", "b", "c"}, true}, // Order independent
|
||||
{[]string{"a", "b"}, []string{"a", "b", "c"}, false},
|
||||
{[]string{}, []string{}, true},
|
||||
{nil, nil, true},
|
||||
{[]string{"a"}, []string{"b"}, false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := StringSlicesEqual(test.a, test.b)
|
||||
assert.Equal(t, test.expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToStatementAction(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{StatementActionAdmin, s3_constants.ACTION_ADMIN},
|
||||
{StatementActionWrite, s3_constants.ACTION_WRITE},
|
||||
{StatementActionRead, s3_constants.ACTION_READ},
|
||||
{StatementActionList, s3_constants.ACTION_LIST},
|
||||
{StatementActionDelete, s3_constants.ACTION_DELETE_BUCKET},
|
||||
// Test fine-grained S3 action mappings (Issue #7864)
|
||||
{"DeleteObject", s3_constants.ACTION_WRITE},
|
||||
{"s3:DeleteObject", s3_constants.ACTION_WRITE},
|
||||
{"PutObject", s3_constants.ACTION_WRITE},
|
||||
{"s3:PutObject", s3_constants.ACTION_WRITE},
|
||||
{"GetObject", s3_constants.ACTION_READ},
|
||||
{"s3:GetObject", s3_constants.ACTION_READ},
|
||||
{"ListBucket", s3_constants.ACTION_LIST},
|
||||
{"s3:ListBucket", s3_constants.ACTION_LIST},
|
||||
{"PutObjectAcl", s3_constants.ACTION_WRITE_ACP},
|
||||
{"s3:PutObjectAcl", s3_constants.ACTION_WRITE_ACP},
|
||||
{"GetObjectAcl", s3_constants.ACTION_READ_ACP},
|
||||
{"s3:GetObjectAcl", s3_constants.ACTION_READ_ACP},
|
||||
{"unknown", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := MapToStatementAction(test.input)
|
||||
assert.Equal(t, test.expected, result, "Failed for input: %s", test.input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToIdentitiesAction(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{s3_constants.ACTION_ADMIN, StatementActionAdmin},
|
||||
{s3_constants.ACTION_WRITE, StatementActionWrite},
|
||||
{s3_constants.ACTION_READ, StatementActionRead},
|
||||
{s3_constants.ACTION_LIST, StatementActionList},
|
||||
{s3_constants.ACTION_DELETE_BUCKET, StatementActionDelete},
|
||||
{"unknown", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := MapToIdentitiesAction(test.input)
|
||||
assert.Equal(t, test.expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaskAccessKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"AKIAIOSFODNN7EXAMPLE", "AKIA***"},
|
||||
{"AKIA", "AKIA"},
|
||||
{"AKI", "AKI"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := MaskAccessKey(test.input)
|
||||
assert.Equal(t, test.expected, result)
|
||||
}
|
||||
}
|
||||
@@ -202,32 +202,6 @@ func (m *IAMManager) getFilerAddress() string {
|
||||
return "" // Fallback to empty string if no provider is set
|
||||
}
|
||||
|
||||
// createRoleStore creates a role store based on configuration
|
||||
func (m *IAMManager) createRoleStore(config *RoleStoreConfig) (RoleStore, error) {
|
||||
if config == nil {
|
||||
// Default to generic cached filer role store when no config provided
|
||||
return NewGenericCachedRoleStore(nil, nil)
|
||||
}
|
||||
|
||||
switch config.StoreType {
|
||||
case "", "filer":
|
||||
// Check if caching is explicitly disabled
|
||||
if config.StoreConfig != nil {
|
||||
if noCache, ok := config.StoreConfig["noCache"].(bool); ok && noCache {
|
||||
return NewFilerRoleStore(config.StoreConfig, nil)
|
||||
}
|
||||
}
|
||||
// Default to generic cached filer store for better performance
|
||||
return NewGenericCachedRoleStore(config.StoreConfig, nil)
|
||||
case "cached-filer", "generic-cached":
|
||||
return NewGenericCachedRoleStore(config.StoreConfig, nil)
|
||||
case "memory":
|
||||
return NewMemoryRoleStore(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported role store type: %s", config.StoreType)
|
||||
}
|
||||
}
|
||||
|
||||
// createRoleStoreWithProvider creates a role store with a filer address provider function
|
||||
func (m *IAMManager) createRoleStoreWithProvider(config *RoleStoreConfig, filerAddressProvider func() string) (RoleStore, error) {
|
||||
if config == nil {
|
||||
|
||||
@@ -388,157 +388,3 @@ type CachedFilerRoleStoreConfig struct {
|
||||
ListTTL string `json:"listTtl,omitempty"` // e.g., "1m", "30s"
|
||||
MaxCacheSize int `json:"maxCacheSize,omitempty"` // Maximum number of cached roles
|
||||
}
|
||||
|
||||
// NewCachedFilerRoleStore creates a new cached filer-based role store
|
||||
func NewCachedFilerRoleStore(config map[string]interface{}) (*CachedFilerRoleStore, error) {
|
||||
// Create underlying filer store
|
||||
filerStore, err := NewFilerRoleStore(config, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create filer role store: %w", err)
|
||||
}
|
||||
|
||||
// Parse cache configuration with defaults
|
||||
cacheTTL := 5 * time.Minute // Default 5 minutes for role cache
|
||||
listTTL := 1 * time.Minute // Default 1 minute for list cache
|
||||
maxCacheSize := 1000 // Default max 1000 cached roles
|
||||
|
||||
if config != nil {
|
||||
if ttlStr, ok := config["ttl"].(string); ok && ttlStr != "" {
|
||||
if parsed, err := time.ParseDuration(ttlStr); err == nil {
|
||||
cacheTTL = parsed
|
||||
}
|
||||
}
|
||||
if listTTLStr, ok := config["listTtl"].(string); ok && listTTLStr != "" {
|
||||
if parsed, err := time.ParseDuration(listTTLStr); err == nil {
|
||||
listTTL = parsed
|
||||
}
|
||||
}
|
||||
if maxSize, ok := config["maxCacheSize"].(int); ok && maxSize > 0 {
|
||||
maxCacheSize = maxSize
|
||||
}
|
||||
}
|
||||
|
||||
// Create ccache instances with appropriate configurations
|
||||
pruneCount := int64(maxCacheSize) >> 3
|
||||
if pruneCount <= 0 {
|
||||
pruneCount = 100
|
||||
}
|
||||
|
||||
store := &CachedFilerRoleStore{
|
||||
filerStore: filerStore,
|
||||
cache: ccache.New(ccache.Configure().MaxSize(int64(maxCacheSize)).ItemsToPrune(uint32(pruneCount))),
|
||||
listCache: ccache.New(ccache.Configure().MaxSize(100).ItemsToPrune(10)), // Smaller cache for lists
|
||||
ttl: cacheTTL,
|
||||
listTTL: listTTL,
|
||||
}
|
||||
|
||||
glog.V(2).Infof("Initialized CachedFilerRoleStore with TTL %v, List TTL %v, Max Cache Size %d",
|
||||
cacheTTL, listTTL, maxCacheSize)
|
||||
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// StoreRole stores a role definition and invalidates the cache
|
||||
func (c *CachedFilerRoleStore) StoreRole(ctx context.Context, filerAddress string, roleName string, role *RoleDefinition) error {
|
||||
// Store in filer
|
||||
err := c.filerStore.StoreRole(ctx, filerAddress, roleName, role)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate cache entries
|
||||
c.cache.Delete(roleName)
|
||||
c.listCache.Clear() // Invalidate list cache
|
||||
|
||||
glog.V(3).Infof("Stored and invalidated cache for role %s", roleName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRole retrieves a role definition with caching
|
||||
func (c *CachedFilerRoleStore) GetRole(ctx context.Context, filerAddress string, roleName string) (*RoleDefinition, error) {
|
||||
// Try to get from cache first
|
||||
item := c.cache.Get(roleName)
|
||||
if item != nil {
|
||||
// Cache hit - return cached role (DO NOT extend TTL)
|
||||
role := item.Value().(*RoleDefinition)
|
||||
glog.V(4).Infof("Cache hit for role %s", roleName)
|
||||
return copyRoleDefinition(role), nil
|
||||
}
|
||||
|
||||
// Cache miss - fetch from filer
|
||||
glog.V(4).Infof("Cache miss for role %s, fetching from filer", roleName)
|
||||
role, err := c.filerStore.GetRole(ctx, filerAddress, roleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache the result with TTL
|
||||
c.cache.Set(roleName, copyRoleDefinition(role), c.ttl)
|
||||
glog.V(3).Infof("Cached role %s with TTL %v", roleName, c.ttl)
|
||||
return role, nil
|
||||
}
|
||||
|
||||
// ListRoles lists all role names with caching
|
||||
func (c *CachedFilerRoleStore) ListRoles(ctx context.Context, filerAddress string) ([]string, error) {
|
||||
// Use a constant key for the role list cache
|
||||
const listCacheKey = "role_list"
|
||||
|
||||
// Try to get from list cache first
|
||||
item := c.listCache.Get(listCacheKey)
|
||||
if item != nil {
|
||||
// Cache hit - return cached list (DO NOT extend TTL)
|
||||
roles := item.Value().([]string)
|
||||
glog.V(4).Infof("List cache hit, returning %d roles", len(roles))
|
||||
return append([]string(nil), roles...), nil // Return a copy
|
||||
}
|
||||
|
||||
// Cache miss - fetch from filer
|
||||
glog.V(4).Infof("List cache miss, fetching from filer")
|
||||
roles, err := c.filerStore.ListRoles(ctx, filerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Cache the result with TTL (store a copy)
|
||||
rolesCopy := append([]string(nil), roles...)
|
||||
c.listCache.Set(listCacheKey, rolesCopy, c.listTTL)
|
||||
glog.V(3).Infof("Cached role list with %d entries, TTL %v", len(roles), c.listTTL)
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// DeleteRole deletes a role definition and invalidates the cache
|
||||
func (c *CachedFilerRoleStore) DeleteRole(ctx context.Context, filerAddress string, roleName string) error {
|
||||
// Delete from filer
|
||||
err := c.filerStore.DeleteRole(ctx, filerAddress, roleName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Invalidate cache entries
|
||||
c.cache.Delete(roleName)
|
||||
c.listCache.Clear() // Invalidate list cache
|
||||
|
||||
glog.V(3).Infof("Deleted and invalidated cache for role %s", roleName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearCache clears all cached entries (for testing or manual cache invalidation)
|
||||
func (c *CachedFilerRoleStore) ClearCache() {
|
||||
c.cache.Clear()
|
||||
c.listCache.Clear()
|
||||
glog.V(2).Infof("Cleared all role cache entries")
|
||||
}
|
||||
|
||||
// GetCacheStats returns cache statistics
|
||||
func (c *CachedFilerRoleStore) GetCacheStats() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"roleCache": map[string]interface{}{
|
||||
"size": c.cache.ItemCount(),
|
||||
"ttl": c.ttl.String(),
|
||||
},
|
||||
"listCache": map[string]interface{}{
|
||||
"size": c.listCache.ItemCount(),
|
||||
"ttl": c.listTTL.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,687 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConditionSetOperators(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
t.Run("ForAnyValue:StringEquals", func(t *testing.T) {
|
||||
trustPolicy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowOIDC",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAnyValue:StringEquals": {
|
||||
"oidc:roles": []string{"Dev.SeaweedFS.TestBucket.ReadWrite", "Dev.SeaweedFS.Admin"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Match: Admin is in the requested roles
|
||||
evalCtxMatch := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"Dev.SeaweedFS.Admin", "OtherRole"},
|
||||
},
|
||||
}
|
||||
resultMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxMatch)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultMatch.Effect)
|
||||
|
||||
// No Match
|
||||
evalCtxNoMatch := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"OtherRole1", "OtherRole2"},
|
||||
},
|
||||
}
|
||||
resultNoMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxNoMatch)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultNoMatch.Effect)
|
||||
|
||||
// No Match: Empty context for ForAnyValue (should deny)
|
||||
evalCtxEmpty := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{},
|
||||
},
|
||||
}
|
||||
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicy, evalCtxEmpty)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultEmpty.Effect, "ForAnyValue should deny when context is empty")
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:StringEquals", func(t *testing.T) {
|
||||
trustPolicyAll := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowOIDCAll",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:StringEquals": {
|
||||
"oidc:roles": []string{"RoleA", "RoleB", "RoleC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Match: All requested roles ARE in the allowed set
|
||||
evalCtxAllMatch := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"RoleA", "RoleB"},
|
||||
},
|
||||
}
|
||||
resultAllMatch, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxAllMatch)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultAllMatch.Effect)
|
||||
|
||||
// Fail: RoleD is NOT in the allowed set
|
||||
evalCtxAllFail := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"RoleA", "RoleD"},
|
||||
},
|
||||
}
|
||||
resultAllFail, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxAllFail)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultAllFail.Effect)
|
||||
|
||||
// Vacuously true: Request has NO roles
|
||||
evalCtxEmpty := &EvaluationContext{
|
||||
Principal: "web-identity-user",
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{},
|
||||
},
|
||||
}
|
||||
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), trustPolicyAll, evalCtxEmpty)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultEmpty.Effect)
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:NumericEqualsVacuouslyTrue", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowNumericAll",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:NumericEquals": {
|
||||
"aws:MultiFactorAuthAge": []string{"3600", "7200"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Vacuously true: Request has NO MFA age info
|
||||
evalCtxEmpty := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:MultiFactorAuthAge": []string{},
|
||||
},
|
||||
}
|
||||
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when numeric context is empty for ForAllValues")
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:BoolVacuouslyTrue", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowBoolAll",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:Bool": {
|
||||
"aws:SecureTransport": "true",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Vacuously true
|
||||
evalCtxEmpty := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SecureTransport": []interface{}{},
|
||||
},
|
||||
}
|
||||
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when bool context is empty for ForAllValues")
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:DateVacuouslyTrue", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowDateAll",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:DateGreaterThan": {
|
||||
"aws:CurrentTime": "2020-01-01T00:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Vacuously true
|
||||
evalCtxEmpty := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:CurrentTime": []interface{}{},
|
||||
},
|
||||
}
|
||||
resultEmpty, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxEmpty)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultEmpty.Effect, "Should allow when date context is empty for ForAllValues")
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:DateWithLabelsAsStrings", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowDateStrings",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:DateGreaterThan": {
|
||||
"aws:CurrentTime": "2020-01-01T00:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
evalCtx := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:CurrentTime": []string{"2021-01-01T00:00:00Z", "2022-01-01T00:00:00Z"},
|
||||
},
|
||||
}
|
||||
result, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, result.Effect, "Should allow when date context is a slice of strings")
|
||||
})
|
||||
|
||||
t.Run("ForAllValues:BoolWithLabelsAsStrings", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowBoolStrings",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:Bool": {
|
||||
"aws:SecureTransport": "true",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
evalCtx := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SecureTransport": []string{"true", "true"},
|
||||
},
|
||||
}
|
||||
result, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, result.Effect, "Should allow when bool context is a slice of strings")
|
||||
})
|
||||
|
||||
t.Run("StringEqualsIgnoreCaseWithVariable", func(t *testing.T) {
|
||||
policyDoc := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowVar",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"arn:aws:s3:::bucket/*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"StringEqualsIgnoreCase": {
|
||||
"s3:prefix": "${aws:username}/",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "var-policy", policyDoc)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtx := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/ALICE/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"s3:prefix": "ALICE/",
|
||||
"aws:username": "alice",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"var-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, result.Effect, "Should allow when variable expands and matches case-insensitively")
|
||||
})
|
||||
|
||||
t.Run("StringLike:CaseSensitivity", func(t *testing.T) {
|
||||
policyDoc := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowCaseSensitiveLike",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"arn:aws:s3:::bucket/*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"StringLike": {
|
||||
"s3:prefix": "Project/*",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "like-policy", policyDoc)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Match: Case sensitive match
|
||||
evalCtxMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/Project/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"s3:prefix": "Project/data",
|
||||
},
|
||||
}
|
||||
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"like-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultMatch.Effect, "Should allow when case matches exactly")
|
||||
|
||||
// Fail: Case insensitive match (should fail for StringLike)
|
||||
evalCtxFail := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/project/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"s3:prefix": "project/data", // lowercase 'p'
|
||||
},
|
||||
}
|
||||
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"like-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when case does not match for StringLike")
|
||||
})
|
||||
|
||||
t.Run("NumericNotEquals:Logic", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "DenySpecificAges",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:NumericNotEquals": {
|
||||
"aws:MultiFactorAuthAge": []string{"3600", "7200"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "numeric-not-equals-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fail: One age matches an excluded value (3600)
|
||||
evalCtxFail := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:MultiFactorAuthAge": []string{"3600", "1800"},
|
||||
},
|
||||
}
|
||||
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"numeric-not-equals-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when one age matches an excluded value")
|
||||
|
||||
// Pass: No age matches any excluded value
|
||||
evalCtxPass := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:MultiFactorAuthAge": []string{"1800", "900"},
|
||||
},
|
||||
}
|
||||
resultPass, err := engine.Evaluate(context.Background(), "", evalCtxPass, []string{"numeric-not-equals-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultPass.Effect, "Should allow when no age matches excluded values")
|
||||
})
|
||||
|
||||
t.Run("DateNotEquals:Logic", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "DenySpecificTimes",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:DateNotEquals": {
|
||||
"aws:CurrentTime": []string{"2024-01-01T00:00:00Z", "2024-01-02T00:00:00Z"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "date-not-equals-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fail: One time matches an excluded value
|
||||
evalCtxFail := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:CurrentTime": []string{"2024-01-01T00:00:00Z", "2024-01-03T00:00:00Z"},
|
||||
},
|
||||
}
|
||||
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"date-not-equals-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultFail.Effect, "Should deny when one date matches an excluded value")
|
||||
})
|
||||
|
||||
t.Run("IpAddress:SetOperators", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowSpecificIPs",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:IpAddress": {
|
||||
"aws:SourceIp": []string{"192.168.1.0/24", "10.0.0.1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "ip-set-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Match: All source IPs are in allowed ranges
|
||||
evalCtxMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": []string{"192.168.1.10", "10.0.0.1"},
|
||||
},
|
||||
}
|
||||
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-set-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultMatch.Effect)
|
||||
|
||||
// Fail: One source IP is NOT in allowed ranges
|
||||
evalCtxFail := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": []string{"192.168.1.10", "172.16.0.1"},
|
||||
},
|
||||
}
|
||||
resultFail, err := engine.Evaluate(context.Background(), "", evalCtxFail, []string{"ip-set-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultFail.Effect)
|
||||
|
||||
// ForAnyValue: IPAddress
|
||||
policyAny := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowAnySpecificIPs",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAnyValue:IpAddress": {
|
||||
"aws:SourceIp": []string{"192.168.1.0/24"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err = engine.AddPolicy("", "ip-any-policy", policyAny)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtxAnyMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": []string{"192.168.1.10", "172.16.0.1"},
|
||||
},
|
||||
}
|
||||
resultAnyMatch, err := engine.Evaluate(context.Background(), "", evalCtxAnyMatch, []string{"ip-any-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultAnyMatch.Effect)
|
||||
})
|
||||
|
||||
t.Run("IpAddress:SingleStringValue", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowSingleIP",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"IpAddress": {
|
||||
"aws:SourceIp": "192.168.1.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "ip-single-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtxMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "192.168.1.1",
|
||||
},
|
||||
}
|
||||
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-single-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultMatch.Effect)
|
||||
|
||||
evalCtxNoMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "10.0.0.1",
|
||||
},
|
||||
}
|
||||
resultNoMatch, err := engine.Evaluate(context.Background(), "", evalCtxNoMatch, []string{"ip-single-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultNoMatch.Effect)
|
||||
})
|
||||
|
||||
t.Run("Bool:StringSlicePolicyValues", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowWithBoolStrings",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"Bool": {
|
||||
"aws:SecureTransport": []string{"true", "false"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "bool-string-slice-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtx := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SecureTransport": "true",
|
||||
},
|
||||
}
|
||||
result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"bool-string-slice-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, result.Effect)
|
||||
})
|
||||
|
||||
t.Run("StringEqualsIgnoreCase:StringSlicePolicyValues", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowWithIgnoreCaseStrings",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"StringEqualsIgnoreCase": {
|
||||
"s3:x-amz-server-side-encryption": []string{"AES256", "aws:kms"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "string-ignorecase-slice-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtx := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"s3:x-amz-server-side-encryption": "aes256",
|
||||
},
|
||||
}
|
||||
result, err := engine.Evaluate(context.Background(), "", evalCtx, []string{"string-ignorecase-slice-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, result.Effect)
|
||||
})
|
||||
|
||||
t.Run("IpAddress:CustomContextKey", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowCustomIPKey",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"IpAddress": {
|
||||
"custom:VpcIp": "10.0.0.0/16",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "ip-custom-key-policy", policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
evalCtxMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"custom:VpcIp": "10.0.5.1",
|
||||
},
|
||||
}
|
||||
resultMatch, err := engine.Evaluate(context.Background(), "", evalCtxMatch, []string{"ip-custom-key-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultMatch.Effect)
|
||||
|
||||
evalCtxNoMatch := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"custom:VpcIp": "192.168.1.1",
|
||||
},
|
||||
}
|
||||
resultNoMatch, err := engine.Evaluate(context.Background(), "", evalCtxNoMatch, []string{"ip-custom-key-policy"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultNoMatch.Effect)
|
||||
})
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNegationSetOperators(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
t.Run("ForAllValues:StringNotEquals", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "DenyAdmin",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAllValues:StringNotEquals": {
|
||||
"oidc:roles": []string{"Admin"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// All roles are NOT "Admin" -> Should Allow
|
||||
evalCtxAllow := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"User", "Developer"},
|
||||
},
|
||||
}
|
||||
resultAllow, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxAllow)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultAllow.Effect, "Should allow when ALL roles satisfy StringNotEquals Admin")
|
||||
|
||||
// One role is "Admin" -> Should Deny
|
||||
evalCtxDeny := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"Admin", "User"},
|
||||
},
|
||||
}
|
||||
resultDeny, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxDeny)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultDeny.Effect, "Should deny when one role is Admin and fails StringNotEquals")
|
||||
})
|
||||
|
||||
t.Run("ForAnyValue:StringNotEquals", func(t *testing.T) {
|
||||
policy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "Requirement",
|
||||
Effect: "Allow",
|
||||
Action: []string{"sts:AssumeRole"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"ForAnyValue:StringNotEquals": {
|
||||
"oidc:roles": []string{"Prohibited"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// At least one role is NOT prohibited -> Should Allow
|
||||
evalCtxAllow := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"Prohibited", "Allowed"},
|
||||
},
|
||||
}
|
||||
resultAllow, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxAllow)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectAllow, resultAllow.Effect, "Should allow when at least one role is NOT Prohibited")
|
||||
|
||||
// All roles are Prohibited -> Should Deny
|
||||
evalCtxDeny := &EvaluationContext{
|
||||
Principal: "user",
|
||||
Action: "sts:AssumeRole",
|
||||
Resource: "arn:aws:iam::role/test-role",
|
||||
RequestContext: map[string]interface{}{
|
||||
"oidc:roles": []string{"Prohibited", "Prohibited"},
|
||||
},
|
||||
}
|
||||
resultDeny, err := engine.EvaluateTrustPolicy(context.Background(), policy, evalCtxDeny)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, EffectDeny, resultDeny.Effect, "Should deny when ALL roles are Prohibited")
|
||||
})
|
||||
}
|
||||
@@ -1155,11 +1155,6 @@ func ValidatePolicyDocumentWithType(policy *PolicyDocument, policyType string) e
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateStatement validates a single statement (for backward compatibility)
|
||||
func validateStatement(statement *Statement) error {
|
||||
return validateStatementWithType(statement, "resource")
|
||||
}
|
||||
|
||||
// validateStatementWithType validates a single statement based on policy type
|
||||
func validateStatementWithType(statement *Statement, policyType string) error {
|
||||
if statement.Effect != "Allow" && statement.Effect != "Deny" {
|
||||
@@ -1198,29 +1193,6 @@ func validateStatementWithType(statement *Statement, policyType string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchResource checks if a resource pattern matches a requested resource
|
||||
// Uses hybrid approach: simple suffix wildcards for compatibility, filepath.Match for complex patterns
|
||||
func matchResource(pattern, resource string) bool {
|
||||
if pattern == resource {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle simple suffix wildcard (backward compatibility)
|
||||
if strings.HasSuffix(pattern, "*") {
|
||||
prefix := pattern[:len(pattern)-1]
|
||||
return strings.HasPrefix(resource, prefix)
|
||||
}
|
||||
|
||||
// For complex patterns, use filepath.Match for advanced wildcard support (*, ?, [])
|
||||
matched, err := filepath.Match(pattern, resource)
|
||||
if err != nil {
|
||||
// Fallback to exact match if pattern is malformed
|
||||
return pattern == resource
|
||||
}
|
||||
|
||||
return matched
|
||||
}
|
||||
|
||||
// awsIAMMatch performs AWS IAM-compliant pattern matching with case-insensitivity and policy variable support
|
||||
func awsIAMMatch(pattern, value string, evalCtx *EvaluationContext) bool {
|
||||
// Step 1: Substitute policy variables (e.g., ${aws:username}, ${saml:username})
|
||||
@@ -1274,16 +1246,6 @@ func expandPolicyVariables(pattern string, evalCtx *EvaluationContext) string {
|
||||
return result
|
||||
}
|
||||
|
||||
// getContextValue safely gets a value from the evaluation context
|
||||
func getContextValue(evalCtx *EvaluationContext, key, defaultValue string) string {
|
||||
if value, exists := evalCtx.RequestContext[key]; exists {
|
||||
if str, ok := value.(string); ok {
|
||||
return str
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// AwsWildcardMatch performs case-insensitive wildcard matching like AWS IAM
|
||||
func AwsWildcardMatch(pattern, value string) bool {
|
||||
// Create regex pattern key for caching
|
||||
@@ -1322,29 +1284,6 @@ func AwsWildcardMatch(pattern, value string) bool {
|
||||
return regex.MatchString(value)
|
||||
}
|
||||
|
||||
// matchAction checks if an action pattern matches a requested action
|
||||
// Uses hybrid approach: simple suffix wildcards for compatibility, filepath.Match for complex patterns
|
||||
func matchAction(pattern, action string) bool {
|
||||
if pattern == action {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle simple suffix wildcard (backward compatibility)
|
||||
if strings.HasSuffix(pattern, "*") {
|
||||
prefix := pattern[:len(pattern)-1]
|
||||
return strings.HasPrefix(action, prefix)
|
||||
}
|
||||
|
||||
// For complex patterns, use filepath.Match for advanced wildcard support (*, ?, [])
|
||||
matched, err := filepath.Match(pattern, action)
|
||||
if err != nil {
|
||||
// Fallback to exact match if pattern is malformed
|
||||
return pattern == action
|
||||
}
|
||||
|
||||
return matched
|
||||
}
|
||||
|
||||
// evaluateStringConditionIgnoreCase evaluates string conditions with case insensitivity
|
||||
func (e *PolicyEngine) evaluateStringConditionIgnoreCase(block map[string]interface{}, evalCtx *EvaluationContext, shouldMatch bool, useWildcard bool, forAllValues bool) bool {
|
||||
for key, expectedValues := range block {
|
||||
|
||||
@@ -1,421 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestPrincipalMatching tests the matchesPrincipal method
|
||||
func TestPrincipalMatching(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
principal interface{}
|
||||
evalCtx *EvaluationContext
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "plain wildcard principal",
|
||||
principal: "*",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "structured wildcard federated principal",
|
||||
principal: map[string]interface{}{
|
||||
"Federated": "*",
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard in array",
|
||||
principal: map[string]interface{}{
|
||||
"Federated": []interface{}{"specific-provider", "*"},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "specific federated provider match",
|
||||
principal: map[string]interface{}{
|
||||
"Federated": "https://example.com/oidc",
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://example.com/oidc",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "specific federated provider no match",
|
||||
principal: map[string]interface{}{
|
||||
"Federated": "https://example.com/oidc",
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://other.com/oidc",
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "array with specific provider match",
|
||||
principal: map[string]interface{}{
|
||||
"Federated": []string{"https://provider1.com", "https://provider2.com"},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://provider2.com",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "AWS principal match",
|
||||
principal: map[string]interface{}{
|
||||
"AWS": "arn:aws:iam::123456789012:user/alice",
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:PrincipalArn": "arn:aws:iam::123456789012:user/alice",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Service principal match",
|
||||
principal: map[string]interface{}{
|
||||
"Service": "s3.amazonaws.com",
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:PrincipalServiceName": "s3.amazonaws.com",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := engine.matchesPrincipal(tt.principal, tt.evalCtx)
|
||||
assert.Equal(t, tt.want, result, "Principal matching failed for: %s", tt.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEvaluatePrincipalValue tests the evaluatePrincipalValue method
|
||||
func TestEvaluatePrincipalValue(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
principalValue interface{}
|
||||
contextKey string
|
||||
evalCtx *EvaluationContext
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "wildcard string",
|
||||
principalValue: "*",
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "specific string match",
|
||||
principalValue: "https://example.com",
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://example.com",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "specific string no match",
|
||||
principalValue: "https://example.com",
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://other.com",
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard in array",
|
||||
principalValue: []interface{}{"provider1", "*"},
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "array match",
|
||||
principalValue: []string{"provider1", "provider2", "provider3"},
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "provider2",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "array no match",
|
||||
principalValue: []string{"provider1", "provider2"},
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "provider3",
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "missing context key",
|
||||
principalValue: "specific-value",
|
||||
contextKey: "aws:FederatedProvider",
|
||||
evalCtx: &EvaluationContext{
|
||||
RequestContext: map[string]interface{}{},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := engine.evaluatePrincipalValue(tt.principalValue, tt.evalCtx, tt.contextKey)
|
||||
assert.Equal(t, tt.want, result, "Principal value evaluation failed for: %s", tt.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTrustPolicyEvaluation tests the EvaluateTrustPolicy method
|
||||
func TestTrustPolicyEvaluation(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
trustPolicy *PolicyDocument
|
||||
evalCtx *EvaluationContext
|
||||
wantEffect Effect
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "wildcard federated principal allows any provider",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: map[string]interface{}{
|
||||
"Federated": "*",
|
||||
},
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://any-provider.com",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectAllow,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "specific federated principal matches",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: map[string]interface{}{
|
||||
"Federated": "https://example.com/oidc",
|
||||
},
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://example.com/oidc",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectAllow,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "specific federated principal does not match",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: map[string]interface{}{
|
||||
"Federated": "https://example.com/oidc",
|
||||
},
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://other.com/oidc",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectDeny,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "plain wildcard principal",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: "*",
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://any-provider.com",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectAllow,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "trust policy with conditions",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: map[string]interface{}{
|
||||
"Federated": "*",
|
||||
},
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"StringEquals": {
|
||||
"oidc:aud": "my-app-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://provider.com",
|
||||
"oidc:aud": "my-app-id",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectAllow,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "trust policy condition not met",
|
||||
trustPolicy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Principal: map[string]interface{}{
|
||||
"Federated": "*",
|
||||
},
|
||||
Action: []string{"sts:AssumeRoleWithWebIdentity"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"StringEquals": {
|
||||
"oidc:aud": "my-app-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
evalCtx: &EvaluationContext{
|
||||
Action: "sts:AssumeRoleWithWebIdentity",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:FederatedProvider": "https://provider.com",
|
||||
"oidc:aud": "wrong-app-id",
|
||||
},
|
||||
},
|
||||
wantEffect: EffectDeny,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := engine.EvaluateTrustPolicy(context.Background(), tt.trustPolicy, tt.evalCtx)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantEffect, result.Effect, "Trust policy evaluation failed for: %s", tt.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetPrincipalContextKey tests the context key mapping
|
||||
func TestGetPrincipalContextKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
principalType string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "Federated principal",
|
||||
principalType: "Federated",
|
||||
want: "aws:FederatedProvider",
|
||||
},
|
||||
{
|
||||
name: "AWS principal",
|
||||
principalType: "AWS",
|
||||
want: "aws:PrincipalArn",
|
||||
},
|
||||
{
|
||||
name: "Service principal",
|
||||
principalType: "Service",
|
||||
want: "aws:PrincipalServiceName",
|
||||
},
|
||||
{
|
||||
name: "Custom principal type",
|
||||
principalType: "CustomType",
|
||||
want: "aws:PrincipalCustomType",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getPrincipalContextKey(tt.principalType)
|
||||
assert.Equal(t, tt.want, result, "Context key mapping failed for: %s", tt.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,426 +0,0 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestPolicyEngineInitialization tests policy engine initialization
|
||||
func TestPolicyEngineInitialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *PolicyEngineConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: &PolicyEngineConfig{
|
||||
DefaultEffect: "Deny",
|
||||
StoreType: "memory",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid default effect",
|
||||
config: &PolicyEngineConfig{
|
||||
DefaultEffect: "Invalid",
|
||||
StoreType: "memory",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine := NewPolicyEngine()
|
||||
|
||||
err := engine.Initialize(tt.config)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, engine.IsInitialized())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPolicyDocumentValidation tests policy document structure validation
|
||||
func TestPolicyDocumentValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
policy *PolicyDocument
|
||||
wantErr bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid policy document",
|
||||
policy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowS3Read",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject", "s3:ListBucket"},
|
||||
Resource: []string{"arn:aws:s3:::mybucket/*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing version",
|
||||
policy: &PolicyDocument{
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"arn:aws:s3:::mybucket/*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errorMsg: "version is required",
|
||||
},
|
||||
{
|
||||
name: "empty statements",
|
||||
policy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{},
|
||||
},
|
||||
wantErr: true,
|
||||
errorMsg: "at least one statement is required",
|
||||
},
|
||||
{
|
||||
name: "invalid effect",
|
||||
policy: &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Effect: "Maybe",
|
||||
Action: []string{"s3:GetObject"},
|
||||
Resource: []string{"arn:aws:s3:::mybucket/*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errorMsg: "invalid effect",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidatePolicyDocument(tt.policy)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPolicyEvaluation tests policy evaluation logic
|
||||
func TestPolicyEvaluation(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
// Add test policies
|
||||
readPolicy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowS3Read",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:GetObject", "s3:ListBucket"},
|
||||
Resource: []string{
|
||||
"arn:aws:s3:::public-bucket/*", // For object operations
|
||||
"arn:aws:s3:::public-bucket", // For bucket operations
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "read-policy", readPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
denyPolicy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "DenyS3Delete",
|
||||
Effect: "Deny",
|
||||
Action: []string{"s3:DeleteObject"},
|
||||
Resource: []string{"arn:aws:s3:::*"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = engine.AddPolicy("", "deny-policy", denyPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
context *EvaluationContext
|
||||
policies []string
|
||||
want Effect
|
||||
}{
|
||||
{
|
||||
name: "allow read access",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::public-bucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "192.168.1.100",
|
||||
},
|
||||
},
|
||||
policies: []string{"read-policy"},
|
||||
want: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "deny delete access (explicit deny)",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:DeleteObject",
|
||||
Resource: "arn:aws:s3:::public-bucket/file.txt",
|
||||
},
|
||||
policies: []string{"read-policy", "deny-policy"},
|
||||
want: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "deny by default (no matching policy)",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:PutObject",
|
||||
Resource: "arn:aws:s3:::public-bucket/file.txt",
|
||||
},
|
||||
policies: []string{"read-policy"},
|
||||
want: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "allow with wildcard action",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:admin",
|
||||
Action: "s3:ListBucket",
|
||||
Resource: "arn:aws:s3:::public-bucket",
|
||||
},
|
||||
policies: []string{"read-policy"},
|
||||
want: EffectAllow,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := engine.Evaluate(context.Background(), "", tt.context, tt.policies)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, result.Effect)
|
||||
|
||||
// Verify evaluation details
|
||||
assert.NotNil(t, result.EvaluationDetails)
|
||||
assert.Equal(t, tt.context.Action, result.EvaluationDetails.Action)
|
||||
assert.Equal(t, tt.context.Resource, result.EvaluationDetails.Resource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConditionEvaluation tests policy conditions
|
||||
func TestConditionEvaluation(t *testing.T) {
|
||||
engine := setupTestPolicyEngine(t)
|
||||
|
||||
// Policy with IP address condition
|
||||
conditionalPolicy := &PolicyDocument{
|
||||
Version: "2012-10-17",
|
||||
Statement: []Statement{
|
||||
{
|
||||
Sid: "AllowFromOfficeIP",
|
||||
Effect: "Allow",
|
||||
Action: []string{"s3:*"},
|
||||
Resource: []string{"arn:aws:s3:::*"},
|
||||
Condition: map[string]map[string]interface{}{
|
||||
"IpAddress": {
|
||||
"aws:SourceIp": []string{"192.168.1.0/24", "10.0.0.0/8"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := engine.AddPolicy("", "ip-conditional", conditionalPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
context *EvaluationContext
|
||||
want Effect
|
||||
}{
|
||||
{
|
||||
name: "allow from office IP",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::mybucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "192.168.1.100",
|
||||
},
|
||||
},
|
||||
want: EffectAllow,
|
||||
},
|
||||
{
|
||||
name: "deny from external IP",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:GetObject",
|
||||
Resource: "arn:aws:s3:::mybucket/file.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "8.8.8.8",
|
||||
},
|
||||
},
|
||||
want: EffectDeny,
|
||||
},
|
||||
{
|
||||
name: "allow from internal IP",
|
||||
context: &EvaluationContext{
|
||||
Principal: "user:alice",
|
||||
Action: "s3:PutObject",
|
||||
Resource: "arn:aws:s3:::mybucket/newfile.txt",
|
||||
RequestContext: map[string]interface{}{
|
||||
"aws:SourceIp": "10.1.2.3",
|
||||
},
|
||||
},
|
||||
want: EffectAllow,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := engine.Evaluate(context.Background(), "", tt.context, []string{"ip-conditional"})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, result.Effect)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResourceMatching tests resource ARN matching
|
||||
func TestResourceMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
policyResource string
|
||||
requestResource string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
policyResource: "arn:aws:s3:::mybucket/file.txt",
|
||||
requestResource: "arn:aws:s3:::mybucket/file.txt",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard match",
|
||||
policyResource: "arn:aws:s3:::mybucket/*",
|
||||
requestResource: "arn:aws:s3:::mybucket/folder/file.txt",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "bucket wildcard",
|
||||
policyResource: "arn:aws:s3:::*",
|
||||
requestResource: "arn:aws:s3:::anybucket/file.txt",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match different bucket",
|
||||
policyResource: "arn:aws:s3:::mybucket/*",
|
||||
requestResource: "arn:aws:s3:::otherbucket/file.txt",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "prefix match",
|
||||
policyResource: "arn:aws:s3:::mybucket/documents/*",
|
||||
requestResource: "arn:aws:s3:::mybucket/documents/secret.txt",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchResource(tt.policyResource, tt.requestResource)
|
||||
assert.Equal(t, tt.want, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestActionMatching tests action pattern matching
|
||||
func TestActionMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
policyAction string
|
||||
requestAction string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
policyAction: "s3:GetObject",
|
||||
requestAction: "s3:GetObject",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard service",
|
||||
policyAction: "s3:*",
|
||||
requestAction: "s3:PutObject",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard all",
|
||||
policyAction: "*",
|
||||
requestAction: "filer:CreateEntry",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "prefix match",
|
||||
policyAction: "s3:Get*",
|
||||
requestAction: "s3:GetObject",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match different service",
|
||||
policyAction: "s3:GetObject",
|
||||
requestAction: "filer:GetEntry",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchAction(tt.policyAction, tt.requestAction)
|
||||
assert.Equal(t, tt.want, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to set up test policy engine
|
||||
func setupTestPolicyEngine(t *testing.T) *PolicyEngine {
|
||||
engine := NewPolicyEngine()
|
||||
config := &PolicyEngineConfig{
|
||||
DefaultEffect: "Deny",
|
||||
StoreType: "memory",
|
||||
}
|
||||
|
||||
err := engine.Initialize(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
return engine
|
||||
}
|
||||
@@ -1,246 +0,0 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIdentityProviderInterface tests the core identity provider interface
|
||||
func TestIdentityProviderInterface(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
provider IdentityProvider
|
||||
wantErr bool
|
||||
}{
|
||||
// We'll add test cases as we implement providers
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test provider name
|
||||
name := tt.provider.Name()
|
||||
assert.NotEmpty(t, name, "Provider name should not be empty")
|
||||
|
||||
// Test initialization
|
||||
err := tt.provider.Initialize(nil)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test authentication with invalid token
|
||||
ctx := context.Background()
|
||||
_, err = tt.provider.Authenticate(ctx, "invalid-token")
|
||||
assert.Error(t, err, "Should fail with invalid token")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExternalIdentityValidation tests external identity structure validation
|
||||
func TestExternalIdentityValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
identity *ExternalIdentity
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid identity",
|
||||
identity: &ExternalIdentity{
|
||||
UserID: "user123",
|
||||
Email: "user@example.com",
|
||||
DisplayName: "Test User",
|
||||
Groups: []string{"group1", "group2"},
|
||||
Attributes: map[string]string{"dept": "engineering"},
|
||||
Provider: "test-provider",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing user id",
|
||||
identity: &ExternalIdentity{
|
||||
Email: "user@example.com",
|
||||
Provider: "test-provider",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing provider",
|
||||
identity: &ExternalIdentity{
|
||||
UserID: "user123",
|
||||
Email: "user@example.com",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
identity: &ExternalIdentity{
|
||||
UserID: "user123",
|
||||
Email: "invalid-email",
|
||||
Provider: "test-provider",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.identity.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenClaimsValidation tests token claims structure
|
||||
func TestTokenClaimsValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
claims *TokenClaims
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "valid claims",
|
||||
claims: &TokenClaims{
|
||||
Subject: "user123",
|
||||
Issuer: "https://provider.example.com",
|
||||
Audience: "seaweedfs",
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
IssuedAt: time.Now().Add(-time.Minute),
|
||||
Claims: map[string]interface{}{"email": "user@example.com"},
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
claims: &TokenClaims{
|
||||
Subject: "user123",
|
||||
Issuer: "https://provider.example.com",
|
||||
Audience: "seaweedfs",
|
||||
ExpiresAt: time.Now().Add(-time.Hour), // Expired
|
||||
IssuedAt: time.Now().Add(-time.Hour * 2),
|
||||
Claims: map[string]interface{}{"email": "user@example.com"},
|
||||
},
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "future issued token",
|
||||
claims: &TokenClaims{
|
||||
Subject: "user123",
|
||||
Issuer: "https://provider.example.com",
|
||||
Audience: "seaweedfs",
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
IssuedAt: time.Now().Add(time.Hour), // Future
|
||||
Claims: map[string]interface{}{"email": "user@example.com"},
|
||||
},
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
valid := tt.claims.IsValid()
|
||||
assert.Equal(t, tt.valid, valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderRegistry tests provider registration and discovery
|
||||
func TestProviderRegistry(t *testing.T) {
|
||||
// Clear registry for test
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
t.Run("register provider", func(t *testing.T) {
|
||||
mockProvider := &MockProvider{name: "test-provider"}
|
||||
|
||||
err := registry.RegisterProvider(mockProvider)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test duplicate registration
|
||||
err = registry.RegisterProvider(mockProvider)
|
||||
assert.Error(t, err, "Should not allow duplicate registration")
|
||||
})
|
||||
|
||||
t.Run("get provider", func(t *testing.T) {
|
||||
provider, exists := registry.GetProvider("test-provider")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, "test-provider", provider.Name())
|
||||
|
||||
// Test non-existent provider
|
||||
_, exists = registry.GetProvider("non-existent")
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("list providers", func(t *testing.T) {
|
||||
providers := registry.ListProviders()
|
||||
assert.Len(t, providers, 1)
|
||||
assert.Equal(t, "test-provider", providers[0])
|
||||
})
|
||||
}
|
||||
|
||||
// MockProvider for testing
|
||||
type MockProvider struct {
|
||||
name string
|
||||
initialized bool
|
||||
shouldError bool
|
||||
}
|
||||
|
||||
func (m *MockProvider) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockProvider) Initialize(config interface{}) error {
|
||||
if m.shouldError {
|
||||
return assert.AnError
|
||||
}
|
||||
m.initialized = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) Authenticate(ctx context.Context, token string) (*ExternalIdentity, error) {
|
||||
if !m.initialized {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
if token == "invalid-token" {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
return &ExternalIdentity{
|
||||
UserID: "test-user",
|
||||
Email: "test@example.com",
|
||||
DisplayName: "Test User",
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) GetUserInfo(ctx context.Context, userID string) (*ExternalIdentity, error) {
|
||||
if !m.initialized || userID == "" {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
return &ExternalIdentity{
|
||||
UserID: userID,
|
||||
Email: userID + "@example.com",
|
||||
DisplayName: "User " + userID,
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) ValidateToken(ctx context.Context, token string) (*TokenClaims, error) {
|
||||
if !m.initialized || token == "invalid-token" {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
return &TokenClaims{
|
||||
Subject: "test-user",
|
||||
Issuer: "test-issuer",
|
||||
Audience: "seaweedfs",
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
IssuedAt: time.Now(),
|
||||
Claims: map[string]interface{}{"email": "test@example.com"},
|
||||
}, nil
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ProviderRegistry manages registered identity providers
|
||||
type ProviderRegistry struct {
|
||||
mu sync.RWMutex
|
||||
providers map[string]IdentityProvider
|
||||
}
|
||||
|
||||
// NewProviderRegistry creates a new provider registry
|
||||
func NewProviderRegistry() *ProviderRegistry {
|
||||
return &ProviderRegistry{
|
||||
providers: make(map[string]IdentityProvider),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterProvider registers a new identity provider
|
||||
func (r *ProviderRegistry) RegisterProvider(provider IdentityProvider) error {
|
||||
if provider == nil {
|
||||
return fmt.Errorf("provider cannot be nil")
|
||||
}
|
||||
|
||||
name := provider.Name()
|
||||
if name == "" {
|
||||
return fmt.Errorf("provider name cannot be empty")
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.providers[name]; exists {
|
||||
return fmt.Errorf("provider %s is already registered", name)
|
||||
}
|
||||
|
||||
r.providers[name] = provider
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProvider retrieves a provider by name
|
||||
func (r *ProviderRegistry) GetProvider(name string) (IdentityProvider, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
provider, exists := r.providers[name]
|
||||
return provider, exists
|
||||
}
|
||||
|
||||
// ListProviders returns all registered provider names
|
||||
func (r *ProviderRegistry) ListProviders() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var names []string
|
||||
for name := range r.providers {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// UnregisterProvider removes a provider from the registry
|
||||
func (r *ProviderRegistry) UnregisterProvider(name string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.providers[name]; !exists {
|
||||
return fmt.Errorf("provider %s is not registered", name)
|
||||
}
|
||||
|
||||
delete(r.providers, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all providers from the registry
|
||||
func (r *ProviderRegistry) Clear() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.providers = make(map[string]IdentityProvider)
|
||||
}
|
||||
|
||||
// GetProviderCount returns the number of registered providers
|
||||
func (r *ProviderRegistry) GetProviderCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
return len(r.providers)
|
||||
}
|
||||
|
||||
// Default global registry
|
||||
var defaultRegistry = NewProviderRegistry()
|
||||
|
||||
// RegisterProvider registers a provider in the default registry
|
||||
func RegisterProvider(provider IdentityProvider) error {
|
||||
return defaultRegistry.RegisterProvider(provider)
|
||||
}
|
||||
|
||||
// GetProvider retrieves a provider from the default registry
|
||||
func GetProvider(name string) (IdentityProvider, bool) {
|
||||
return defaultRegistry.GetProvider(name)
|
||||
}
|
||||
|
||||
// ListProviders returns all provider names from the default registry
|
||||
func ListProviders() []string {
|
||||
return defaultRegistry.ListProviders()
|
||||
}
|
||||
@@ -1,503 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/oidc"
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test-only constants for mock providers
|
||||
const (
|
||||
ProviderTypeMock = "mock"
|
||||
)
|
||||
|
||||
// createMockOIDCProvider creates a mock OIDC provider for testing
|
||||
// This is only available in test builds
|
||||
func createMockOIDCProvider(name string, config map[string]interface{}) (providers.IdentityProvider, error) {
|
||||
// Convert config to OIDC format
|
||||
factory := NewProviderFactory()
|
||||
oidcConfig, err := factory.convertToOIDCConfig(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set default values for mock provider if not provided
|
||||
if oidcConfig.Issuer == "" {
|
||||
oidcConfig.Issuer = "http://localhost:9999"
|
||||
}
|
||||
|
||||
provider := oidc.NewMockOIDCProvider(name)
|
||||
if err := provider.Initialize(oidcConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set up default test data for the mock provider
|
||||
provider.SetupDefaultTestData()
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// createMockJWT creates a test JWT token with the specified issuer for mock provider testing
|
||||
func createMockJWT(t *testing.T, issuer, subject string) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": subject,
|
||||
"aud": "test-client",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// TestCrossInstanceTokenUsage verifies that tokens generated by one STS instance
|
||||
// can be used and validated by other STS instances in a distributed environment
|
||||
func TestCrossInstanceTokenUsage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// Dummy filer address for testing
|
||||
|
||||
// Common configuration that would be shared across all instances in production
|
||||
sharedConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "distributed-sts-cluster", // SAME across all instances
|
||||
SigningKey: []byte(TestSigningKey32Chars), // SAME across all instances
|
||||
Providers: []*ProviderConfig{
|
||||
{
|
||||
Name: "company-oidc",
|
||||
Type: ProviderTypeOIDC,
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
ConfigFieldIssuer: "https://sso.company.com/realms/production",
|
||||
ConfigFieldClientID: "seaweedfs-cluster",
|
||||
ConfigFieldJWKSUri: "https://sso.company.com/realms/production/protocol/openid-connect/certs",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create multiple STS instances simulating different S3 gateway instances
|
||||
instanceA := NewSTSService() // e.g., s3-gateway-1
|
||||
instanceB := NewSTSService() // e.g., s3-gateway-2
|
||||
instanceC := NewSTSService() // e.g., s3-gateway-3
|
||||
|
||||
// Initialize all instances with IDENTICAL configuration
|
||||
err := instanceA.Initialize(sharedConfig)
|
||||
require.NoError(t, err, "Instance A should initialize")
|
||||
|
||||
err = instanceB.Initialize(sharedConfig)
|
||||
require.NoError(t, err, "Instance B should initialize")
|
||||
|
||||
err = instanceC.Initialize(sharedConfig)
|
||||
require.NoError(t, err, "Instance C should initialize")
|
||||
|
||||
// Set up mock trust policy validator for all instances (required for STS testing)
|
||||
mockValidator := &MockTrustPolicyValidator{}
|
||||
instanceA.SetTrustPolicyValidator(mockValidator)
|
||||
instanceB.SetTrustPolicyValidator(mockValidator)
|
||||
instanceC.SetTrustPolicyValidator(mockValidator)
|
||||
|
||||
// Manually register mock provider for testing (not available in production)
|
||||
mockProviderConfig := map[string]interface{}{
|
||||
ConfigFieldIssuer: "http://test-mock:9999",
|
||||
ConfigFieldClientID: TestClientID,
|
||||
}
|
||||
mockProviderA, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProviderB, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProviderC, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
instanceA.RegisterProvider(mockProviderA)
|
||||
instanceB.RegisterProvider(mockProviderB)
|
||||
instanceC.RegisterProvider(mockProviderC)
|
||||
|
||||
// Test 1: Token generated on Instance A can be validated on Instance B & C
|
||||
t.Run("cross_instance_token_validation", func(t *testing.T) {
|
||||
// Generate session token on Instance A
|
||||
sessionId := TestSessionID
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
|
||||
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err, "Instance A should generate token")
|
||||
|
||||
// Validate token on Instance B
|
||||
claimsFromB, err := instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
|
||||
require.NoError(t, err, "Instance B should validate token from Instance A")
|
||||
assert.Equal(t, sessionId, claimsFromB.SessionId, "Session ID should match")
|
||||
|
||||
// Validate same token on Instance C
|
||||
claimsFromC, err := instanceC.GetTokenGenerator().ValidateSessionToken(tokenFromA)
|
||||
require.NoError(t, err, "Instance C should validate token from Instance A")
|
||||
assert.Equal(t, sessionId, claimsFromC.SessionId, "Session ID should match")
|
||||
|
||||
// All instances should extract identical claims
|
||||
assert.Equal(t, claimsFromB.SessionId, claimsFromC.SessionId)
|
||||
assert.Equal(t, claimsFromB.ExpiresAt.Unix(), claimsFromC.ExpiresAt.Unix())
|
||||
assert.Equal(t, claimsFromB.IssuedAt.Unix(), claimsFromC.IssuedAt.Unix())
|
||||
})
|
||||
|
||||
// Test 2: Complete assume role flow across instances
|
||||
t.Run("cross_instance_assume_role_flow", func(t *testing.T) {
|
||||
// Step 1: User authenticates and assumes role on Instance A
|
||||
// Create a valid JWT token for the mock provider
|
||||
mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
|
||||
|
||||
assumeRequest := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/CrossInstanceTestRole",
|
||||
WebIdentityToken: mockToken, // JWT token for mock provider
|
||||
RoleSessionName: "cross-instance-test-session",
|
||||
DurationSeconds: int64ToPtr(3600),
|
||||
}
|
||||
|
||||
// Instance A processes assume role request
|
||||
responseFromA, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
require.NoError(t, err, "Instance A should process assume role")
|
||||
|
||||
sessionToken := responseFromA.Credentials.SessionToken
|
||||
accessKeyId := responseFromA.Credentials.AccessKeyId
|
||||
secretAccessKey := responseFromA.Credentials.SecretAccessKey
|
||||
|
||||
// Verify response structure
|
||||
assert.NotEmpty(t, sessionToken, "Should have session token")
|
||||
assert.NotEmpty(t, accessKeyId, "Should have access key ID")
|
||||
assert.NotEmpty(t, secretAccessKey, "Should have secret access key")
|
||||
assert.NotNil(t, responseFromA.AssumedRoleUser, "Should have assumed role user")
|
||||
|
||||
// Step 2: Use session token on Instance B (different instance)
|
||||
sessionInfoFromB, err := instanceB.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Instance B should validate session token from Instance A")
|
||||
|
||||
assert.Equal(t, assumeRequest.RoleSessionName, sessionInfoFromB.SessionName)
|
||||
assert.Equal(t, assumeRequest.RoleArn, sessionInfoFromB.RoleArn)
|
||||
|
||||
// Step 3: Use same session token on Instance C (yet another instance)
|
||||
sessionInfoFromC, err := instanceC.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Instance C should validate session token from Instance A")
|
||||
|
||||
// All instances should return identical session information
|
||||
assert.Equal(t, sessionInfoFromB.SessionId, sessionInfoFromC.SessionId)
|
||||
assert.Equal(t, sessionInfoFromB.SessionName, sessionInfoFromC.SessionName)
|
||||
assert.Equal(t, sessionInfoFromB.RoleArn, sessionInfoFromC.RoleArn)
|
||||
assert.Equal(t, sessionInfoFromB.Subject, sessionInfoFromC.Subject)
|
||||
assert.Equal(t, sessionInfoFromB.Provider, sessionInfoFromC.Provider)
|
||||
})
|
||||
|
||||
// Test 3: Session revocation across instances
|
||||
t.Run("cross_instance_session_revocation", func(t *testing.T) {
|
||||
// Create session on Instance A
|
||||
mockToken := createMockJWT(t, "http://test-mock:9999", "test-user")
|
||||
|
||||
assumeRequest := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/RevocationTestRole",
|
||||
WebIdentityToken: mockToken,
|
||||
RoleSessionName: "revocation-test-session",
|
||||
}
|
||||
|
||||
response, err := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
require.NoError(t, err)
|
||||
sessionToken := response.Credentials.SessionToken
|
||||
|
||||
// Verify token works on Instance B
|
||||
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Token should be valid on Instance B initially")
|
||||
|
||||
// Validate session on Instance C to verify cross-instance token compatibility
|
||||
_, err = instanceC.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Instance C should be able to validate session token")
|
||||
|
||||
// In a stateless JWT system, tokens remain valid on all instances since they're self-contained
|
||||
// No revocation is possible without breaking the stateless architecture
|
||||
_, err = instanceA.ValidateSessionToken(ctx, sessionToken)
|
||||
assert.NoError(t, err, "Token should still be valid on Instance A (stateless system)")
|
||||
|
||||
// Verify token is still valid on Instance B
|
||||
_, err = instanceB.ValidateSessionToken(ctx, sessionToken)
|
||||
assert.NoError(t, err, "Token should still be valid on Instance B (stateless system)")
|
||||
})
|
||||
|
||||
// Test 4: Provider consistency across instances
|
||||
t.Run("provider_consistency_affects_token_generation", func(t *testing.T) {
|
||||
// All instances should have same providers and be able to process same OIDC tokens
|
||||
providerNamesA := instanceA.getProviderNames()
|
||||
providerNamesB := instanceB.getProviderNames()
|
||||
providerNamesC := instanceC.getProviderNames()
|
||||
|
||||
assert.ElementsMatch(t, providerNamesA, providerNamesB, "Instance A and B should have same providers")
|
||||
assert.ElementsMatch(t, providerNamesB, providerNamesC, "Instance B and C should have same providers")
|
||||
|
||||
// All instances should be able to process same web identity token
|
||||
testToken := createMockJWT(t, "http://test-mock:9999", "test-user")
|
||||
|
||||
// Try to assume role with same token on different instances
|
||||
assumeRequest := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/ProviderTestRole",
|
||||
WebIdentityToken: testToken,
|
||||
RoleSessionName: "provider-consistency-test",
|
||||
}
|
||||
|
||||
// Should work on any instance
|
||||
responseA, errA := instanceA.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
responseB, errB := instanceB.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
responseC, errC := instanceC.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
|
||||
require.NoError(t, errA, "Instance A should process OIDC token")
|
||||
require.NoError(t, errB, "Instance B should process OIDC token")
|
||||
require.NoError(t, errC, "Instance C should process OIDC token")
|
||||
|
||||
// All should return valid responses (sessions will have different IDs but same structure)
|
||||
assert.NotEmpty(t, responseA.Credentials.SessionToken)
|
||||
assert.NotEmpty(t, responseB.Credentials.SessionToken)
|
||||
assert.NotEmpty(t, responseC.Credentials.SessionToken)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSTSDistributedConfigurationRequirements tests the configuration requirements
|
||||
// for cross-instance token compatibility
|
||||
func TestSTSDistributedConfigurationRequirements(t *testing.T) {
|
||||
_ = "localhost:8888" // Dummy filer address for testing (not used in these tests)
|
||||
|
||||
t.Run("same_signing_key_required", func(t *testing.T) {
|
||||
// Instance A with signing key 1
|
||||
configA := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("signing-key-1-32-characters-long"),
|
||||
}
|
||||
|
||||
// Instance B with different signing key
|
||||
configB := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("signing-key-2-32-characters-long"), // DIFFERENT!
|
||||
}
|
||||
|
||||
instanceA := NewSTSService()
|
||||
instanceB := NewSTSService()
|
||||
|
||||
err := instanceA.Initialize(configA)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = instanceB.Initialize(configB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate token on Instance A
|
||||
sessionId := "test-session"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instance A should validate its own token
|
||||
_, err = instanceA.GetTokenGenerator().ValidateSessionToken(tokenFromA)
|
||||
assert.NoError(t, err, "Instance A should validate own token")
|
||||
|
||||
// Instance B should REJECT token due to different signing key
|
||||
_, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
|
||||
assert.Error(t, err, "Instance B should reject token with different signing key")
|
||||
assert.Contains(t, err.Error(), "invalid token", "Should be signature validation error")
|
||||
})
|
||||
|
||||
t.Run("same_issuer_required", func(t *testing.T) {
|
||||
sharedSigningKey := []byte("shared-signing-key-32-characters-lo")
|
||||
|
||||
// Instance A with issuer 1
|
||||
configA := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "sts-cluster-1",
|
||||
SigningKey: sharedSigningKey,
|
||||
}
|
||||
|
||||
// Instance B with different issuer
|
||||
configB := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "sts-cluster-2", // DIFFERENT!
|
||||
SigningKey: sharedSigningKey,
|
||||
}
|
||||
|
||||
instanceA := NewSTSService()
|
||||
instanceB := NewSTSService()
|
||||
|
||||
err := instanceA.Initialize(configA)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = instanceB.Initialize(configB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate token on Instance A
|
||||
sessionId := "test-session"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
tokenFromA, err := instanceA.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instance B should REJECT token due to different issuer
|
||||
_, err = instanceB.GetTokenGenerator().ValidateSessionToken(tokenFromA)
|
||||
assert.Error(t, err, "Instance B should reject token with different issuer")
|
||||
assert.Contains(t, err.Error(), "invalid issuer", "Should be issuer validation error")
|
||||
})
|
||||
|
||||
t.Run("identical_configuration_required", func(t *testing.T) {
|
||||
// Identical configuration
|
||||
identicalConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "production-sts-cluster",
|
||||
SigningKey: []byte("production-signing-key-32-chars-l"),
|
||||
}
|
||||
|
||||
// Create multiple instances with identical config
|
||||
instances := make([]*STSService, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
instances[i] = NewSTSService()
|
||||
err := instances[i].Initialize(identicalConfig)
|
||||
require.NoError(t, err, "Instance %d should initialize", i)
|
||||
}
|
||||
|
||||
// Generate token on Instance 0
|
||||
sessionId := "multi-instance-test"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
token, err := instances[0].GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// All other instances should validate the token
|
||||
for i := 1; i < 5; i++ {
|
||||
claims, err := instances[i].GetTokenGenerator().ValidateSessionToken(token)
|
||||
require.NoError(t, err, "Instance %d should validate token", i)
|
||||
assert.Equal(t, sessionId, claims.SessionId, "Instance %d should extract correct session ID", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSTSRealWorldDistributedScenarios tests realistic distributed deployment scenarios
|
||||
func TestSTSRealWorldDistributedScenarios(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("load_balanced_s3_gateway_scenario", func(t *testing.T) {
|
||||
// Simulate real production scenario:
|
||||
// 1. User authenticates with OIDC provider
|
||||
// 2. User calls AssumeRoleWithWebIdentity on S3 Gateway 1
|
||||
// 3. User makes S3 requests that hit S3 Gateway 2 & 3 via load balancer
|
||||
// 4. All instances should handle the session token correctly
|
||||
|
||||
productionConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{2 * time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{24 * time.Hour},
|
||||
Issuer: "seaweedfs-production-sts",
|
||||
SigningKey: []byte("prod-signing-key-32-characters-lon"),
|
||||
|
||||
Providers: []*ProviderConfig{
|
||||
{
|
||||
Name: "corporate-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://sso.company.com/realms/production",
|
||||
"clientId": "seaweedfs-prod-cluster",
|
||||
"clientSecret": "supersecret-prod-key",
|
||||
"scopes": []string{"openid", "profile", "email", "groups"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create 3 S3 Gateway instances behind load balancer
|
||||
gateway1 := NewSTSService()
|
||||
gateway2 := NewSTSService()
|
||||
gateway3 := NewSTSService()
|
||||
|
||||
err := gateway1.Initialize(productionConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = gateway2.Initialize(productionConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = gateway3.Initialize(productionConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up mock trust policy validator for all gateway instances
|
||||
mockValidator := &MockTrustPolicyValidator{}
|
||||
gateway1.SetTrustPolicyValidator(mockValidator)
|
||||
gateway2.SetTrustPolicyValidator(mockValidator)
|
||||
gateway3.SetTrustPolicyValidator(mockValidator)
|
||||
|
||||
// Manually register mock provider for testing (not available in production)
|
||||
mockProviderConfig := map[string]interface{}{
|
||||
ConfigFieldIssuer: "http://test-mock:9999",
|
||||
ConfigFieldClientID: "test-client-id",
|
||||
}
|
||||
mockProvider1, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProvider2, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProvider3, err := createMockOIDCProvider("test-mock", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
gateway1.RegisterProvider(mockProvider1)
|
||||
gateway2.RegisterProvider(mockProvider2)
|
||||
gateway3.RegisterProvider(mockProvider3)
|
||||
|
||||
// Step 1: User authenticates and hits Gateway 1 for AssumeRole
|
||||
mockToken := createMockJWT(t, "http://test-mock:9999", "production-user")
|
||||
|
||||
assumeRequest := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/ProductionS3User",
|
||||
WebIdentityToken: mockToken, // JWT token from mock provider
|
||||
RoleSessionName: "user-production-session",
|
||||
DurationSeconds: int64ToPtr(7200), // 2 hours
|
||||
}
|
||||
|
||||
stsResponse, err := gateway1.AssumeRoleWithWebIdentity(ctx, assumeRequest)
|
||||
require.NoError(t, err, "Gateway 1 should handle AssumeRole")
|
||||
|
||||
sessionToken := stsResponse.Credentials.SessionToken
|
||||
accessKey := stsResponse.Credentials.AccessKeyId
|
||||
secretKey := stsResponse.Credentials.SecretAccessKey
|
||||
|
||||
// Step 2: User makes S3 requests that hit different gateways via load balancer
|
||||
// Simulate S3 request validation on Gateway 2
|
||||
sessionInfo2, err := gateway2.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Gateway 2 should validate session from Gateway 1")
|
||||
assert.Equal(t, "user-production-session", sessionInfo2.SessionName)
|
||||
assert.Equal(t, "arn:aws:iam::role/ProductionS3User", sessionInfo2.RoleArn)
|
||||
|
||||
// Simulate S3 request validation on Gateway 3
|
||||
sessionInfo3, err := gateway3.ValidateSessionToken(ctx, sessionToken)
|
||||
require.NoError(t, err, "Gateway 3 should validate session from Gateway 1")
|
||||
assert.Equal(t, sessionInfo2.SessionId, sessionInfo3.SessionId, "Should be same session")
|
||||
|
||||
// Step 3: Verify credentials are consistent
|
||||
assert.Equal(t, accessKey, stsResponse.Credentials.AccessKeyId, "Access key should be consistent")
|
||||
assert.Equal(t, secretKey, stsResponse.Credentials.SecretAccessKey, "Secret key should be consistent")
|
||||
|
||||
// Step 4: Session expiration should be honored across all instances
|
||||
assert.True(t, sessionInfo2.ExpiresAt.After(time.Now()), "Session should not be expired")
|
||||
assert.True(t, sessionInfo3.ExpiresAt.After(time.Now()), "Session should not be expired")
|
||||
|
||||
// Step 5: Token should be identical when parsed
|
||||
claims2, err := gateway2.GetTokenGenerator().ValidateSessionToken(sessionToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims3, err := gateway3.GetTokenGenerator().ValidateSessionToken(sessionToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, claims2.SessionId, claims3.SessionId, "Session IDs should match")
|
||||
assert.Equal(t, claims2.ExpiresAt.Unix(), claims3.ExpiresAt.Unix(), "Expiration should match")
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to convert int64 to pointer
|
||||
func int64ToPtr(i int64) *int64 {
|
||||
return &i
|
||||
}
|
||||
@@ -1,340 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDistributedSTSService verifies that multiple STS instances with identical configurations
|
||||
// behave consistently across distributed environments
|
||||
func TestDistributedSTSService(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Common configuration for all instances
|
||||
commonConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "distributed-sts-test",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
|
||||
Providers: []*ProviderConfig{
|
||||
{
|
||||
Name: "keycloak-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "http://keycloak:8080/realms/seaweedfs-test",
|
||||
"clientId": "seaweedfs-s3",
|
||||
"jwksUri": "http://keycloak:8080/realms/seaweedfs-test/protocol/openid-connect/certs",
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Name: "disabled-ldap",
|
||||
Type: "oidc", // Use OIDC as placeholder since LDAP isn't implemented
|
||||
Enabled: false,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "ldap://company.com",
|
||||
"clientId": "ldap-client",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create multiple STS instances simulating distributed deployment
|
||||
instance1 := NewSTSService()
|
||||
instance2 := NewSTSService()
|
||||
instance3 := NewSTSService()
|
||||
|
||||
// Initialize all instances with identical configuration
|
||||
err := instance1.Initialize(commonConfig)
|
||||
require.NoError(t, err, "Instance 1 should initialize successfully")
|
||||
|
||||
err = instance2.Initialize(commonConfig)
|
||||
require.NoError(t, err, "Instance 2 should initialize successfully")
|
||||
|
||||
err = instance3.Initialize(commonConfig)
|
||||
require.NoError(t, err, "Instance 3 should initialize successfully")
|
||||
|
||||
// Manually register mock providers for testing (not available in production)
|
||||
mockProviderConfig := map[string]interface{}{
|
||||
"issuer": "http://localhost:9999",
|
||||
"clientId": "test-client",
|
||||
}
|
||||
mockProvider1, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProvider2, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
mockProvider3, err := createMockOIDCProvider("test-mock-provider", mockProviderConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
instance1.RegisterProvider(mockProvider1)
|
||||
instance2.RegisterProvider(mockProvider2)
|
||||
instance3.RegisterProvider(mockProvider3)
|
||||
|
||||
// Verify all instances have identical provider configurations
|
||||
t.Run("provider_consistency", func(t *testing.T) {
|
||||
// All instances should have same number of providers
|
||||
assert.Len(t, instance1.providers, 2, "Instance 1 should have 2 enabled providers")
|
||||
assert.Len(t, instance2.providers, 2, "Instance 2 should have 2 enabled providers")
|
||||
assert.Len(t, instance3.providers, 2, "Instance 3 should have 2 enabled providers")
|
||||
|
||||
// All instances should have same provider names
|
||||
instance1Names := instance1.getProviderNames()
|
||||
instance2Names := instance2.getProviderNames()
|
||||
instance3Names := instance3.getProviderNames()
|
||||
|
||||
assert.ElementsMatch(t, instance1Names, instance2Names, "Instance 1 and 2 should have same providers")
|
||||
assert.ElementsMatch(t, instance2Names, instance3Names, "Instance 2 and 3 should have same providers")
|
||||
|
||||
// Verify specific providers exist on all instances
|
||||
expectedProviders := []string{"keycloak-oidc", "test-mock-provider"}
|
||||
assert.ElementsMatch(t, instance1Names, expectedProviders, "Instance 1 should have expected providers")
|
||||
assert.ElementsMatch(t, instance2Names, expectedProviders, "Instance 2 should have expected providers")
|
||||
assert.ElementsMatch(t, instance3Names, expectedProviders, "Instance 3 should have expected providers")
|
||||
|
||||
// Verify disabled providers are not loaded
|
||||
assert.NotContains(t, instance1Names, "disabled-ldap", "Disabled providers should not be loaded")
|
||||
assert.NotContains(t, instance2Names, "disabled-ldap", "Disabled providers should not be loaded")
|
||||
assert.NotContains(t, instance3Names, "disabled-ldap", "Disabled providers should not be loaded")
|
||||
})
|
||||
|
||||
// Test token generation consistency across instances
|
||||
t.Run("token_generation_consistency", func(t *testing.T) {
|
||||
sessionId := "test-session-123"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
|
||||
// Generate tokens from different instances
|
||||
token1, err1 := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
token2, err2 := instance2.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
token3, err3 := instance3.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
|
||||
require.NoError(t, err1, "Instance 1 token generation should succeed")
|
||||
require.NoError(t, err2, "Instance 2 token generation should succeed")
|
||||
require.NoError(t, err3, "Instance 3 token generation should succeed")
|
||||
|
||||
// All tokens should be different (due to timestamp variations)
|
||||
// But they should all be valid JWTs with same signing key
|
||||
assert.NotEmpty(t, token1)
|
||||
assert.NotEmpty(t, token2)
|
||||
assert.NotEmpty(t, token3)
|
||||
})
|
||||
|
||||
// Test token validation consistency - any instance should validate tokens from any other instance
|
||||
t.Run("cross_instance_token_validation", func(t *testing.T) {
|
||||
sessionId := "cross-validation-session"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
|
||||
// Generate token on instance 1
|
||||
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate on all instances
|
||||
claims1, err1 := instance1.GetTokenGenerator().ValidateSessionToken(token)
|
||||
claims2, err2 := instance2.GetTokenGenerator().ValidateSessionToken(token)
|
||||
claims3, err3 := instance3.GetTokenGenerator().ValidateSessionToken(token)
|
||||
|
||||
require.NoError(t, err1, "Instance 1 should validate token from instance 1")
|
||||
require.NoError(t, err2, "Instance 2 should validate token from instance 1")
|
||||
require.NoError(t, err3, "Instance 3 should validate token from instance 1")
|
||||
|
||||
// All instances should extract same session ID
|
||||
assert.Equal(t, sessionId, claims1.SessionId)
|
||||
assert.Equal(t, sessionId, claims2.SessionId)
|
||||
assert.Equal(t, sessionId, claims3.SessionId)
|
||||
|
||||
assert.Equal(t, claims1.SessionId, claims2.SessionId)
|
||||
assert.Equal(t, claims2.SessionId, claims3.SessionId)
|
||||
})
|
||||
|
||||
// Test provider access consistency
|
||||
t.Run("provider_access_consistency", func(t *testing.T) {
|
||||
// All instances should be able to access the same providers
|
||||
provider1, exists1 := instance1.providers["test-mock-provider"]
|
||||
provider2, exists2 := instance2.providers["test-mock-provider"]
|
||||
provider3, exists3 := instance3.providers["test-mock-provider"]
|
||||
|
||||
assert.True(t, exists1, "Instance 1 should have test-mock-provider")
|
||||
assert.True(t, exists2, "Instance 2 should have test-mock-provider")
|
||||
assert.True(t, exists3, "Instance 3 should have test-mock-provider")
|
||||
|
||||
assert.Equal(t, provider1.Name(), provider2.Name())
|
||||
assert.Equal(t, provider2.Name(), provider3.Name())
|
||||
|
||||
// Test authentication with the mock provider on all instances
|
||||
testToken := "valid_test_token"
|
||||
|
||||
identity1, err1 := provider1.Authenticate(ctx, testToken)
|
||||
identity2, err2 := provider2.Authenticate(ctx, testToken)
|
||||
identity3, err3 := provider3.Authenticate(ctx, testToken)
|
||||
|
||||
require.NoError(t, err1, "Instance 1 provider should authenticate successfully")
|
||||
require.NoError(t, err2, "Instance 2 provider should authenticate successfully")
|
||||
require.NoError(t, err3, "Instance 3 provider should authenticate successfully")
|
||||
|
||||
// All instances should return identical identity information
|
||||
assert.Equal(t, identity1.UserID, identity2.UserID)
|
||||
assert.Equal(t, identity2.UserID, identity3.UserID)
|
||||
assert.Equal(t, identity1.Email, identity2.Email)
|
||||
assert.Equal(t, identity2.Email, identity3.Email)
|
||||
assert.Equal(t, identity1.Provider, identity2.Provider)
|
||||
assert.Equal(t, identity2.Provider, identity3.Provider)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSTSConfigurationValidation tests configuration validation for distributed deployments
|
||||
func TestSTSConfigurationValidation(t *testing.T) {
|
||||
t.Run("consistent_signing_keys_required", func(t *testing.T) {
|
||||
// Different signing keys should result in incompatible token validation
|
||||
config1 := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("signing-key-1-32-characters-long"),
|
||||
}
|
||||
|
||||
config2 := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("signing-key-2-32-characters-long"), // Different key!
|
||||
}
|
||||
|
||||
instance1 := NewSTSService()
|
||||
instance2 := NewSTSService()
|
||||
|
||||
err1 := instance1.Initialize(config1)
|
||||
err2 := instance2.Initialize(config2)
|
||||
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
|
||||
// Generate token on instance 1
|
||||
sessionId := "test-session"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instance 1 should validate its own token
|
||||
_, err = instance1.GetTokenGenerator().ValidateSessionToken(token)
|
||||
assert.NoError(t, err, "Instance 1 should validate its own token")
|
||||
|
||||
// Instance 2 should reject token from instance 1 (different signing key)
|
||||
_, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
|
||||
assert.Error(t, err, "Instance 2 should reject token with different signing key")
|
||||
})
|
||||
|
||||
t.Run("consistent_issuer_required", func(t *testing.T) {
|
||||
// Different issuers should result in incompatible tokens
|
||||
commonSigningKey := []byte("shared-signing-key-32-characters-lo")
|
||||
|
||||
config1 := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "sts-instance-1",
|
||||
SigningKey: commonSigningKey,
|
||||
}
|
||||
|
||||
config2 := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{12 * time.Hour},
|
||||
Issuer: "sts-instance-2", // Different issuer!
|
||||
SigningKey: commonSigningKey,
|
||||
}
|
||||
|
||||
instance1 := NewSTSService()
|
||||
instance2 := NewSTSService()
|
||||
|
||||
err1 := instance1.Initialize(config1)
|
||||
err2 := instance2.Initialize(config2)
|
||||
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
|
||||
// Generate token on instance 1
|
||||
sessionId := "test-session"
|
||||
expiresAt := time.Now().Add(time.Hour)
|
||||
token, err := instance1.GetTokenGenerator().GenerateSessionToken(sessionId, expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instance 2 should reject token due to issuer mismatch
|
||||
// (Even though signing key is the same, issuer validation will fail)
|
||||
_, err = instance2.GetTokenGenerator().ValidateSessionToken(token)
|
||||
assert.Error(t, err, "Instance 2 should reject token with different issuer")
|
||||
})
|
||||
}
|
||||
|
||||
// TestProviderFactoryDistributed tests the provider factory in distributed scenarios
|
||||
func TestProviderFactoryDistributed(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Simulate configuration that would be identical across all instances
|
||||
configs := []*ProviderConfig{
|
||||
{
|
||||
Name: "production-keycloak",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://keycloak.company.com/realms/seaweedfs",
|
||||
"clientId": "seaweedfs-prod",
|
||||
"clientSecret": "super-secret-key",
|
||||
"jwksUri": "https://keycloak.company.com/realms/seaweedfs/protocol/openid-connect/certs",
|
||||
"scopes": []string{"openid", "profile", "email", "roles"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "backup-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: false, // Disabled by default
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://backup-oidc.company.com",
|
||||
"clientId": "seaweedfs-backup",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create providers multiple times (simulating multiple instances)
|
||||
providers1, err1 := factory.LoadProvidersFromConfig(configs)
|
||||
providers2, err2 := factory.LoadProvidersFromConfig(configs)
|
||||
providers3, err3 := factory.LoadProvidersFromConfig(configs)
|
||||
|
||||
require.NoError(t, err1, "First load should succeed")
|
||||
require.NoError(t, err2, "Second load should succeed")
|
||||
require.NoError(t, err3, "Third load should succeed")
|
||||
|
||||
// All instances should have same provider counts
|
||||
assert.Len(t, providers1, 1, "First instance should have 1 enabled provider")
|
||||
assert.Len(t, providers2, 1, "Second instance should have 1 enabled provider")
|
||||
assert.Len(t, providers3, 1, "Third instance should have 1 enabled provider")
|
||||
|
||||
// All instances should have same provider names
|
||||
names1 := make([]string, 0, len(providers1))
|
||||
names2 := make([]string, 0, len(providers2))
|
||||
names3 := make([]string, 0, len(providers3))
|
||||
|
||||
for name := range providers1 {
|
||||
names1 = append(names1, name)
|
||||
}
|
||||
for name := range providers2 {
|
||||
names2 = append(names2, name)
|
||||
}
|
||||
for name := range providers3 {
|
||||
names3 = append(names3, name)
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, names1, names2, "Instance 1 and 2 should have same provider names")
|
||||
assert.ElementsMatch(t, names2, names3, "Instance 2 and 3 should have same provider names")
|
||||
|
||||
// Verify specific providers
|
||||
expectedProviders := []string{"production-keycloak"}
|
||||
assert.ElementsMatch(t, names1, expectedProviders, "Should have expected enabled providers")
|
||||
|
||||
// Verify disabled providers are not included
|
||||
assert.NotContains(t, names1, "backup-oidc", "Disabled providers should not be loaded")
|
||||
assert.NotContains(t, names2, "backup-oidc", "Disabled providers should not be loaded")
|
||||
assert.NotContains(t, names3, "backup-oidc", "Disabled providers should not be loaded")
|
||||
}
|
||||
@@ -274,69 +274,3 @@ func (f *ProviderFactory) convertToRoleMapping(value interface{}) (*providers.Ro
|
||||
|
||||
return roleMapping, nil
|
||||
}
|
||||
|
||||
// ValidateProviderConfig validates a provider configuration
|
||||
func (f *ProviderFactory) ValidateProviderConfig(config *ProviderConfig) error {
|
||||
if config == nil {
|
||||
return fmt.Errorf("provider config cannot be nil")
|
||||
}
|
||||
|
||||
if config.Name == "" {
|
||||
return fmt.Errorf("provider name cannot be empty")
|
||||
}
|
||||
|
||||
if config.Type == "" {
|
||||
return fmt.Errorf("provider type cannot be empty")
|
||||
}
|
||||
|
||||
if config.Config == nil {
|
||||
return fmt.Errorf("provider config cannot be nil")
|
||||
}
|
||||
|
||||
// Type-specific validation
|
||||
switch config.Type {
|
||||
case "oidc":
|
||||
return f.validateOIDCConfig(config.Config)
|
||||
case "ldap":
|
||||
return f.validateLDAPConfig(config.Config)
|
||||
case "saml":
|
||||
return f.validateSAMLConfig(config.Config)
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider type: %s", config.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// validateOIDCConfig validates OIDC provider configuration
|
||||
func (f *ProviderFactory) validateOIDCConfig(config map[string]interface{}) error {
|
||||
if _, ok := config[ConfigFieldIssuer]; !ok {
|
||||
return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldIssuer)
|
||||
}
|
||||
|
||||
if _, ok := config[ConfigFieldClientID]; !ok {
|
||||
return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldClientID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateLDAPConfig validates LDAP provider configuration
|
||||
func (f *ProviderFactory) validateLDAPConfig(config map[string]interface{}) error {
|
||||
if _, ok := config["server"]; !ok {
|
||||
return fmt.Errorf("LDAP provider requires 'server' field")
|
||||
}
|
||||
if _, ok := config["baseDN"]; !ok {
|
||||
return fmt.Errorf("LDAP provider requires 'baseDN' field")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateSAMLConfig validates SAML provider configuration
|
||||
func (f *ProviderFactory) validateSAMLConfig(config map[string]interface{}) error {
|
||||
// TODO: Implement when SAML provider is available
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSupportedProviderTypes returns list of supported provider types
|
||||
func (f *ProviderFactory) GetSupportedProviderTypes() []string {
|
||||
return []string{ProviderTypeOIDC}
|
||||
}
|
||||
|
||||
@@ -1,312 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProviderFactory_CreateOIDCProvider(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
config := &ProviderConfig{
|
||||
Name: "test-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
"clientSecret": "test-secret",
|
||||
"jwksUri": "https://test-issuer.com/.well-known/jwks.json",
|
||||
"scopes": []string{"openid", "profile", "email"},
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
assert.Equal(t, "test-oidc", provider.Name())
|
||||
}
|
||||
|
||||
// Note: Mock provider tests removed - mock providers are now test-only
|
||||
// and not available through the production ProviderFactory
|
||||
|
||||
func TestProviderFactory_DisabledProvider(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
config := &ProviderConfig{
|
||||
Name: "disabled-provider",
|
||||
Type: "oidc",
|
||||
Enabled: false,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, provider) // Should return nil for disabled providers
|
||||
}
|
||||
|
||||
func TestProviderFactory_InvalidProviderType(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-provider",
|
||||
Type: "unsupported-type",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "unsupported provider type")
|
||||
}
|
||||
|
||||
func TestProviderFactory_LoadMultipleProviders(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
configs := []*ProviderConfig{
|
||||
{
|
||||
Name: "oidc-provider",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://oidc-issuer.com",
|
||||
"clientId": "oidc-client",
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Name: "disabled-provider",
|
||||
Type: "oidc",
|
||||
Enabled: false,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://disabled-issuer.com",
|
||||
"clientId": "disabled-client",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
providers, err := factory.LoadProvidersFromConfig(configs)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, providers, 1) // Only enabled providers should be loaded
|
||||
|
||||
assert.Contains(t, providers, "oidc-provider")
|
||||
assert.NotContains(t, providers, "disabled-provider")
|
||||
}
|
||||
|
||||
func TestProviderFactory_ValidateOIDCConfig(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "valid-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://valid-issuer.com",
|
||||
"clientId": "valid-client",
|
||||
},
|
||||
}
|
||||
|
||||
err := factory.ValidateProviderConfig(config)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing issuer", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"clientId": "valid-client",
|
||||
},
|
||||
}
|
||||
|
||||
err := factory.ValidateProviderConfig(config)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "issuer")
|
||||
})
|
||||
|
||||
t.Run("missing clientId", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-oidc",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://valid-issuer.com",
|
||||
},
|
||||
}
|
||||
|
||||
err := factory.ValidateProviderConfig(config)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "clientId")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderFactory_ConvertToStringSlice(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
t.Run("string slice", func(t *testing.T) {
|
||||
input := []string{"a", "b", "c"}
|
||||
result, err := factory.convertToStringSlice(input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"a", "b", "c"}, result)
|
||||
})
|
||||
|
||||
t.Run("interface slice", func(t *testing.T) {
|
||||
input := []interface{}{"a", "b", "c"}
|
||||
result, err := factory.convertToStringSlice(input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"a", "b", "c"}, result)
|
||||
})
|
||||
|
||||
t.Run("invalid type", func(t *testing.T) {
|
||||
input := "not-a-slice"
|
||||
result, err := factory.convertToStringSlice(input)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderFactory_ConfigConversionErrors(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
t.Run("invalid scopes type", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-scopes",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
"scopes": "invalid-not-array", // Should be array
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "failed to convert scopes")
|
||||
})
|
||||
|
||||
t.Run("invalid claimsMapping type", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-claims",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
"claimsMapping": "invalid-not-map", // Should be map
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "failed to convert claimsMapping")
|
||||
})
|
||||
|
||||
t.Run("invalid roleMapping type", func(t *testing.T) {
|
||||
config := &ProviderConfig{
|
||||
Name: "invalid-roles",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
"roleMapping": "invalid-not-map", // Should be map
|
||||
},
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(config)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "failed to convert roleMapping")
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderFactory_ConvertToStringMap(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
t.Run("string map", func(t *testing.T) {
|
||||
input := map[string]string{"key1": "value1", "key2": "value2"}
|
||||
result, err := factory.convertToStringMap(input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
|
||||
})
|
||||
|
||||
t.Run("interface map", func(t *testing.T) {
|
||||
input := map[string]interface{}{"key1": "value1", "key2": "value2"}
|
||||
result, err := factory.convertToStringMap(input)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, map[string]string{"key1": "value1", "key2": "value2"}, result)
|
||||
})
|
||||
|
||||
t.Run("invalid type", func(t *testing.T) {
|
||||
input := "not-a-map"
|
||||
result, err := factory.convertToStringMap(input)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderFactory_GetSupportedProviderTypes(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
supportedTypes := factory.GetSupportedProviderTypes()
|
||||
assert.Contains(t, supportedTypes, "oidc")
|
||||
assert.Len(t, supportedTypes, 1) // Currently only OIDC is supported in production
|
||||
}
|
||||
|
||||
func TestSTSService_LoadProvidersFromConfig(t *testing.T) {
|
||||
stsConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{3600 * time.Second},
|
||||
MaxSessionLength: FlexibleDuration{43200 * time.Second},
|
||||
Issuer: "test-issuer",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
Providers: []*ProviderConfig{
|
||||
{
|
||||
Name: "test-provider",
|
||||
Type: "oidc",
|
||||
Enabled: true,
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "https://test-issuer.com",
|
||||
"clientId": "test-client",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stsService := NewSTSService()
|
||||
err := stsService.Initialize(stsConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that provider was loaded
|
||||
assert.Len(t, stsService.providers, 1)
|
||||
assert.Contains(t, stsService.providers, "test-provider")
|
||||
assert.Equal(t, "test-provider", stsService.providers["test-provider"].Name())
|
||||
}
|
||||
|
||||
func TestSTSService_NoProvidersConfig(t *testing.T) {
|
||||
stsConfig := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{3600 * time.Second},
|
||||
MaxSessionLength: FlexibleDuration{43200 * time.Second},
|
||||
Issuer: "test-issuer",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
// No providers configured
|
||||
}
|
||||
|
||||
stsService := NewSTSService()
|
||||
err := stsService.Initialize(stsConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should initialize successfully with no providers
|
||||
assert.Len(t, stsService.providers, 0)
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSecurityIssuerToProviderMapping tests the security fix that ensures JWT tokens
|
||||
// with specific issuer claims can only be validated by the provider registered for that issuer
|
||||
func TestSecurityIssuerToProviderMapping(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create STS service with two mock providers
|
||||
service := NewSTSService()
|
||||
config := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{time.Hour * 12},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
}
|
||||
|
||||
err := service.Initialize(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up mock trust policy validator
|
||||
mockValidator := &MockTrustPolicyValidator{}
|
||||
service.SetTrustPolicyValidator(mockValidator)
|
||||
|
||||
// Create two mock providers with different issuers
|
||||
providerA := &MockIdentityProviderWithIssuer{
|
||||
name: "provider-a",
|
||||
issuer: "https://provider-a.com",
|
||||
validTokens: map[string]bool{
|
||||
"token-for-provider-a": true,
|
||||
},
|
||||
}
|
||||
|
||||
providerB := &MockIdentityProviderWithIssuer{
|
||||
name: "provider-b",
|
||||
issuer: "https://provider-b.com",
|
||||
validTokens: map[string]bool{
|
||||
"token-for-provider-b": true,
|
||||
},
|
||||
}
|
||||
|
||||
// Register both providers
|
||||
err = service.RegisterProvider(providerA)
|
||||
require.NoError(t, err)
|
||||
err = service.RegisterProvider(providerB)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create JWT tokens with specific issuer claims
|
||||
tokenForProviderA := createTestJWT(t, "https://provider-a.com", "user-a")
|
||||
tokenForProviderB := createTestJWT(t, "https://provider-b.com", "user-b")
|
||||
|
||||
t.Run("jwt_token_with_issuer_a_only_validated_by_provider_a", func(t *testing.T) {
|
||||
// This should succeed - token has issuer A and provider A is registered
|
||||
identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderA)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, identity)
|
||||
assert.Equal(t, "provider-a", provider.Name())
|
||||
})
|
||||
|
||||
t.Run("jwt_token_with_issuer_b_only_validated_by_provider_b", func(t *testing.T) {
|
||||
// This should succeed - token has issuer B and provider B is registered
|
||||
identity, provider, err := service.validateWebIdentityToken(ctx, tokenForProviderB)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, identity)
|
||||
assert.Equal(t, "provider-b", provider.Name())
|
||||
})
|
||||
|
||||
t.Run("jwt_token_with_unregistered_issuer_fails", func(t *testing.T) {
|
||||
// Create token with unregistered issuer
|
||||
tokenWithUnknownIssuer := createTestJWT(t, "https://unknown-issuer.com", "user-x")
|
||||
|
||||
// This should fail - no provider registered for this issuer
|
||||
identity, provider, err := service.validateWebIdentityToken(ctx, tokenWithUnknownIssuer)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, identity)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "no identity provider registered for issuer: https://unknown-issuer.com")
|
||||
})
|
||||
|
||||
t.Run("non_jwt_tokens_are_rejected", func(t *testing.T) {
|
||||
// Non-JWT tokens should be rejected - no fallback mechanism exists for security
|
||||
identity, provider, err := service.validateWebIdentityToken(ctx, "token-for-provider-a")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, identity)
|
||||
assert.Nil(t, provider)
|
||||
assert.Contains(t, err.Error(), "web identity token must be a valid JWT token")
|
||||
})
|
||||
}
|
||||
|
||||
// createTestJWT creates a test JWT token with the specified issuer and subject
|
||||
func createTestJWT(t *testing.T, issuer, subject string) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": subject,
|
||||
"aud": "test-client",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// MockIdentityProviderWithIssuer is a mock provider that supports issuer mapping
|
||||
type MockIdentityProviderWithIssuer struct {
|
||||
name string
|
||||
issuer string
|
||||
validTokens map[string]bool
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) GetIssuer() string {
|
||||
return m.issuer
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) Initialize(config interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
|
||||
// For JWT tokens, parse and validate the token format
|
||||
if len(token) > 50 && strings.Contains(token, ".") {
|
||||
// This looks like a JWT - parse it to get the subject
|
||||
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT token")
|
||||
}
|
||||
|
||||
claims, ok := parsedToken.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid claims")
|
||||
}
|
||||
|
||||
issuer, _ := claims["iss"].(string)
|
||||
subject, _ := claims["sub"].(string)
|
||||
|
||||
// Verify the issuer matches what we expect
|
||||
if issuer != m.issuer {
|
||||
return nil, fmt.Errorf("token issuer %s does not match provider issuer %s", issuer, m.issuer)
|
||||
}
|
||||
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: subject,
|
||||
Email: subject + "@" + m.name + ".com",
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// For non-JWT tokens, check our simple token list
|
||||
if m.validTokens[token] {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: "test-user",
|
||||
Email: "test@" + m.name + ".com",
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: userID,
|
||||
Email: userID + "@" + m.name + ".com",
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithIssuer) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
|
||||
if m.validTokens[token] {
|
||||
return &providers.TokenClaims{
|
||||
Subject: "test-user",
|
||||
Issuer: m.issuer,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
@@ -1,168 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// createSessionPolicyTestJWT creates a test JWT token for session policy tests
|
||||
func createSessionPolicyTestJWT(t *testing.T, issuer, subject string) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": subject,
|
||||
"aud": "test-client",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithWebIdentity_SessionPolicy verifies inline session policies are preserved in tokens.
|
||||
func TestAssumeRoleWithWebIdentity_SessionPolicy(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:GetObject","Resource":"arn:aws:s3:::example-bucket/*"}]}`
|
||||
testToken := createSessionPolicyTestJWT(t, "test-issuer", "test-user")
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: testToken,
|
||||
RoleSessionName: "test-session",
|
||||
Policy: &sessionPolicy,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
normalized, err := NormalizeSessionPolicy(sessionPolicy)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, normalized, sessionInfo.SessionPolicy)
|
||||
|
||||
t.Run("should_succeed_without_session_policy", func(t *testing.T) {
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sessionInfo.SessionPolicy)
|
||||
})
|
||||
}
|
||||
|
||||
// Test edge case scenarios for the Policy field handling
|
||||
func TestAssumeRoleWithWebIdentity_SessionPolicy_EdgeCases(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("malformed_json_policy_rejected", func(t *testing.T) {
|
||||
malformedPolicy := `{"Version": "2012-10-17", "Statement": [` // Incomplete JSON
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
Policy: &malformedPolicy,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, response)
|
||||
assert.Contains(t, err.Error(), "invalid session policy JSON")
|
||||
})
|
||||
|
||||
t.Run("invalid_policy_document_rejected", func(t *testing.T) {
|
||||
invalidPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow"}]}`
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
Policy: &invalidPolicy,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, response)
|
||||
assert.Contains(t, err.Error(), "invalid session policy document")
|
||||
})
|
||||
|
||||
t.Run("whitespace_policy_ignored", func(t *testing.T) {
|
||||
whitespacePolicy := " \t\n "
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSessionPolicyTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
Policy: &whitespacePolicy,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sessionInfo.SessionPolicy)
|
||||
})
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation verifies that the struct field exists and is optional.
|
||||
func TestAssumeRoleWithWebIdentity_PolicyFieldDocumentation(t *testing.T) {
|
||||
request := &AssumeRoleWithWebIdentityRequest{}
|
||||
|
||||
assert.IsType(t, (*string)(nil), request.Policy,
|
||||
"Policy field should be *string type for optional JSON policy")
|
||||
assert.Nil(t, request.Policy,
|
||||
"Policy field should default to nil (no session policy)")
|
||||
|
||||
policyValue := `{"Version": "2012-10-17"}`
|
||||
request.Policy = &policyValue
|
||||
assert.NotNil(t, request.Policy, "Should be able to assign policy value")
|
||||
assert.Equal(t, policyValue, *request.Policy, "Policy value should be preserved")
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithCredentials_SessionPolicy verifies session policy support for credentials-based flow.
|
||||
func TestAssumeRoleWithCredentials_SessionPolicy(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sessionPolicy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"filer:CreateEntry","Resource":"arn:aws:filer::path/user-docs/*"}]}`
|
||||
request := &AssumeRoleWithCredentialsRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
RoleSessionName: "test-session",
|
||||
ProviderName: "test-ldap",
|
||||
Policy: &sessionPolicy,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithCredentials(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
normalized, err := NormalizeSessionPolicy(sessionPolicy)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, normalized, sessionInfo.SessionPolicy)
|
||||
}
|
||||
@@ -879,21 +879,6 @@ func (s *STSService) calculateSessionDuration(durationSeconds *int64, tokenExpir
|
||||
return duration
|
||||
}
|
||||
|
||||
// extractSessionIdFromToken extracts session ID from JWT session token
|
||||
func (s *STSService) extractSessionIdFromToken(sessionToken string) string {
|
||||
// Validate JWT and extract session claims
|
||||
claims, err := s.tokenGenerator.ValidateJWTWithClaims(sessionToken)
|
||||
if err != nil {
|
||||
// For test compatibility, also handle direct session IDs
|
||||
if len(sessionToken) == 32 { // Typical session ID length
|
||||
return sessionToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
return claims.SessionId
|
||||
}
|
||||
|
||||
// validateAssumeRoleWithCredentialsRequest validates the credentials request parameters
|
||||
func (s *STSService) validateAssumeRoleWithCredentialsRequest(request *AssumeRoleWithCredentialsRequest) error {
|
||||
if request.RoleArn == "" {
|
||||
|
||||
@@ -1,778 +0,0 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// createSTSTestJWT creates a test JWT token for STS service tests
|
||||
func createSTSTestJWT(t *testing.T, issuer, subject string) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": subject,
|
||||
"aud": "test-client",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// TestSTSServiceInitialization tests STS service initialization
|
||||
func TestSTSServiceInitialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *STSConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{time.Hour * 12},
|
||||
Issuer: "seaweedfs-sts",
|
||||
SigningKey: []byte("test-signing-key"),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing signing key",
|
||||
config: &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
Issuer: "seaweedfs-sts",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid token duration",
|
||||
config: &STSConfig{
|
||||
TokenDuration: FlexibleDuration{-time.Hour},
|
||||
Issuer: "seaweedfs-sts",
|
||||
SigningKey: []byte("test-key"),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service := NewSTSService()
|
||||
|
||||
err := service.Initialize(tt.config)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, service.IsInitialized())
|
||||
|
||||
// Verify defaults if applicable
|
||||
if tt.config.Issuer == "" {
|
||||
assert.Equal(t, DefaultIssuer, service.Config.Issuer)
|
||||
}
|
||||
if tt.config.TokenDuration.Duration == 0 {
|
||||
assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, service.Config.TokenDuration.Duration)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSTSServiceDefaults(t *testing.T) {
|
||||
service := NewSTSService()
|
||||
config := &STSConfig{
|
||||
SigningKey: []byte("test-signing-key"),
|
||||
// Missing duration and issuer
|
||||
}
|
||||
|
||||
err := service.Initialize(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, DefaultIssuer, config.Issuer)
|
||||
assert.Equal(t, time.Duration(DefaultTokenDuration)*time.Second, config.TokenDuration.Duration)
|
||||
assert.Equal(t, time.Duration(DefaultMaxSessionLength)*time.Second, config.MaxSessionLength.Duration)
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithWebIdentity tests role assumption with OIDC tokens
|
||||
func TestAssumeRoleWithWebIdentity(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roleArn string
|
||||
webIdentityToken string
|
||||
sessionName string
|
||||
durationSeconds *int64
|
||||
wantErr bool
|
||||
expectedSubject string
|
||||
}{
|
||||
{
|
||||
name: "successful role assumption",
|
||||
roleArn: "arn:aws:iam::role/TestRole",
|
||||
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user-id"),
|
||||
sessionName: "test-session",
|
||||
durationSeconds: nil, // Use default
|
||||
wantErr: false,
|
||||
expectedSubject: "test-user-id",
|
||||
},
|
||||
{
|
||||
name: "invalid web identity token",
|
||||
roleArn: "arn:aws:iam::role/TestRole",
|
||||
webIdentityToken: "invalid-token",
|
||||
sessionName: "test-session",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent role",
|
||||
roleArn: "arn:aws:iam::role/NonExistentRole",
|
||||
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
|
||||
sessionName: "test-session",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "custom session duration",
|
||||
roleArn: "arn:aws:iam::role/TestRole",
|
||||
webIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
|
||||
sessionName: "test-session",
|
||||
durationSeconds: int64Ptr(7200), // 2 hours
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: tt.roleArn,
|
||||
WebIdentityToken: tt.webIdentityToken,
|
||||
RoleSessionName: tt.sessionName,
|
||||
DurationSeconds: tt.durationSeconds,
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, response)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, response)
|
||||
assert.NotNil(t, response.Credentials)
|
||||
assert.NotNil(t, response.AssumedRoleUser)
|
||||
|
||||
// Verify credentials
|
||||
creds := response.Credentials
|
||||
assert.NotEmpty(t, creds.AccessKeyId)
|
||||
assert.NotEmpty(t, creds.SecretAccessKey)
|
||||
assert.NotEmpty(t, creds.SessionToken)
|
||||
assert.True(t, creds.Expiration.After(time.Now()))
|
||||
|
||||
// Verify assumed role user
|
||||
user := response.AssumedRoleUser
|
||||
assert.Equal(t, tt.roleArn, user.AssumedRoleId)
|
||||
assert.Contains(t, user.Arn, tt.sessionName)
|
||||
|
||||
if tt.expectedSubject != "" {
|
||||
assert.Equal(t, tt.expectedSubject, user.Subject)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithLDAP tests role assumption with LDAP credentials
|
||||
func TestAssumeRoleWithLDAP(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roleArn string
|
||||
username string
|
||||
password string
|
||||
sessionName string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful LDAP role assumption",
|
||||
roleArn: "arn:aws:iam::role/LDAPRole",
|
||||
username: "testuser",
|
||||
password: "testpass",
|
||||
sessionName: "ldap-session",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid LDAP credentials",
|
||||
roleArn: "arn:aws:iam::role/LDAPRole",
|
||||
username: "testuser",
|
||||
password: "wrongpass",
|
||||
sessionName: "ldap-session",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
request := &AssumeRoleWithCredentialsRequest{
|
||||
RoleArn: tt.roleArn,
|
||||
Username: tt.username,
|
||||
Password: tt.password,
|
||||
RoleSessionName: tt.sessionName,
|
||||
ProviderName: "test-ldap",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithCredentials(ctx, request)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, response)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, response)
|
||||
assert.NotNil(t, response.Credentials)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTokenValidation tests session token validation
|
||||
func TestSessionTokenValidation(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// First, create a session
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
sessionToken := response.Credentials.SessionToken
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid session token",
|
||||
token: sessionToken,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid session token",
|
||||
token: "invalid-session-token",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty session token",
|
||||
token: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
session, err := service.ValidateSessionToken(ctx, tt.token)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, session)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, session)
|
||||
assert.Equal(t, "test-session", session.SessionName)
|
||||
assert.Equal(t, "arn:aws:iam::role/TestRole", session.RoleArn)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTokenPersistence tests that JWT tokens remain valid throughout their lifetime
|
||||
// Note: In the stateless JWT design, tokens cannot be revoked and remain valid until expiration
|
||||
func TestSessionTokenPersistence(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session first
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: createSTSTestJWT(t, "test-issuer", "test-user"),
|
||||
RoleSessionName: "test-session",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionToken := response.Credentials.SessionToken
|
||||
|
||||
// Verify token is valid initially
|
||||
session, err := service.ValidateSessionToken(ctx, sessionToken)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, session)
|
||||
assert.Equal(t, "test-session", session.SessionName)
|
||||
|
||||
// In a stateless JWT system, tokens remain valid throughout their lifetime
|
||||
// Multiple validations should all succeed as long as the token hasn't expired
|
||||
session2, err := service.ValidateSessionToken(ctx, sessionToken)
|
||||
assert.NoError(t, err, "Token should remain valid in stateless system")
|
||||
assert.NotNil(t, session2, "Session should be returned from JWT token")
|
||||
assert.Equal(t, session.SessionId, session2.SessionId, "Session ID should be consistent")
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func setupTestSTSService(t *testing.T) *STSService {
|
||||
service := NewSTSService()
|
||||
|
||||
config := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{time.Hour * 12},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
}
|
||||
|
||||
err := service.Initialize(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up mock trust policy validator (required for STS testing)
|
||||
mockValidator := &MockTrustPolicyValidator{}
|
||||
service.SetTrustPolicyValidator(mockValidator)
|
||||
|
||||
// Register test providers
|
||||
mockOIDCProvider := &MockIdentityProvider{
|
||||
name: "test-oidc",
|
||||
validTokens: map[string]*providers.TokenClaims{
|
||||
createSTSTestJWT(t, "test-issuer", "test-user"): {
|
||||
Subject: "test-user-id",
|
||||
Issuer: "test-issuer",
|
||||
Claims: map[string]interface{}{
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mockLDAPProvider := &MockIdentityProvider{
|
||||
name: "test-ldap",
|
||||
validCredentials: map[string]string{
|
||||
"testuser": "testpass",
|
||||
},
|
||||
}
|
||||
|
||||
service.RegisterProvider(mockOIDCProvider)
|
||||
service.RegisterProvider(mockLDAPProvider)
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func int64Ptr(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Mock identity provider for testing
|
||||
type MockIdentityProvider struct {
|
||||
name string
|
||||
validTokens map[string]*providers.TokenClaims
|
||||
validCredentials map[string]string
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) GetIssuer() string {
|
||||
return "test-issuer" // This matches the issuer in the token claims
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) Initialize(config interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
|
||||
// First try to parse as JWT token
|
||||
if len(token) > 20 && strings.Count(token, ".") >= 2 {
|
||||
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
|
||||
if err == nil {
|
||||
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
|
||||
issuer, _ := claims["iss"].(string)
|
||||
subject, _ := claims["sub"].(string)
|
||||
|
||||
// Verify the issuer matches what we expect
|
||||
if issuer == "test-issuer" && subject != "" {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: subject,
|
||||
Email: subject + "@test-domain.com",
|
||||
DisplayName: "Test User " + subject,
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle legacy OIDC tokens (for backwards compatibility)
|
||||
if claims, exists := m.validTokens[token]; exists {
|
||||
email, _ := claims.GetClaimString("email")
|
||||
name, _ := claims.GetClaimString("name")
|
||||
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: claims.Subject,
|
||||
Email: email,
|
||||
DisplayName: name,
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Handle LDAP credentials (username:password format)
|
||||
if m.validCredentials != nil {
|
||||
parts := strings.Split(token, ":")
|
||||
if len(parts) == 2 {
|
||||
username, password := parts[0], parts[1]
|
||||
if expectedPassword, exists := m.validCredentials[username]; exists && expectedPassword == password {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: username,
|
||||
Email: username + "@" + m.name + ".com",
|
||||
DisplayName: "Test User " + username,
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unknown test token: %s", token)
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: userID,
|
||||
Email: userID + "@" + m.name + ".com",
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProvider) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
|
||||
if claims, exists := m.validTokens[token]; exists {
|
||||
return claims, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
// TestSessionDurationCappedByTokenExpiration tests that session duration is capped by the source token's exp claim
|
||||
func TestSessionDurationCappedByTokenExpiration(t *testing.T) {
|
||||
service := NewSTSService()
|
||||
|
||||
config := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour}, // Default: 1 hour
|
||||
MaxSessionLength: FlexibleDuration{time.Hour * 12},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
}
|
||||
|
||||
err := service.Initialize(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
durationSeconds *int64
|
||||
tokenExpiration *time.Time
|
||||
expectedMaxSeconds int64
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no token expiration - use default duration",
|
||||
durationSeconds: nil,
|
||||
tokenExpiration: nil,
|
||||
expectedMaxSeconds: 3600, // 1 hour default
|
||||
description: "When no token expiration is set, use the configured default duration",
|
||||
},
|
||||
{
|
||||
name: "token expires before default duration",
|
||||
durationSeconds: nil,
|
||||
tokenExpiration: timePtr(time.Now().Add(30 * time.Minute)),
|
||||
expectedMaxSeconds: 30 * 60, // 30 minutes
|
||||
description: "When token expires in 30 min, session should be capped at 30 min",
|
||||
},
|
||||
{
|
||||
name: "token expires after default duration - use default",
|
||||
durationSeconds: nil,
|
||||
tokenExpiration: timePtr(time.Now().Add(2 * time.Hour)),
|
||||
expectedMaxSeconds: 3600, // 1 hour default, since it's less than 2 hour token expiry
|
||||
description: "When token expires after default duration, use the default duration",
|
||||
},
|
||||
{
|
||||
name: "requested duration shorter than token expiry",
|
||||
durationSeconds: int64Ptr(1800), // 30 min requested
|
||||
tokenExpiration: timePtr(time.Now().Add(time.Hour)),
|
||||
expectedMaxSeconds: 1800, // 30 minutes as requested
|
||||
description: "When requested duration is shorter than token expiry, use requested duration",
|
||||
},
|
||||
{
|
||||
name: "requested duration longer than token expiry - cap at token expiry",
|
||||
durationSeconds: int64Ptr(3600), // 1 hour requested
|
||||
tokenExpiration: timePtr(time.Now().Add(15 * time.Minute)),
|
||||
expectedMaxSeconds: 15 * 60, // Capped at 15 minutes
|
||||
description: "When requested duration exceeds token expiry, cap at token expiry",
|
||||
},
|
||||
{
|
||||
name: "GitLab CI short-lived token scenario",
|
||||
durationSeconds: nil,
|
||||
tokenExpiration: timePtr(time.Now().Add(5 * time.Minute)),
|
||||
expectedMaxSeconds: 5 * 60, // 5 minutes
|
||||
description: "GitLab CI job with 5 minute timeout should result in 5 minute session",
|
||||
},
|
||||
{
|
||||
name: "already expired token - defense in depth",
|
||||
durationSeconds: nil,
|
||||
tokenExpiration: timePtr(time.Now().Add(-5 * time.Minute)), // Expired 5 minutes ago
|
||||
expectedMaxSeconds: 60, // 1 minute minimum
|
||||
description: "Already expired token should result in minimal 1 minute session",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
duration := service.calculateSessionDuration(tt.durationSeconds, tt.tokenExpiration)
|
||||
|
||||
// Allow 5 second tolerance for time calculations
|
||||
maxExpected := time.Duration(tt.expectedMaxSeconds+5) * time.Second
|
||||
minExpected := time.Duration(tt.expectedMaxSeconds-5) * time.Second
|
||||
|
||||
assert.GreaterOrEqual(t, duration, minExpected,
|
||||
"%s: duration %v should be >= %v", tt.description, duration, minExpected)
|
||||
assert.LessOrEqual(t, duration, maxExpected,
|
||||
"%s: duration %v should be <= %v", tt.description, duration, maxExpected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithWebIdentityRespectsTokenExpiration tests end-to-end that session duration is capped
|
||||
func TestAssumeRoleWithWebIdentityRespectsTokenExpiration(t *testing.T) {
|
||||
service := NewSTSService()
|
||||
|
||||
config := &STSConfig{
|
||||
TokenDuration: FlexibleDuration{time.Hour},
|
||||
MaxSessionLength: FlexibleDuration{time.Hour * 12},
|
||||
Issuer: "test-sts",
|
||||
SigningKey: []byte("test-signing-key-32-characters-long"),
|
||||
}
|
||||
|
||||
err := service.Initialize(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set up mock trust policy validator
|
||||
mockValidator := &MockTrustPolicyValidator{}
|
||||
service.SetTrustPolicyValidator(mockValidator)
|
||||
|
||||
// Create a mock provider that returns tokens with short expiration
|
||||
shortLivedTokenExpiration := time.Now().Add(10 * time.Minute)
|
||||
mockProvider := &MockIdentityProviderWithExpiration{
|
||||
name: "short-lived-issuer",
|
||||
tokenExpiration: &shortLivedTokenExpiration,
|
||||
}
|
||||
service.RegisterProvider(mockProvider)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a JWT token with short expiration
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": "short-lived-issuer",
|
||||
"sub": "test-user",
|
||||
"aud": "test-client",
|
||||
"exp": shortLivedTokenExpiration.Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: tokenString,
|
||||
RoleSessionName: "test-session",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
// Verify the session expires at or before the token expiration
|
||||
// Allow 5 second tolerance
|
||||
assert.True(t, response.Credentials.Expiration.Before(shortLivedTokenExpiration.Add(5*time.Second)),
|
||||
"Session expiration (%v) should not exceed token expiration (%v)",
|
||||
response.Credentials.Expiration, shortLivedTokenExpiration)
|
||||
}
|
||||
|
||||
// MockIdentityProviderWithExpiration is a mock provider that returns tokens with configurable expiration
|
||||
type MockIdentityProviderWithExpiration struct {
|
||||
name string
|
||||
tokenExpiration *time.Time
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) GetIssuer() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) Initialize(config interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
|
||||
// Parse the token to get subject
|
||||
parsedToken, _, err := new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := parsedToken.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid claims")
|
||||
}
|
||||
|
||||
subject, _ := claims["sub"].(string)
|
||||
|
||||
identity := &providers.ExternalIdentity{
|
||||
UserID: subject,
|
||||
Email: subject + "@example.com",
|
||||
DisplayName: "Test User",
|
||||
Provider: m.name,
|
||||
TokenExpiration: m.tokenExpiration,
|
||||
}
|
||||
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: userID,
|
||||
Provider: m.name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithExpiration) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
|
||||
claims := &providers.TokenClaims{
|
||||
Subject: "test-user",
|
||||
Issuer: m.name,
|
||||
}
|
||||
if m.tokenExpiration != nil {
|
||||
claims.ExpiresAt = *m.tokenExpiration
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func timePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
// TestAssumeRoleWithWebIdentity_PreservesAttributes tests that attributes from the identity provider
|
||||
// are correctly propagated to the session token's request context
|
||||
func TestAssumeRoleWithWebIdentity_PreservesAttributes(t *testing.T) {
|
||||
service := setupTestSTSService(t)
|
||||
|
||||
// Create a mock provider that returns a user with attributes
|
||||
mockProvider := &MockIdentityProviderWithAttributes{
|
||||
name: "attr-provider",
|
||||
attributes: map[string]string{
|
||||
"preferred_username": "my-user",
|
||||
"department": "engineering",
|
||||
"project": "seaweedfs",
|
||||
},
|
||||
}
|
||||
service.RegisterProvider(mockProvider)
|
||||
|
||||
// Create a valid JWT token for the provider
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iss": "attr-provider",
|
||||
"sub": "test-user-id",
|
||||
"aud": "test-client",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
tokenString, err := token.SignedString([]byte("test-signing-key"))
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
request := &AssumeRoleWithWebIdentityRequest{
|
||||
RoleArn: "arn:aws:iam::role/TestRole",
|
||||
WebIdentityToken: tokenString,
|
||||
RoleSessionName: "test-session",
|
||||
}
|
||||
|
||||
response, err := service.AssumeRoleWithWebIdentity(ctx, request)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
|
||||
// Validate the session token to check claims
|
||||
sessionInfo, err := service.ValidateSessionToken(ctx, response.Credentials.SessionToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that attributes are present in RequestContext
|
||||
require.NotNil(t, sessionInfo.RequestContext, "RequestContext should not be nil")
|
||||
assert.Equal(t, "my-user", sessionInfo.RequestContext["preferred_username"])
|
||||
assert.Equal(t, "engineering", sessionInfo.RequestContext["department"])
|
||||
assert.Equal(t, "seaweedfs", sessionInfo.RequestContext["project"])
|
||||
|
||||
// Check standard claims are also present
|
||||
assert.Equal(t, "test-user-id", sessionInfo.RequestContext["sub"])
|
||||
assert.Equal(t, "test@example.com", sessionInfo.RequestContext["email"])
|
||||
assert.Equal(t, "Test User", sessionInfo.RequestContext["name"])
|
||||
}
|
||||
|
||||
// MockIdentityProviderWithAttributes is a mock provider that returns configured attributes
|
||||
type MockIdentityProviderWithAttributes struct {
|
||||
name string
|
||||
attributes map[string]string
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) GetIssuer() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) Initialize(config interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) Authenticate(ctx context.Context, token string) (*providers.ExternalIdentity, error) {
|
||||
return &providers.ExternalIdentity{
|
||||
UserID: "test-user-id",
|
||||
Email: "test@example.com",
|
||||
DisplayName: "Test User",
|
||||
Provider: m.name,
|
||||
Attributes: m.attributes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) GetUserInfo(ctx context.Context, userID string) (*providers.ExternalIdentity, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockIdentityProviderWithAttributes) ValidateToken(ctx context.Context, token string) (*providers.TokenClaims, error) {
|
||||
return &providers.TokenClaims{
|
||||
Subject: "test-user-id",
|
||||
Issuer: m.name,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,53 +1,4 @@
|
||||
package sts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/iam/providers"
|
||||
)
|
||||
|
||||
// MockTrustPolicyValidator is a simple mock for testing STS functionality
|
||||
type MockTrustPolicyValidator struct{}
|
||||
|
||||
// ValidateTrustPolicyForWebIdentity allows valid JWT test tokens for STS testing
|
||||
func (m *MockTrustPolicyValidator) ValidateTrustPolicyForWebIdentity(ctx context.Context, roleArn string, webIdentityToken string, durationSeconds *int64) error {
|
||||
// Reject non-existent roles for testing
|
||||
if strings.Contains(roleArn, "NonExistentRole") {
|
||||
return fmt.Errorf("trust policy validation failed: role does not exist")
|
||||
}
|
||||
|
||||
// For STS unit tests, allow JWT tokens that look valid (contain dots for JWT structure)
|
||||
// In real implementation, this would validate against actual trust policies
|
||||
if len(webIdentityToken) > 20 && strings.Count(webIdentityToken, ".") >= 2 {
|
||||
// This appears to be a JWT token - allow it for testing
|
||||
return nil
|
||||
}
|
||||
|
||||
// Legacy support for specific test tokens during migration
|
||||
if webIdentityToken == "valid_test_token" || webIdentityToken == "valid-oidc-token" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reject invalid tokens
|
||||
if webIdentityToken == "invalid_token" || webIdentityToken == "expired_token" || webIdentityToken == "invalid-token" {
|
||||
return fmt.Errorf("trust policy denies token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateTrustPolicyForCredentials allows valid test identities for STS testing
|
||||
func (m *MockTrustPolicyValidator) ValidateTrustPolicyForCredentials(ctx context.Context, roleArn string, identity *providers.ExternalIdentity) error {
|
||||
// Reject non-existent roles for testing
|
||||
if strings.Contains(roleArn, "NonExistentRole") {
|
||||
return fmt.Errorf("trust policy validation failed: role does not exist")
|
||||
}
|
||||
|
||||
// For STS unit tests, allow test identities
|
||||
if identity != nil && identity.UserID != "" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("invalid identity for role assumption")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user