diff --git a/test/s3/iam/s3_sts_get_federation_token_test.go b/test/s3/iam/s3_sts_get_federation_token_test.go new file mode 100644 index 000000000..2a718cba9 --- /dev/null +++ b/test/s3/iam/s3_sts_get_federation_token_test.go @@ -0,0 +1,511 @@ +package iam + +import ( + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + v4 "github.com/aws/aws-sdk-go/aws/signer/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// GetFederationTokenTestResponse represents the STS GetFederationToken response +type GetFederationTokenTestResponse struct { + XMLName xml.Name `xml:"GetFederationTokenResponse"` + Result struct { + Credentials struct { + AccessKeyId string `xml:"AccessKeyId"` + SecretAccessKey string `xml:"SecretAccessKey"` + SessionToken string `xml:"SessionToken"` + Expiration string `xml:"Expiration"` + } `xml:"Credentials"` + FederatedUser struct { + FederatedUserId string `xml:"FederatedUserId"` + Arn string `xml:"Arn"` + } `xml:"FederatedUser"` + } `xml:"GetFederationTokenResult"` +} + +func getTestCredentials() (string, string) { + accessKey := os.Getenv("STS_TEST_ACCESS_KEY") + if accessKey == "" { + accessKey = "admin" + } + secretKey := os.Getenv("STS_TEST_SECRET_KEY") + if secretKey == "" { + secretKey = "admin" + } + return accessKey, secretKey +} + +// isGetFederationTokenImplemented checks if the running server supports GetFederationToken +func isGetFederationTokenImplemented(t *testing.T) bool { + accessKey, secretKey := getTestCredentials() + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"probe"}, + }, accessKey, secretKey) + if err != nil { + return false + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + if xml.Unmarshal(body, &errResp) == nil { + if errResp.Error.Code == "InvalidAction" || errResp.Error.Code == "NotImplemented" { + return false + } + } + return true +} + +// TestSTSGetFederationTokenValidation tests input validation for the GetFederationToken endpoint +func TestSTSGetFederationTokenValidation(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + if !isSTSEndpointRunning(t) { + t.Fatal("SeaweedFS STS endpoint is not running at", TestSTSEndpoint, "- please run 'make setup-all-tests' first") + } + + if !isGetFederationTokenImplemented(t) { + t.Fatal("GetFederationToken action is not implemented in the running server") + } + + accessKey, secretKey := getTestCredentials() + + t.Run("missing_name", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + // Name is missing + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "MissingParameter", errResp.Error.Code) + }) + + t.Run("name_too_short", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"A"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "InvalidParameterValue", errResp.Error.Code) + }) + + t.Run("name_too_long", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {strings.Repeat("A", 33)}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "InvalidParameterValue", errResp.Error.Code) + }) + + t.Run("name_invalid_characters", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"bad name"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "InvalidParameterValue", errResp.Error.Code) + }) + + t.Run("duration_too_short", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"TestApp"}, + "DurationSeconds": {"100"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "InvalidParameterValue", errResp.Error.Code) + }) + + t.Run("duration_too_long", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"TestApp"}, + "DurationSeconds": {"200000"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "InvalidParameterValue", errResp.Error.Code) + }) + + t.Run("malformed_policy", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"TestApp"}, + "Policy": {"not-valid-json"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "MalformedPolicyDocument", errResp.Error.Code) + }) + + t.Run("anonymous_rejected", func(t *testing.T) { + // GetFederationToken requires SigV4, anonymous should fail + resp, err := callSTSAPI(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"TestApp"}, + }) + require.NoError(t, err) + defer resp.Body.Close() + + assert.NotEqual(t, http.StatusOK, resp.StatusCode) + }) +} + +// TestSTSGetFederationTokenRejectTemporaryCredentials tests that temporary +// credentials (session tokens) are rejected by GetFederationToken +func TestSTSGetFederationTokenRejectTemporaryCredentials(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + if !isSTSEndpointRunning(t) { + t.Skip("SeaweedFS STS endpoint is not running at", TestSTSEndpoint) + } + + if !isGetFederationTokenImplemented(t) { + t.Skip("GetFederationToken not implemented") + } + + accessKey, secretKey := getTestCredentials() + + // First, obtain temporary credentials via AssumeRole + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"AssumeRole"}, + "Version": {"2011-06-15"}, + "RoleArn": {"arn:aws:iam::role/admin"}, + "RoleSessionName": {"temp-session"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + if resp.StatusCode != http.StatusOK { + t.Skipf("AssumeRole failed (may not be configured): status=%d body=%s", resp.StatusCode, string(body)) + } + + var assumeResp AssumeRoleTestResponse + require.NoError(t, xml.Unmarshal(body, &assumeResp), "Parse AssumeRole response: %s", string(body)) + + tempAccessKey := assumeResp.Result.Credentials.AccessKeyId + tempSecretKey := assumeResp.Result.Credentials.SecretAccessKey + tempSessionToken := assumeResp.Result.Credentials.SessionToken + require.NotEmpty(t, tempAccessKey) + require.NotEmpty(t, tempSessionToken) + + // Now try GetFederationToken with the temporary credentials + // Include X-Amz-Security-Token header which marks this as a temp credential call + params := url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"ShouldFail"}, + } + + reqBody := params.Encode() + req, err := http.NewRequest(http.MethodPost, TestSTSEndpoint+"/", strings.NewReader(reqBody)) + require.NoError(t, err) + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-Amz-Security-Token", tempSessionToken) + + creds := credentials.NewStaticCredentials(tempAccessKey, tempSecretKey, tempSessionToken) + signer := v4.NewSigner(creds) + _, err = signer.Sign(req, strings.NewReader(reqBody), "sts", "us-east-1", time.Now()) + require.NoError(t, err) + + client := &http.Client{Timeout: 30 * time.Second} + resp2, err := client.Do(req) + require.NoError(t, err) + defer resp2.Body.Close() + + body2, _ := io.ReadAll(resp2.Body) + assert.Equal(t, http.StatusForbidden, resp2.StatusCode, + "GetFederationToken should reject temporary credentials: %s", string(body2)) + assert.Contains(t, string(body2), "temporary credentials", + "Error should mention temporary credentials") +} + +// TestSTSGetFederationTokenSuccess tests a successful GetFederationToken call +// and verifies the returned credentials can be used to access S3 +func TestSTSGetFederationTokenSuccess(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + if !isSTSEndpointRunning(t) { + t.Skip("SeaweedFS STS endpoint is not running at", TestSTSEndpoint) + } + + if !isGetFederationTokenImplemented(t) { + t.Skip("GetFederationToken not implemented") + } + + accessKey, secretKey := getTestCredentials() + + t.Run("basic_success", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"AppClient"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + t.Logf("Response status: %d, body: %s", resp.StatusCode, string(body)) + + if resp.StatusCode != http.StatusOK { + var errResp STSErrorTestResponse + _ = xml.Unmarshal(body, &errResp) + t.Fatalf("GetFederationToken failed: code=%s message=%s", errResp.Error.Code, errResp.Error.Message) + } + + var stsResp GetFederationTokenTestResponse + require.NoError(t, xml.Unmarshal(body, &stsResp), "Parse response: %s", string(body)) + + creds := stsResp.Result.Credentials + assert.NotEmpty(t, creds.AccessKeyId) + assert.NotEmpty(t, creds.SecretAccessKey) + assert.NotEmpty(t, creds.SessionToken) + assert.NotEmpty(t, creds.Expiration) + + fedUser := stsResp.Result.FederatedUser + assert.Contains(t, fedUser.Arn, "federated-user/AppClient") + assert.Contains(t, fedUser.FederatedUserId, "AppClient") + }) + + t.Run("with_custom_duration", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"DurationTest"}, + "DurationSeconds": {"3600"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + t.Logf("Response status: %d, body: %s", resp.StatusCode, string(body)) + + if resp.StatusCode == http.StatusOK { + var stsResp GetFederationTokenTestResponse + require.NoError(t, xml.Unmarshal(body, &stsResp)) + assert.NotEmpty(t, stsResp.Result.Credentials.AccessKeyId) + + // Verify expiration is roughly 1 hour from now + expTime, err := time.Parse(time.RFC3339, stsResp.Result.Credentials.Expiration) + require.NoError(t, err) + diff := time.Until(expTime) + assert.InDelta(t, 3600, diff.Seconds(), 60, + "Expiration should be ~1 hour from now") + } + }) + + t.Run("with_36_hour_duration", func(t *testing.T) { + // GetFederationToken allows up to 36 hours (unlike AssumeRole's 12h max) + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"LongDuration"}, + "DurationSeconds": {"129600"}, // 36 hours + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode == http.StatusOK { + var stsResp GetFederationTokenTestResponse + require.NoError(t, xml.Unmarshal(body, &stsResp)) + + expTime, err := time.Parse(time.RFC3339, stsResp.Result.Credentials.Expiration) + require.NoError(t, err) + diff := time.Until(expTime) + assert.InDelta(t, 129600, diff.Seconds(), 60, + "Expiration should be ~36 hours from now") + } else { + // Duration should not cause a rejection + var errResp STSErrorTestResponse + _ = xml.Unmarshal(body, &errResp) + assert.NotContains(t, errResp.Error.Message, "DurationSeconds", + "36-hour duration should be accepted by GetFederationToken") + } + }) +} + +// TestSTSGetFederationTokenWithSessionPolicy tests that vended credentials +// are scoped down by an inline session policy +func TestSTSGetFederationTokenWithSessionPolicy(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + if !isSTSEndpointRunning(t) { + t.Skip("SeaweedFS STS endpoint is not running at", TestSTSEndpoint) + } + + if !isGetFederationTokenImplemented(t) { + t.Skip("GetFederationToken not implemented") + } + + accessKey, secretKey := getTestCredentials() + + // Create a test bucket using admin credentials + adminSess, err := session.NewSession(&aws.Config{ + Region: aws.String("us-east-1"), + Endpoint: aws.String(TestSTSEndpoint), + DisableSSL: aws.Bool(true), + S3ForcePathStyle: aws.Bool(true), + Credentials: credentials.NewStaticCredentials(accessKey, secretKey, ""), + }) + require.NoError(t, err) + + adminS3 := s3.New(adminSess) + bucket := fmt.Sprintf("fed-token-test-%d", time.Now().UnixNano()) + + _, err = adminS3.CreateBucket(&s3.CreateBucketInput{Bucket: aws.String(bucket)}) + require.NoError(t, err) + defer adminS3.DeleteBucket(&s3.DeleteBucketInput{Bucket: aws.String(bucket)}) + + _, err = adminS3.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String("test.txt"), + Body: strings.NewReader("hello"), + }) + require.NoError(t, err) + defer adminS3.DeleteObject(&s3.DeleteObjectInput{Bucket: aws.String(bucket), Key: aws.String("test.txt")}) + + // Get federated credentials with a session policy that only allows GetObject + sessionPolicy := fmt.Sprintf(`{ + "Version": "2012-10-17", + "Statement": [{ + "Effect": "Allow", + "Action": ["s3:GetObject"], + "Resource": ["arn:aws:s3:::%s/*"] + }] + }`, bucket) + + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"ScopedClient"}, + "Policy": {sessionPolicy}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + t.Logf("GetFederationToken response: status=%d body=%s", resp.StatusCode, string(body)) + + if resp.StatusCode != http.StatusOK { + t.Skipf("GetFederationToken failed (may need IAM policy config): %s", string(body)) + } + + var stsResp GetFederationTokenTestResponse + require.NoError(t, xml.Unmarshal(body, &stsResp)) + + fedCreds := stsResp.Result.Credentials + require.NotEmpty(t, fedCreds.AccessKeyId) + require.NotEmpty(t, fedCreds.SessionToken) + + // Create S3 client with the federated credentials + fedSess, err := session.NewSession(&aws.Config{ + Region: aws.String("us-east-1"), + Endpoint: aws.String(TestSTSEndpoint), + DisableSSL: aws.Bool(true), + S3ForcePathStyle: aws.Bool(true), + Credentials: credentials.NewStaticCredentials( + fedCreds.AccessKeyId, fedCreds.SecretAccessKey, fedCreds.SessionToken), + }) + require.NoError(t, err) + + fedS3 := s3.New(fedSess) + + // GetObject should succeed (allowed by session policy) + getResp, err := fedS3.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String("test.txt"), + }) + if err == nil { + defer getResp.Body.Close() + t.Log("GetObject with federated credentials succeeded (as expected)") + } else { + t.Logf("GetObject with federated credentials: %v (session policy enforcement may vary)", err) + } + + // PutObject should be denied (not allowed by session policy) + _, err = fedS3.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String("denied.txt"), + Body: strings.NewReader("should fail"), + }) + if err != nil { + t.Log("PutObject correctly denied with federated credentials") + assert.Contains(t, err.Error(), "AccessDenied", + "PutObject should be denied by session policy") + } else { + // Clean up if unexpectedly succeeded + adminS3.DeleteObject(&s3.DeleteObjectInput{Bucket: aws.String(bucket), Key: aws.String("denied.txt")}) + t.Log("PutObject unexpectedly succeeded — session policy enforcement may not be active") + } +} diff --git a/weed/iam/integration/iam_manager.go b/weed/iam/integration/iam_manager.go index 2e1225a89..fb8a47895 100644 --- a/weed/iam/integration/iam_manager.go +++ b/weed/iam/integration/iam_manager.go @@ -725,6 +725,23 @@ func (m *IAMManager) ExpireSessionForTesting(ctx context.Context, sessionToken s return m.stsService.ExpireSessionForTesting(ctx, sessionToken) } +// GetPoliciesForUser returns the policy names attached to an IAM user. +// Returns an error if the user store is not configured or the lookup fails, +// so callers can fail closed on policy-resolution failures. +func (m *IAMManager) GetPoliciesForUser(ctx context.Context, username string) ([]string, error) { + if m.userStore == nil { + return nil, fmt.Errorf("user store not configured") + } + user, err := m.userStore.GetUser(ctx, username) + if err != nil { + return nil, fmt.Errorf("failed to look up user %q: %w", username, err) + } + if user == nil { + return nil, nil + } + return user.PolicyNames, nil +} + // GetSTSService returns the STS service instance func (m *IAMManager) GetSTSService() *sts.STSService { return m.stsService diff --git a/weed/iam/sts/constants.go b/weed/iam/sts/constants.go index 021aca906..6e293028b 100644 --- a/weed/iam/sts/constants.go +++ b/weed/iam/sts/constants.go @@ -124,6 +124,7 @@ const ( ActionAssumeRole = "sts:AssumeRole" ActionAssumeRoleWithWebIdentity = "sts:AssumeRoleWithWebIdentity" ActionAssumeRoleWithCredentials = "sts:AssumeRoleWithCredentials" + ActionGetFederationToken = "sts:GetFederationToken" ActionValidateSession = "sts:ValidateSession" ) diff --git a/weed/s3api/s3_action_resolver.go b/weed/s3api/s3_action_resolver.go index 1a9edfca8..fa2e0a134 100644 --- a/weed/s3api/s3_action_resolver.go +++ b/weed/s3api/s3_action_resolver.go @@ -296,8 +296,8 @@ func resolveBucketLevelAction(method string, baseAction string) string { // mapBaseActionToS3Format converts coarse-grained base actions to S3 format // This is the fallback when no specific resolution is found func mapBaseActionToS3Format(baseAction string) string { - // Handle actions that already have s3: or iam: prefix - if strings.HasPrefix(baseAction, "s3:") || strings.HasPrefix(baseAction, "iam:") { + // Handle actions that already have a known service prefix + if strings.HasPrefix(baseAction, "s3:") || strings.HasPrefix(baseAction, "iam:") || strings.HasPrefix(baseAction, "sts:") { return baseAction } diff --git a/weed/s3api/s3_action_resolver_test.go b/weed/s3api/s3_action_resolver_test.go index c95ec3972..9e11f1dcb 100644 --- a/weed/s3api/s3_action_resolver_test.go +++ b/weed/s3api/s3_action_resolver_test.go @@ -7,6 +7,59 @@ import ( "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" ) +// TestMapBaseActionToS3Format_ServicePrefixPassthrough verifies that actions +// with known service prefixes (s3:, iam:, sts:) are returned unchanged. +func TestMapBaseActionToS3Format_ServicePrefixPassthrough(t *testing.T) { + tests := []struct { + name string + input string + expect string + }{ + {"s3 prefix", "s3:GetObject", "s3:GetObject"}, + {"iam prefix", "iam:CreateUser", "iam:CreateUser"}, + {"sts:AssumeRole", "sts:AssumeRole", "sts:AssumeRole"}, + {"sts:GetFederationToken", "sts:GetFederationToken", "sts:GetFederationToken"}, + {"sts:GetCallerIdentity", "sts:GetCallerIdentity", "sts:GetCallerIdentity"}, + {"coarse Read maps to s3:GetObject", "Read", s3_constants.S3_ACTION_GET_OBJECT}, + {"coarse Write maps to s3:PutObject", "Write", s3_constants.S3_ACTION_PUT_OBJECT}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mapBaseActionToS3Format(tt.input) + if got != tt.expect { + t.Errorf("mapBaseActionToS3Format(%q) = %q, want %q", tt.input, got, tt.expect) + } + }) + } +} + +// TestResolveS3Action_STSActionsPassthrough verifies that STS actions flow +// through ResolveS3Action unchanged, both with and without an HTTP request. +func TestResolveS3Action_STSActionsPassthrough(t *testing.T) { + stsActions := []string{ + "sts:AssumeRole", + "sts:GetFederationToken", + "sts:GetCallerIdentity", + } + + for _, action := range stsActions { + t.Run("nil_request_"+action, func(t *testing.T) { + got := ResolveS3Action(nil, action, "", "") + if got != action { + t.Errorf("ResolveS3Action(nil, %q) = %q, want %q", action, got, action) + } + }) + t.Run("with_request_"+action, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "http://localhost/", nil) + got := ResolveS3Action(r, action, "", "") + if got != action { + t.Errorf("ResolveS3Action(r, %q) = %q, want %q", action, got, action) + } + }) + } +} + func TestResolveS3Action_AttributesBeforeVersionId(t *testing.T) { tests := []struct { name string diff --git a/weed/s3api/s3api_sts.go b/weed/s3api/s3api_sts.go index 478c09139..f9167d271 100644 --- a/weed/s3api/s3api_sts.go +++ b/weed/s3api/s3api_sts.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "net/http" + "regexp" "strconv" "time" @@ -37,6 +38,10 @@ const ( actionAssumeRoleWithWebIdentity = "AssumeRoleWithWebIdentity" actionAssumeRoleWithLDAPIdentity = "AssumeRoleWithLDAPIdentity" actionGetCallerIdentity = "GetCallerIdentity" + actionGetFederationToken = "GetFederationToken" + + // GetFederationToken-specific parameters + stsFederationName = "Name" // LDAP parameter names stsLDAPUsername = "LDAPUsername" @@ -44,15 +49,20 @@ const ( stsLDAPProviderName = "LDAPProviderName" ) +// federationNameRegex validates the Name parameter for GetFederationToken per AWS spec +var federationNameRegex = regexp.MustCompile(`^[\w+=,.@-]+$`) + // STS duration constants (AWS specification) const ( - minDurationSeconds = int64(900) // 15 minutes - maxDurationSeconds = int64(43200) // 12 hours + minDurationSeconds = int64(900) // 15 minutes + maxDurationSeconds = int64(43200) // 12 hours (AssumeRole) + defaultFederationDurationSeconds = int64(43200) // 12 hours (GetFederationToken default) + maxFederationDurationSeconds = int64(129600) // 36 hours (GetFederationToken max) ) -// parseDurationSeconds parses and validates the DurationSeconds parameter -// Returns nil if the parameter is not provided, or a pointer to the parsed value -func parseDurationSeconds(r *http.Request) (*int64, STSErrorCode, error) { +// parseDurationSecondsWithBounds parses and validates the DurationSeconds parameter +// against the given min and max bounds. Returns nil if the parameter is not provided. +func parseDurationSecondsWithBounds(r *http.Request, minSec, maxSec int64) (*int64, STSErrorCode, error) { dsStr := r.FormValue("DurationSeconds") if dsStr == "" { return nil, "", nil @@ -63,14 +73,19 @@ func parseDurationSeconds(r *http.Request) (*int64, STSErrorCode, error) { return nil, STSErrInvalidParameterValue, fmt.Errorf("invalid DurationSeconds: %w", err) } - if ds < minDurationSeconds || ds > maxDurationSeconds { + if ds < minSec || ds > maxSec { return nil, STSErrInvalidParameterValue, - fmt.Errorf("DurationSeconds must be between %d and %d seconds", minDurationSeconds, maxDurationSeconds) + fmt.Errorf("DurationSeconds must be between %d and %d seconds", minSec, maxSec) } return &ds, "", nil } +// parseDurationSeconds parses DurationSeconds for AssumeRole (15 min to 12 hours) +func parseDurationSeconds(r *http.Request) (*int64, STSErrorCode, error) { + return parseDurationSecondsWithBounds(r, minDurationSeconds, maxDurationSeconds) +} + // Removed generateSecureCredentials - now using STS service's JWT token generation // The STS service generates proper JWT tokens with embedded claims that can be validated // across distributed instances without shared state. @@ -124,6 +139,8 @@ func (h *STSHandlers) HandleSTSRequest(w http.ResponseWriter, r *http.Request) { h.handleAssumeRoleWithLDAPIdentity(w, r) case actionGetCallerIdentity: h.handleGetCallerIdentity(w, r) + case actionGetFederationToken: + h.handleGetFederationToken(w, r) default: h.writeSTSErrorResponse(w, r, STSErrInvalidAction, fmt.Errorf("unsupported action: %s", action)) @@ -296,7 +313,7 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) { // Check authorizations if roleArn != "" { // Check if the caller is authorized to assume the role (sts:AssumeRole permission) - if authErr := h.iam.VerifyActionPermission(r, identity, Action("sts:AssumeRole"), "", roleArn); authErr != s3err.ErrNone { + if authErr := h.iam.VerifyActionPermission(r, identity, Action(sts.ActionAssumeRole), "", roleArn); authErr != s3err.ErrNone { glog.V(2).Infof("AssumeRole: caller %s is not authorized to assume role %s", identity.Name, roleArn) h.writeSTSErrorResponse(w, r, STSErrAccessDenied, fmt.Errorf("user %s is not authorized to assume role %s", identity.Name, roleArn)) @@ -320,7 +337,7 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) { // For safety/consistency with previous logic, we keep the check but strictly it might not be required by AWS for GetSessionToken. // But since this IS AssumeRole, let's keep it. // Admin/Global check when no specific role is requested - if authErr := h.iam.VerifyActionPermission(r, identity, Action("sts:AssumeRole"), "", ""); authErr != s3err.ErrNone { + if authErr := h.iam.VerifyActionPermission(r, identity, Action(sts.ActionAssumeRole), "", ""); authErr != s3err.ErrNone { glog.Warningf("AssumeRole: caller %s attempted to assume role without RoleArn and lacks global sts:AssumeRole permission", identity.Name) h.writeSTSErrorResponse(w, r, STSErrAccessDenied, fmt.Errorf("access denied")) return @@ -505,6 +522,202 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) } +// handleGetFederationToken handles the GetFederationToken API action. +// This allows long-term IAM users to obtain temporary credentials scoped down +// by an optional inline session policy. Temporary credentials cannot call this action. +func (h *STSHandlers) handleGetFederationToken(w http.ResponseWriter, r *http.Request) { + // Extract parameters + name := r.FormValue(stsFederationName) + + // Validate required parameters + if name == "" { + h.writeSTSErrorResponse(w, r, STSErrMissingParameter, + fmt.Errorf("Name is required")) + return + } + + // AWS requires Name to be 2-32 characters matching [\w+=,.@-]+ + if len(name) < 2 || len(name) > 32 { + h.writeSTSErrorResponse(w, r, STSErrInvalidParameterValue, + fmt.Errorf("Name must be between 2 and 32 characters")) + return + } + if !federationNameRegex.MatchString(name) { + h.writeSTSErrorResponse(w, r, STSErrInvalidParameterValue, + fmt.Errorf("Name contains invalid characters, must match [\\w+=,.@-]+")) + return + } + + // Parse and validate DurationSeconds (GetFederationToken allows up to 36 hours) + durationSeconds, errCode, err := parseDurationSecondsWithBounds(r, minDurationSeconds, maxFederationDurationSeconds) + if err != nil { + h.writeSTSErrorResponse(w, r, errCode, err) + return + } + + // Reject calls from temporary credentials (session tokens) early, + // before SigV4 verification — no need to authenticate first. + // GetFederationToken can only be called by long-term IAM users. + securityToken := r.Header.Get("X-Amz-Security-Token") + if securityToken == "" { + securityToken = r.URL.Query().Get("X-Amz-Security-Token") + } + if securityToken != "" { + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("GetFederationToken cannot be called with temporary credentials")) + return + } + + // Check if STS service is initialized + if h.stsService == nil || !h.stsService.IsInitialized() { + h.writeSTSErrorResponse(w, r, STSErrSTSNotReady, + fmt.Errorf("STS service not initialized")) + return + } + + // Check if IAM is available for SigV4 verification + if h.iam == nil { + h.writeSTSErrorResponse(w, r, STSErrSTSNotReady, + fmt.Errorf("IAM not configured for STS")) + return + } + + // Validate AWS SigV4 authentication + identity, _, _, _, sigErrCode := h.iam.verifyV4Signature(r, false) + if sigErrCode != s3err.ErrNone { + glog.V(2).Infof("GetFederationToken SigV4 verification failed: %v", sigErrCode) + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("invalid AWS signature: %v", sigErrCode)) + return + } + + if identity == nil { + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("unable to identify caller")) + return + } + + glog.V(2).Infof("GetFederationToken: caller identity=%s, name=%s", identity.Name, name) + + // Check if the caller is authorized to call GetFederationToken + if authErr := h.iam.VerifyActionPermission(r, identity, Action(sts.ActionGetFederationToken), "", ""); authErr != s3err.ErrNone { + glog.V(2).Infof("GetFederationToken: caller %s is not authorized to call GetFederationToken", identity.Name) + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("user %s is not authorized to call GetFederationToken", identity.Name)) + return + } + + // Validate session policy if provided + sessionPolicyJSON, err := sts.NormalizeSessionPolicy(r.FormValue("Policy")) + if err != nil { + h.writeSTSErrorResponse(w, r, STSErrMalformedPolicyDocument, + fmt.Errorf("invalid Policy document: %w", err)) + return + } + + // Calculate duration (default 12 hours for GetFederationToken) + duration := time.Duration(defaultFederationDurationSeconds) * time.Second + if durationSeconds != nil { + duration = time.Duration(*durationSeconds) * time.Second + } + + // Generate session ID + sessionId, err := sts.GenerateSessionId() + if err != nil { + h.writeSTSErrorResponse(w, r, STSErrInternalError, + fmt.Errorf("failed to generate session ID: %w", err)) + return + } + + expiration := time.Now().Add(duration) + accountID := h.getAccountID() + + // Build federated user ARN: arn:aws:sts:::federated-user/ + federatedUserArn := fmt.Sprintf("arn:aws:sts::%s:federated-user/%s", accountID, name) + federatedUserId := fmt.Sprintf("%s:%s", accountID, name) + + // Create session claims — use the caller's principal ARN as the RoleArn + // so that policy evaluation resolves the caller's attached policies + claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration). + WithSessionName(name). + WithRoleInfo(identity.PrincipalArn, federatedUserId, federatedUserArn) + + // Embed the caller's effective policies into the token. + // Merge identity.PolicyNames (from SigV4 identity) with policies resolved + // from the IAM manager (which may include group-attached policies). + policySet := make(map[string]struct{}) + for _, p := range identity.PolicyNames { + policySet[p] = struct{}{} + } + + var policyManager *integration.IAMManager + if h.iam.iamIntegration != nil { + if provider, ok := h.iam.iamIntegration.(IAMManagerProvider); ok { + policyManager = provider.GetIAMManager() + } + } + if policyManager != nil { + userPolicies, err := policyManager.GetPoliciesForUser(r.Context(), identity.Name) + if err != nil { + glog.V(2).Infof("GetFederationToken: failed to resolve policies for %s: %v", identity.Name, err) + h.writeSTSErrorResponse(w, r, STSErrInternalError, + fmt.Errorf("failed to resolve caller policies")) + return + } + for _, p := range userPolicies { + policySet[p] = struct{}{} + } + } + + if len(policySet) > 0 { + merged := make([]string, 0, len(policySet)) + for p := range policySet { + merged = append(merged, p) + } + claims.WithPolicies(merged) + } + + if sessionPolicyJSON != "" { + claims.WithSessionPolicy(sessionPolicyJSON) + } + + // Generate JWT session token + sessionToken, err := h.stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + if err != nil { + h.writeSTSErrorResponse(w, r, STSErrInternalError, + fmt.Errorf("failed to generate session token: %w", err)) + return + } + + // Generate temporary credentials + stsCredGen := sts.NewCredentialGenerator() + stsCredsDet, err := stsCredGen.GenerateTemporaryCredentials(sessionId, expiration) + if err != nil { + h.writeSTSErrorResponse(w, r, STSErrInternalError, + fmt.Errorf("failed to generate temporary credentials: %w", err)) + return + } + + // Build and return response + xmlResponse := &GetFederationTokenResponse{ + Result: GetFederationTokenResult{ + Credentials: STSCredentials{ + AccessKeyId: stsCredsDet.AccessKeyId, + SecretAccessKey: stsCredsDet.SecretAccessKey, + SessionToken: sessionToken, + Expiration: expiration.Format(time.RFC3339), + }, + FederatedUser: FederatedUser{ + FederatedUserId: federatedUserId, + Arn: federatedUserArn, + }, + }, + } + xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) + + s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) +} + // prepareSTSCredentials extracts common shared logic for credential generation func (h *STSHandlers) prepareSTSCredentials(ctx context.Context, roleArn, roleSessionName string, durationSeconds *int64, sessionPolicy string, modifyClaims func(*sts.STSSessionClaims)) (STSCredentials, *AssumedRoleUser, error) { @@ -743,6 +956,27 @@ type GetCallerIdentityResult struct { Account string `xml:"Account"` } +// GetFederationTokenResponse is the response for GetFederationToken +type GetFederationTokenResponse struct { + XMLName xml.Name `xml:"https://sts.amazonaws.com/doc/2011-06-15/ GetFederationTokenResponse"` + Result GetFederationTokenResult `xml:"GetFederationTokenResult"` + ResponseMetadata struct { + RequestId string `xml:"RequestId,omitempty"` + } `xml:"ResponseMetadata,omitempty"` +} + +// GetFederationTokenResult contains the result of GetFederationToken +type GetFederationTokenResult struct { + Credentials STSCredentials `xml:"Credentials"` + FederatedUser FederatedUser `xml:"FederatedUser"` +} + +// FederatedUser contains information about the federated user +type FederatedUser struct { + FederatedUserId string `xml:"FederatedUserId"` + Arn string `xml:"Arn"` +} + // STS Error types // STSErrorCode represents STS error codes diff --git a/weed/s3api/s3api_sts_get_federation_token_test.go b/weed/s3api/s3api_sts_get_federation_token_test.go new file mode 100644 index 000000000..7375a1822 --- /dev/null +++ b/weed/s3api/s3api_sts_get_federation_token_test.go @@ -0,0 +1,746 @@ +package s3api + +import ( + "context" + "encoding/xml" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "sort" + "strings" + "testing" + "time" + + "github.com/seaweedfs/seaweedfs/weed/iam/integration" + "github.com/seaweedfs/seaweedfs/weed/iam/policy" + "github.com/seaweedfs/seaweedfs/weed/iam/sts" + "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockUserStore implements integration.UserStore for testing GetPoliciesForUser +type mockUserStore struct { + users map[string]*iam_pb.Identity +} + +func (m *mockUserStore) GetUser(_ context.Context, username string) (*iam_pb.Identity, error) { + u, ok := m.users[username] + if !ok { + return nil, nil + } + return u, nil +} + +// TestGetFederationToken_BasicFlow tests basic credential generation for GetFederationToken +func TestGetFederationToken_BasicFlow(t *testing.T) { + stsService, _ := setupTestSTSService(t) + + iam := &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + } + stsHandlers := NewSTSHandlers(stsService, iam) + + // Simulate the core logic of handleGetFederationToken + name := "BobApp" + callerIdentity := &Identity{ + Name: "alice", + PrincipalArn: fmt.Sprintf("arn:aws:iam::%s:user/alice", defaultAccountID), + PolicyNames: []string{"S3ReadPolicy"}, + } + + accountID := stsHandlers.getAccountID() + + // Generate session ID and credentials + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(12 * time.Hour) + federatedUserArn := fmt.Sprintf("arn:aws:sts::%s:federated-user/%s", accountID, name) + federatedUserId := fmt.Sprintf("%s:%s", accountID, name) + + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName(name). + WithRoleInfo(callerIdentity.PrincipalArn, federatedUserId, federatedUserArn). + WithPolicies(callerIdentity.PolicyNames) + + sessionToken, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + // Validate the session token + sessionInfo, err := stsService.ValidateSessionToken(context.Background(), sessionToken) + require.NoError(t, err) + require.NotNil(t, sessionInfo) + + // Verify the session info contains caller's policies + assert.Equal(t, []string{"S3ReadPolicy"}, sessionInfo.Policies) + + // Verify principal is the federated user ARN + assert.Equal(t, federatedUserArn, sessionInfo.Principal) + + // Verify the RoleArn points to the caller's identity (for policy resolution) + assert.Equal(t, callerIdentity.PrincipalArn, sessionInfo.RoleArn) + + // Verify session name + assert.Equal(t, name, sessionInfo.SessionName) +} + +// TestGetFederationToken_WithSessionPolicy tests session policy scoping +func TestGetFederationToken_WithSessionPolicy(t *testing.T) { + stsService, _ := setupTestSTSService(t) + + stsHandlers := NewSTSHandlers(stsService, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + accountID := stsHandlers.getAccountID() + name := "ScopedApp" + + sessionPolicyJSON := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:GetObject"],"Resource":["arn:aws:s3:::my-bucket/*"]}]}` + normalizedPolicy, err := sts.NormalizeSessionPolicy(sessionPolicyJSON) + require.NoError(t, err) + + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(12 * time.Hour) + federatedUserArn := fmt.Sprintf("arn:aws:sts::%s:federated-user/%s", accountID, name) + federatedUserId := fmt.Sprintf("%s:%s", accountID, name) + + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName(name). + WithRoleInfo("arn:aws:iam::000000000000:user/caller", federatedUserId, federatedUserArn). + WithPolicies([]string{"S3FullAccess"}). + WithSessionPolicy(normalizedPolicy) + + sessionToken, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + sessionInfo, err := stsService.ValidateSessionToken(context.Background(), sessionToken) + require.NoError(t, err) + require.NotNil(t, sessionInfo) + + // Verify session policy is embedded + assert.NotEmpty(t, sessionInfo.SessionPolicy) + assert.Contains(t, sessionInfo.SessionPolicy, "s3:GetObject") + + // Verify caller's policies are still present + assert.Equal(t, []string{"S3FullAccess"}, sessionInfo.Policies) +} + +// TestGetFederationToken_RejectTemporaryCredentials tests that requests with +// session tokens are rejected. +func TestGetFederationToken_RejectTemporaryCredentials(t *testing.T) { + stsService, _ := setupTestSTSService(t) + stsHandlers := NewSTSHandlers(stsService, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + tests := []struct { + name string + setToken func(r *http.Request) + description string + }{ + { + name: "SessionTokenInHeader", + setToken: func(r *http.Request) { + r.Header.Set("X-Amz-Security-Token", "some-session-token") + }, + description: "Session token in X-Amz-Security-Token header should be rejected", + }, + { + name: "SessionTokenInQuery", + setToken: func(r *http.Request) { + q := r.URL.Query() + q.Set("X-Amz-Security-Token", "some-session-token") + r.URL.RawQuery = q.Encode() + }, + description: "Session token in query string should be rejected", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + form := url.Values{} + form.Set("Action", "GetFederationToken") + form.Set("Name", "TestUser") + form.Set("Version", "2011-06-15") + + req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + tt.setToken(req) + + // Parse form so the handler can read it + require.NoError(t, req.ParseForm()) + // Re-set values after parse + req.Form.Set("Action", "GetFederationToken") + req.Form.Set("Name", "TestUser") + req.Form.Set("Version", "2011-06-15") + + rr := httptest.NewRecorder() + stsHandlers.HandleSTSRequest(rr, req) + + // The handler rejects temporary credentials before SigV4 verification + assert.Equal(t, http.StatusForbidden, rr.Code, tt.description) + assert.Contains(t, rr.Body.String(), "AccessDenied") + assert.Contains(t, rr.Body.String(), "cannot be called with temporary credentials") + }) + } +} + +// TestGetFederationToken_MissingName tests that a missing Name parameter returns an error +func TestGetFederationToken_MissingName(t *testing.T) { + stsService, _ := setupTestSTSService(t) + stsHandlers := NewSTSHandlers(stsService, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + req := httptest.NewRequest("POST", "/", nil) + req.Form = url.Values{} + req.Form.Set("Action", "GetFederationToken") + req.Form.Set("Version", "2011-06-15") + // Name is intentionally omitted + + rr := httptest.NewRecorder() + stsHandlers.HandleSTSRequest(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Name is required") +} + +// TestGetFederationToken_NameValidation tests Name parameter validation +func TestGetFederationToken_NameValidation(t *testing.T) { + stsService, _ := setupTestSTSService(t) + stsHandlers := NewSTSHandlers(stsService, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + tests := []struct { + name string + federName string + expectError bool + errContains string + }{ + { + name: "TooShort", + federName: "A", + expectError: true, + errContains: "between 2 and 32", + }, + { + name: "TooLong", + federName: strings.Repeat("A", 33), + expectError: true, + errContains: "between 2 and 32", + }, + { + name: "MinLength", + federName: "AB", + expectError: false, + }, + { + name: "MaxLength", + federName: strings.Repeat("A", 32), + expectError: false, + }, + { + name: "ValidSpecialChars", + federName: "user+=,.@-test", + expectError: false, + }, + { + name: "InvalidChars_Space", + federName: "bad name", + expectError: true, + errContains: "invalid characters", + }, + { + name: "InvalidChars_Slash", + federName: "bad/name", + expectError: true, + errContains: "invalid characters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/", nil) + req.Form = url.Values{} + req.Form.Set("Action", "GetFederationToken") + req.Form.Set("Name", tt.federName) + req.Form.Set("Version", "2011-06-15") + + rr := httptest.NewRecorder() + stsHandlers.HandleSTSRequest(rr, req) + + if tt.expectError { + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), tt.errContains) + } else { + // Valid name should proceed past validation — will fail at SigV4 + // (returns 403 because we have no real signature) + assert.NotEqual(t, http.StatusBadRequest, rr.Code, + "Valid name should not produce a 400 for name validation") + } + }) + } +} + +// TestGetFederationToken_DurationValidation tests DurationSeconds validation +func TestGetFederationToken_DurationValidation(t *testing.T) { + stsService, _ := setupTestSTSService(t) + stsHandlers := NewSTSHandlers(stsService, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + tests := []struct { + name string + duration string + expectError bool + errContains string + }{ + { + name: "BelowMinimum", + duration: "899", + expectError: true, + errContains: "between", + }, + { + name: "AboveMaximum", + duration: "129601", + expectError: true, + errContains: "between", + }, + { + name: "InvalidFormat", + duration: "not-a-number", + expectError: true, + errContains: "invalid DurationSeconds", + }, + { + name: "MinimumValid", + duration: "900", + expectError: false, + }, + { + name: "MaximumValid_36Hours", + duration: "129600", + expectError: false, + }, + { + name: "Default12Hours", + duration: "43200", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/", nil) + req.Form = url.Values{} + req.Form.Set("Action", "GetFederationToken") + req.Form.Set("Name", "TestUser") + req.Form.Set("DurationSeconds", tt.duration) + req.Form.Set("Version", "2011-06-15") + + rr := httptest.NewRecorder() + stsHandlers.HandleSTSRequest(rr, req) + + if tt.expectError { + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), tt.errContains) + } else { + // Valid duration should proceed past validation — will fail at SigV4 + assert.NotEqual(t, http.StatusBadRequest, rr.Code, + "Valid duration should not produce a 400 for duration validation") + } + }) + } +} + +// TestGetFederationToken_ResponseFormat tests the XML response structure +func TestGetFederationToken_ResponseFormat(t *testing.T) { + // Verify the response XML structure matches AWS format + response := GetFederationTokenResponse{ + Result: GetFederationTokenResult{ + Credentials: STSCredentials{ + AccessKeyId: "ASIA1234567890", + SecretAccessKey: "secret123", + SessionToken: "token123", + Expiration: "2026-04-02T12:00:00Z", + }, + FederatedUser: FederatedUser{ + FederatedUserId: "000000000000:BobApp", + Arn: "arn:aws:sts::000000000000:federated-user/BobApp", + }, + }, + } + response.ResponseMetadata.RequestId = "test-request-id" + + data, err := xml.MarshalIndent(response, "", " ") + require.NoError(t, err) + + xmlStr := string(data) + assert.Contains(t, xmlStr, "GetFederationTokenResponse") + assert.Contains(t, xmlStr, "GetFederationTokenResult") + assert.Contains(t, xmlStr, "FederatedUser") + assert.Contains(t, xmlStr, "FederatedUserId") + assert.Contains(t, xmlStr, "federated-user/BobApp") + assert.Contains(t, xmlStr, "ASIA1234567890") + assert.Contains(t, xmlStr, "test-request-id") + + // Verify it can be unmarshaled back + var parsed GetFederationTokenResponse + err = xml.Unmarshal(data, &parsed) + require.NoError(t, err) + assert.Equal(t, "ASIA1234567890", parsed.Result.Credentials.AccessKeyId) + assert.Equal(t, "arn:aws:sts::000000000000:federated-user/BobApp", parsed.Result.FederatedUser.Arn) + assert.Equal(t, "000000000000:BobApp", parsed.Result.FederatedUser.FederatedUserId) +} + +// TestGetFederationToken_PolicyEmbedding tests that the caller's policies are embedded +// into the session token using the IAM integration manager +func TestGetFederationToken_PolicyEmbedding(t *testing.T) { + ctx := context.Background() + manager := newTestSTSIntegrationManager(t) + + // Create a policy that the user has attached + userPolicy := &policy.PolicyDocument{ + Version: "2012-10-17", + Statement: []policy.Statement{ + { + Effect: "Allow", + Action: []string{"s3:GetObject", "s3:PutObject"}, + Resource: []string{"arn:aws:s3:::user-bucket/*"}, + }, + }, + } + require.NoError(t, manager.CreatePolicy(ctx, "", "UserS3Policy", userPolicy)) + + stsService := manager.GetSTSService() + + // Simulate what handleGetFederationToken does for policy embedding + name := "AppClient" + callerPolicies := []string{"UserS3Policy"} + + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(12 * time.Hour) + accountID := defaultAccountID + federatedUserArn := fmt.Sprintf("arn:aws:sts::%s:federated-user/%s", accountID, name) + federatedUserId := fmt.Sprintf("%s:%s", accountID, name) + + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName(name). + WithRoleInfo("arn:aws:iam::000000000000:user/caller", federatedUserId, federatedUserArn). + WithPolicies(callerPolicies) + + sessionToken, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + sessionInfo, err := stsService.ValidateSessionToken(ctx, sessionToken) + require.NoError(t, err) + require.NotNil(t, sessionInfo) + + // Verify the caller's policy names are embedded + assert.Equal(t, []string{"UserS3Policy"}, sessionInfo.Policies) +} + +// TestGetFederationToken_PolicyIntersection tests that both the caller's base policies +// and the restrictive session policy are embedded in the token, enabling the +// authorization layer to compute their intersection at request time. +func TestGetFederationToken_PolicyIntersection(t *testing.T) { + ctx := context.Background() + manager := newTestSTSIntegrationManager(t) + + // Create a broad policy for the caller + broadPolicy := &policy.PolicyDocument{ + Version: "2012-10-17", + Statement: []policy.Statement{ + { + Effect: "Allow", + Action: []string{"s3:*"}, + Resource: []string{"arn:aws:s3:::*", "arn:aws:s3:::*/*"}, + }, + }, + } + require.NoError(t, manager.CreatePolicy(ctx, "", "S3FullAccess", broadPolicy)) + + stsService := manager.GetSTSService() + + // Session policy restricts to one bucket and one action + sessionPolicyJSON := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:GetObject"],"Resource":["arn:aws:s3:::restricted-bucket/*"]}]}` + normalizedPolicy, err := sts.NormalizeSessionPolicy(sessionPolicyJSON) + require.NoError(t, err) + + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(12 * time.Hour) + name := "RestrictedApp" + accountID := defaultAccountID + federatedUserArn := fmt.Sprintf("arn:aws:sts::%s:federated-user/%s", accountID, name) + federatedUserId := fmt.Sprintf("%s:%s", accountID, name) + + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName(name). + WithRoleInfo("arn:aws:iam::000000000000:user/caller", federatedUserId, federatedUserArn). + WithPolicies([]string{"S3FullAccess"}). + WithSessionPolicy(normalizedPolicy) + + sessionToken, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + sessionInfo, err := stsService.ValidateSessionToken(ctx, sessionToken) + require.NoError(t, err) + require.NotNil(t, sessionInfo) + + // Verify both the broad base policies and the restrictive session policy are embedded + // The authorization layer computes intersection at request time + assert.Equal(t, []string{"S3FullAccess"}, sessionInfo.Policies, + "Caller's base policies should be embedded in token") + assert.Contains(t, sessionInfo.SessionPolicy, "restricted-bucket", + "Session policy should restrict to specific bucket") + assert.Contains(t, sessionInfo.SessionPolicy, "s3:GetObject", + "Session policy should restrict to specific action") +} + +// TestGetFederationToken_MalformedPolicy tests that invalid policy JSON is rejected +// by the session policy normalization used in the handler +func TestGetFederationToken_MalformedPolicy(t *testing.T) { + tests := []struct { + name string + policyStr string + expectErr bool + }{ + { + name: "InvalidJSON", + policyStr: "not-valid-json", + expectErr: true, + }, + { + name: "EmptyObject", + policyStr: "{}", + expectErr: true, + }, + { + name: "TooLarge", + policyStr: `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:GetObject"],"Resource":["` + strings.Repeat("a", 2048) + `"]}]}`, + expectErr: true, + }, + { + name: "ValidPolicy", + policyStr: `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:GetObject"],"Resource":["arn:aws:s3:::bucket/*"]}]}`, + expectErr: false, + }, + { + name: "EmptyString", + policyStr: "", + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := sts.NormalizeSessionPolicy(tt.policyStr) + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestGetFederationToken_STSNotReady tests that the handler returns 503 when STS is not initialized +func TestGetFederationToken_STSNotReady(t *testing.T) { + // Create handlers with nil STS service + stsHandlers := NewSTSHandlers(nil, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + req := httptest.NewRequest("POST", "/", nil) + req.Form = url.Values{} + req.Form.Set("Action", "GetFederationToken") + req.Form.Set("Name", "TestUser") + req.Form.Set("Version", "2011-06-15") + + rr := httptest.NewRecorder() + stsHandlers.HandleSTSRequest(rr, req) + + assert.Equal(t, http.StatusServiceUnavailable, rr.Code) + assert.Contains(t, rr.Body.String(), "ServiceUnavailable") +} + +// TestGetFederationToken_DefaultDuration tests that the default duration is 12 hours +func TestGetFederationToken_DefaultDuration(t *testing.T) { + assert.Equal(t, int64(43200), defaultFederationDurationSeconds, "Default duration should be 12 hours (43200 seconds)") + assert.Equal(t, int64(129600), maxFederationDurationSeconds, "Max duration should be 36 hours (129600 seconds)") +} + +// TestGetFederationToken_GetPoliciesForUser tests that GetPoliciesForUser +// correctly resolves user policies from the UserStore and returns errors +// when the store is unavailable. +func TestGetFederationToken_GetPoliciesForUser(t *testing.T) { + ctx := context.Background() + manager := newTestSTSIntegrationManager(t) + + t.Run("NoUserStore", func(t *testing.T) { + // UserStore not set — should return error + policies, err := manager.GetPoliciesForUser(ctx, "alice") + assert.Error(t, err) + assert.Nil(t, policies) + assert.Contains(t, err.Error(), "user store not configured") + }) + + t.Run("UserNotFound", func(t *testing.T) { + manager.SetUserStore(&mockUserStore{users: map[string]*iam_pb.Identity{}}) + policies, err := manager.GetPoliciesForUser(ctx, "nonexistent") + assert.NoError(t, err) + assert.Nil(t, policies) + }) + + t.Run("UserWithPolicies", func(t *testing.T) { + manager.SetUserStore(&mockUserStore{ + users: map[string]*iam_pb.Identity{ + "alice": { + Name: "alice", + PolicyNames: []string{"GroupReadPolicy", "GroupWritePolicy"}, + }, + }, + }) + policies, err := manager.GetPoliciesForUser(ctx, "alice") + assert.NoError(t, err) + assert.Equal(t, []string{"GroupReadPolicy", "GroupWritePolicy"}, policies) + }) + + t.Run("UserWithNoPolicies", func(t *testing.T) { + manager.SetUserStore(&mockUserStore{ + users: map[string]*iam_pb.Identity{ + "bob": {Name: "bob"}, + }, + }) + policies, err := manager.GetPoliciesForUser(ctx, "bob") + assert.NoError(t, err) + assert.Empty(t, policies) + }) +} + +// TestGetFederationToken_PolicyMergeAndDedup tests that the handler's policy +// merge logic correctly combines identity.PolicyNames with IAM-manager-resolved +// policies and deduplicates the result. +func TestGetFederationToken_PolicyMergeAndDedup(t *testing.T) { + ctx := context.Background() + manager := newTestSTSIntegrationManager(t) + + // Create policies so they exist in the engine + for _, name := range []string{"DirectPolicy", "GroupPolicy", "SharedPolicy"} { + require.NoError(t, manager.CreatePolicy(ctx, "", name, &policy.PolicyDocument{ + Version: "2012-10-17", + Statement: []policy.Statement{ + {Effect: "Allow", Action: []string{"s3:GetObject"}, Resource: []string{"arn:aws:s3:::*/*"}}, + }, + })) + } + + // Set up a user store that returns group-attached policies + manager.SetUserStore(&mockUserStore{ + users: map[string]*iam_pb.Identity{ + "alice": { + Name: "alice", + PolicyNames: []string{"GroupPolicy", "SharedPolicy"}, + }, + }, + }) + + stsService := manager.GetSTSService() + + // Simulate what the handler does: merge identity.PolicyNames with GetPoliciesForUser + identityPolicies := []string{"DirectPolicy", "SharedPolicy"} // SharedPolicy overlaps + + policySet := make(map[string]struct{}) + for _, p := range identityPolicies { + policySet[p] = struct{}{} + } + + userPolicies, err := manager.GetPoliciesForUser(ctx, "alice") + require.NoError(t, err) + for _, p := range userPolicies { + policySet[p] = struct{}{} + } + + merged := make([]string, 0, len(policySet)) + for p := range policySet { + merged = append(merged, p) + } + sort.Strings(merged) // deterministic for assertion + + // Should contain all three unique policies, no duplicates + assert.Equal(t, []string{"DirectPolicy", "GroupPolicy", "SharedPolicy"}, merged) + + // Verify the merged policies can be embedded in a token and recovered + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(time.Hour) + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName("test"). + WithRoleInfo("arn:aws:iam::000000000000:user/alice", "000000000000:test", "arn:aws:sts::000000000000:federated-user/test"). + WithPolicies(merged) + + token, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + sessionInfo, err := stsService.ValidateSessionToken(ctx, token) + require.NoError(t, err) + + sort.Strings(sessionInfo.Policies) + assert.Equal(t, []string{"DirectPolicy", "GroupPolicy", "SharedPolicy"}, sessionInfo.Policies, + "Token should contain the deduplicated merge of identity and group policies") +} + +// TestGetFederationToken_PolicyMergeNoManager tests that when the IAM manager +// is unavailable, identity.PolicyNames alone are still embedded. +func TestGetFederationToken_PolicyMergeNoManager(t *testing.T) { + ctx := context.Background() + stsService, _ := setupTestSTSService(t) + + // No IAM manager — only identity.PolicyNames should be used + identityPolicies := []string{"UserDirectPolicy"} + + policySet := make(map[string]struct{}) + for _, p := range identityPolicies { + policySet[p] = struct{}{} + } + + // IAM manager is nil — skip GetPoliciesForUser (mirrors handler logic) + var policyManager *integration.IAMManager // nil + if policyManager != nil { + t.Fatal("policyManager should be nil in this test") + } + + merged := make([]string, 0, len(policySet)) + for p := range policySet { + merged = append(merged, p) + } + + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(time.Hour) + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName("test"). + WithRoleInfo("arn:aws:iam::000000000000:user/alice", "000000000000:test", "arn:aws:sts::000000000000:federated-user/test"). + WithPolicies(merged) + + token, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + sessionInfo, err := stsService.ValidateSessionToken(ctx, token) + require.NoError(t, err) + + assert.Equal(t, []string{"UserDirectPolicy"}, sessionInfo.Policies, + "Without IAM manager, only identity policies should be embedded") +}