diff --git a/.github/workflows/container_release_unified.yml b/.github/workflows/container_release_unified.yml index 391002a64..8218aaba9 100644 --- a/.github/workflows/container_release_unified.yml +++ b/.github/workflows/container_release_unified.yml @@ -39,7 +39,80 @@ concurrency: cancel-in-progress: false jobs: + + # ── Pre-build Rust volume server binaries natively ────────────────── + # Cross-compiles for amd64 and arm64 without QEMU, turning a 5-hour + # emulated cargo build into ~15 minutes of native compilation. + build-rust-binaries: + runs-on: ubuntu-22.04 + strategy: + matrix: + include: + - target: x86_64-unknown-linux-musl + arch: amd64 + - target: aarch64-unknown-linux-musl + arch: arm64 + cross: true + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Install protobuf compiler + run: sudo apt-get update && sudo apt-get install -y protobuf-compiler + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + targets: ${{ matrix.target }} + + - name: Install musl tools (amd64) + if: ${{ !matrix.cross }} + run: sudo apt-get install -y musl-tools + + - name: Install cross-compilation tools (arm64) + if: matrix.cross + run: | + sudo apt-get install -y gcc-aarch64-linux-gnu + echo "CARGO_TARGET_AARCH64_UNKNOWN_LINUX_MUSL_LINKER=aarch64-linux-gnu-gcc" >> "$GITHUB_ENV" + + - name: Cache cargo registry and target + uses: actions/cache@v5 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + seaweed-volume/target + key: rust-docker-${{ matrix.target }}-${{ hashFiles('seaweed-volume/Cargo.lock') }} + restore-keys: | + rust-docker-${{ matrix.target }}- + + - name: Build large-disk variant + env: + SEAWEEDFS_COMMIT: ${{ github.sha }} + run: | + cd seaweed-volume + cargo build --release --target ${{ matrix.target }} + cp target/${{ matrix.target }}/release/weed-volume ../weed-volume-large-disk-${{ matrix.arch }} + + - name: Build normal variant + env: + SEAWEEDFS_COMMIT: ${{ github.sha }} + run: | + cd seaweed-volume + cargo build --release --target ${{ matrix.target }} --no-default-features + cp target/${{ matrix.target }}/release/weed-volume ../weed-volume-normal-${{ matrix.arch }} + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: rust-volume-${{ matrix.arch }} + path: | + weed-volume-large-disk-${{ matrix.arch }} + weed-volume-normal-${{ matrix.arch }} + + # ── Build Docker containers ───────────────────────────────────────── build: + needs: [build-rust-binaries] runs-on: ubuntu-latest strategy: # Build sequentially to avoid rate limits @@ -52,20 +125,23 @@ jobs: dockerfile: ./docker/Dockerfile.go_build build_args: "" tag_suffix: "" - - # Large disk - multi-arch + rust_variant: normal + + # Large disk - multi-arch - variant: large_disk platforms: linux/amd64,linux/arm64,linux/arm/v7,linux/386 dockerfile: ./docker/Dockerfile.go_build build_args: TAGS=5BytesOffset tag_suffix: _large_disk - + rust_variant: large-disk + # Full tags - multi-arch - variant: full platforms: linux/amd64,linux/arm64 dockerfile: ./docker/Dockerfile.go_build build_args: TAGS=elastic,gocdk,rclone,sqlite,tarantool,tikv,ydb tag_suffix: _full + rust_variant: normal # Large disk + full tags - multi-arch - variant: large_disk_full @@ -73,19 +149,42 @@ jobs: dockerfile: ./docker/Dockerfile.go_build build_args: TAGS=5BytesOffset,elastic,gocdk,rclone,sqlite,tarantool,tikv,ydb tag_suffix: _large_disk_full - + rust_variant: large-disk + # RocksDB large disk - amd64 only - variant: rocksdb platforms: linux/amd64 dockerfile: ./docker/Dockerfile.rocksdb_large build_args: "" tag_suffix: _large_disk_rocksdb + rust_variant: large-disk steps: - name: Checkout if: github.event_name != 'workflow_dispatch' || github.event.inputs.variant == 'all' || github.event.inputs.variant == matrix.variant uses: actions/checkout@v6 - + + - name: Download pre-built Rust binaries + if: github.event_name != 'workflow_dispatch' || github.event.inputs.variant == 'all' || github.event.inputs.variant == matrix.variant + uses: actions/download-artifact@v4 + with: + pattern: rust-volume-* + merge-multiple: true + path: ./rust-bins + + - name: Place Rust binaries in Docker context + if: github.event_name != 'workflow_dispatch' || github.event.inputs.variant == 'all' || github.event.inputs.variant == matrix.variant + run: | + mkdir -p docker/weed-volume-prebuilt + for arch in amd64 arm64; do + src="./rust-bins/weed-volume-${{ matrix.rust_variant }}-${arch}" + if [ -f "$src" ]; then + cp "$src" "docker/weed-volume-prebuilt/weed-volume-${arch}" + echo "Placed pre-built Rust binary for ${arch}" + fi + done + ls -la docker/weed-volume-prebuilt/ + - name: Free Disk Space if: github.event_name != 'workflow_dispatch' || github.event.inputs.variant == 'all' || github.event.inputs.variant == matrix.variant run: | diff --git a/docker/Dockerfile.go_build b/docker/Dockerfile.go_build index daf76af43..1f7777d72 100644 --- a/docker/Dockerfile.go_build +++ b/docker/Dockerfile.go_build @@ -16,15 +16,21 @@ RUN cd /go/src/github.com/seaweedfs/seaweedfs/weed \ && export LDFLAGS="-X github.com/seaweedfs/seaweedfs/weed/util/version.COMMIT=$(git rev-parse --short HEAD)" \ && CGO_ENABLED=0 go install -tags "$TAGS" -ldflags "-extldflags -static ${LDFLAGS}" -# Rust volume server builder. Alpine packages avoid depending on the -# upstream rust:alpine manifest list, which no longer includes linux/386. +# Rust volume server: use pre-built binary from CI when available (placed in +# weed-volume-prebuilt/ by the build-rust-binaries job), otherwise compile +# from source. Pre-building avoids a multi-hour QEMU-emulated cargo build +# for non-native architectures. FROM alpine:3.23 as rust_builder ARG TARGETARCH +ARG TAGS +COPY weed-volume-prebuilt/ /prebuilt/ COPY --from=builder /go/src/github.com/seaweedfs/seaweedfs/seaweed-volume /build/seaweed-volume COPY --from=builder /go/src/github.com/seaweedfs/seaweedfs/weed /build/weed WORKDIR /build/seaweed-volume -ARG TAGS -RUN if [ "$TARGETARCH" = "amd64" ] || [ "$TARGETARCH" = "arm64" ]; then \ +RUN if [ -f "/prebuilt/weed-volume-${TARGETARCH}" ]; then \ + echo "Using pre-built Rust binary for ${TARGETARCH}" && \ + cp "/prebuilt/weed-volume-${TARGETARCH}" /weed-volume; \ + elif [ "$TARGETARCH" = "amd64" ] || [ "$TARGETARCH" = "arm64" ]; then \ apk add --no-cache musl-dev openssl-dev protobuf-dev git rust cargo; \ if [ "$TAGS" = "5BytesOffset" ]; then \ cargo build --release; \ diff --git a/docker/weed-volume-prebuilt/.gitkeep b/docker/weed-volume-prebuilt/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/test/s3/iam/s3_sts_get_federation_token_test.go b/test/s3/iam/s3_sts_get_federation_token_test.go new file mode 100644 index 000000000..2a718cba9 --- /dev/null +++ b/test/s3/iam/s3_sts_get_federation_token_test.go @@ -0,0 +1,511 @@ +package iam + +import ( + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + v4 "github.com/aws/aws-sdk-go/aws/signer/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// GetFederationTokenTestResponse represents the STS GetFederationToken response +type GetFederationTokenTestResponse struct { + XMLName xml.Name `xml:"GetFederationTokenResponse"` + Result struct { + Credentials struct { + AccessKeyId string `xml:"AccessKeyId"` + SecretAccessKey string `xml:"SecretAccessKey"` + SessionToken string `xml:"SessionToken"` + Expiration string `xml:"Expiration"` + } `xml:"Credentials"` + FederatedUser struct { + FederatedUserId string `xml:"FederatedUserId"` + Arn string `xml:"Arn"` + } `xml:"FederatedUser"` + } `xml:"GetFederationTokenResult"` +} + +func getTestCredentials() (string, string) { + accessKey := os.Getenv("STS_TEST_ACCESS_KEY") + if accessKey == "" { + accessKey = "admin" + } + secretKey := os.Getenv("STS_TEST_SECRET_KEY") + if secretKey == "" { + secretKey = "admin" + } + return accessKey, secretKey +} + +// isGetFederationTokenImplemented checks if the running server supports GetFederationToken +func isGetFederationTokenImplemented(t *testing.T) bool { + accessKey, secretKey := getTestCredentials() + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"probe"}, + }, accessKey, secretKey) + if err != nil { + return false + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + if xml.Unmarshal(body, &errResp) == nil { + if errResp.Error.Code == "InvalidAction" || errResp.Error.Code == "NotImplemented" { + return false + } + } + return true +} + +// TestSTSGetFederationTokenValidation tests input validation for the GetFederationToken endpoint +func TestSTSGetFederationTokenValidation(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + if !isSTSEndpointRunning(t) { + t.Fatal("SeaweedFS STS endpoint is not running at", TestSTSEndpoint, "- please run 'make setup-all-tests' first") + } + + if !isGetFederationTokenImplemented(t) { + t.Fatal("GetFederationToken action is not implemented in the running server") + } + + accessKey, secretKey := getTestCredentials() + + t.Run("missing_name", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + // Name is missing + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "MissingParameter", errResp.Error.Code) + }) + + t.Run("name_too_short", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"A"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "InvalidParameterValue", errResp.Error.Code) + }) + + t.Run("name_too_long", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {strings.Repeat("A", 33)}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "InvalidParameterValue", errResp.Error.Code) + }) + + t.Run("name_invalid_characters", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"bad name"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "InvalidParameterValue", errResp.Error.Code) + }) + + t.Run("duration_too_short", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"TestApp"}, + "DurationSeconds": {"100"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "InvalidParameterValue", errResp.Error.Code) + }) + + t.Run("duration_too_long", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"TestApp"}, + "DurationSeconds": {"200000"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "InvalidParameterValue", errResp.Error.Code) + }) + + t.Run("malformed_policy", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"TestApp"}, + "Policy": {"not-valid-json"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var errResp STSErrorTestResponse + require.NoError(t, xml.Unmarshal(body, &errResp), "Failed to parse: %s", string(body)) + assert.Equal(t, "MalformedPolicyDocument", errResp.Error.Code) + }) + + t.Run("anonymous_rejected", func(t *testing.T) { + // GetFederationToken requires SigV4, anonymous should fail + resp, err := callSTSAPI(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"TestApp"}, + }) + require.NoError(t, err) + defer resp.Body.Close() + + assert.NotEqual(t, http.StatusOK, resp.StatusCode) + }) +} + +// TestSTSGetFederationTokenRejectTemporaryCredentials tests that temporary +// credentials (session tokens) are rejected by GetFederationToken +func TestSTSGetFederationTokenRejectTemporaryCredentials(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + if !isSTSEndpointRunning(t) { + t.Skip("SeaweedFS STS endpoint is not running at", TestSTSEndpoint) + } + + if !isGetFederationTokenImplemented(t) { + t.Skip("GetFederationToken not implemented") + } + + accessKey, secretKey := getTestCredentials() + + // First, obtain temporary credentials via AssumeRole + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"AssumeRole"}, + "Version": {"2011-06-15"}, + "RoleArn": {"arn:aws:iam::role/admin"}, + "RoleSessionName": {"temp-session"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + if resp.StatusCode != http.StatusOK { + t.Skipf("AssumeRole failed (may not be configured): status=%d body=%s", resp.StatusCode, string(body)) + } + + var assumeResp AssumeRoleTestResponse + require.NoError(t, xml.Unmarshal(body, &assumeResp), "Parse AssumeRole response: %s", string(body)) + + tempAccessKey := assumeResp.Result.Credentials.AccessKeyId + tempSecretKey := assumeResp.Result.Credentials.SecretAccessKey + tempSessionToken := assumeResp.Result.Credentials.SessionToken + require.NotEmpty(t, tempAccessKey) + require.NotEmpty(t, tempSessionToken) + + // Now try GetFederationToken with the temporary credentials + // Include X-Amz-Security-Token header which marks this as a temp credential call + params := url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"ShouldFail"}, + } + + reqBody := params.Encode() + req, err := http.NewRequest(http.MethodPost, TestSTSEndpoint+"/", strings.NewReader(reqBody)) + require.NoError(t, err) + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-Amz-Security-Token", tempSessionToken) + + creds := credentials.NewStaticCredentials(tempAccessKey, tempSecretKey, tempSessionToken) + signer := v4.NewSigner(creds) + _, err = signer.Sign(req, strings.NewReader(reqBody), "sts", "us-east-1", time.Now()) + require.NoError(t, err) + + client := &http.Client{Timeout: 30 * time.Second} + resp2, err := client.Do(req) + require.NoError(t, err) + defer resp2.Body.Close() + + body2, _ := io.ReadAll(resp2.Body) + assert.Equal(t, http.StatusForbidden, resp2.StatusCode, + "GetFederationToken should reject temporary credentials: %s", string(body2)) + assert.Contains(t, string(body2), "temporary credentials", + "Error should mention temporary credentials") +} + +// TestSTSGetFederationTokenSuccess tests a successful GetFederationToken call +// and verifies the returned credentials can be used to access S3 +func TestSTSGetFederationTokenSuccess(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + if !isSTSEndpointRunning(t) { + t.Skip("SeaweedFS STS endpoint is not running at", TestSTSEndpoint) + } + + if !isGetFederationTokenImplemented(t) { + t.Skip("GetFederationToken not implemented") + } + + accessKey, secretKey := getTestCredentials() + + t.Run("basic_success", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"AppClient"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + t.Logf("Response status: %d, body: %s", resp.StatusCode, string(body)) + + if resp.StatusCode != http.StatusOK { + var errResp STSErrorTestResponse + _ = xml.Unmarshal(body, &errResp) + t.Fatalf("GetFederationToken failed: code=%s message=%s", errResp.Error.Code, errResp.Error.Message) + } + + var stsResp GetFederationTokenTestResponse + require.NoError(t, xml.Unmarshal(body, &stsResp), "Parse response: %s", string(body)) + + creds := stsResp.Result.Credentials + assert.NotEmpty(t, creds.AccessKeyId) + assert.NotEmpty(t, creds.SecretAccessKey) + assert.NotEmpty(t, creds.SessionToken) + assert.NotEmpty(t, creds.Expiration) + + fedUser := stsResp.Result.FederatedUser + assert.Contains(t, fedUser.Arn, "federated-user/AppClient") + assert.Contains(t, fedUser.FederatedUserId, "AppClient") + }) + + t.Run("with_custom_duration", func(t *testing.T) { + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"DurationTest"}, + "DurationSeconds": {"3600"}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + t.Logf("Response status: %d, body: %s", resp.StatusCode, string(body)) + + if resp.StatusCode == http.StatusOK { + var stsResp GetFederationTokenTestResponse + require.NoError(t, xml.Unmarshal(body, &stsResp)) + assert.NotEmpty(t, stsResp.Result.Credentials.AccessKeyId) + + // Verify expiration is roughly 1 hour from now + expTime, err := time.Parse(time.RFC3339, stsResp.Result.Credentials.Expiration) + require.NoError(t, err) + diff := time.Until(expTime) + assert.InDelta(t, 3600, diff.Seconds(), 60, + "Expiration should be ~1 hour from now") + } + }) + + t.Run("with_36_hour_duration", func(t *testing.T) { + // GetFederationToken allows up to 36 hours (unlike AssumeRole's 12h max) + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"LongDuration"}, + "DurationSeconds": {"129600"}, // 36 hours + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode == http.StatusOK { + var stsResp GetFederationTokenTestResponse + require.NoError(t, xml.Unmarshal(body, &stsResp)) + + expTime, err := time.Parse(time.RFC3339, stsResp.Result.Credentials.Expiration) + require.NoError(t, err) + diff := time.Until(expTime) + assert.InDelta(t, 129600, diff.Seconds(), 60, + "Expiration should be ~36 hours from now") + } else { + // Duration should not cause a rejection + var errResp STSErrorTestResponse + _ = xml.Unmarshal(body, &errResp) + assert.NotContains(t, errResp.Error.Message, "DurationSeconds", + "36-hour duration should be accepted by GetFederationToken") + } + }) +} + +// TestSTSGetFederationTokenWithSessionPolicy tests that vended credentials +// are scoped down by an inline session policy +func TestSTSGetFederationTokenWithSessionPolicy(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + if !isSTSEndpointRunning(t) { + t.Skip("SeaweedFS STS endpoint is not running at", TestSTSEndpoint) + } + + if !isGetFederationTokenImplemented(t) { + t.Skip("GetFederationToken not implemented") + } + + accessKey, secretKey := getTestCredentials() + + // Create a test bucket using admin credentials + adminSess, err := session.NewSession(&aws.Config{ + Region: aws.String("us-east-1"), + Endpoint: aws.String(TestSTSEndpoint), + DisableSSL: aws.Bool(true), + S3ForcePathStyle: aws.Bool(true), + Credentials: credentials.NewStaticCredentials(accessKey, secretKey, ""), + }) + require.NoError(t, err) + + adminS3 := s3.New(adminSess) + bucket := fmt.Sprintf("fed-token-test-%d", time.Now().UnixNano()) + + _, err = adminS3.CreateBucket(&s3.CreateBucketInput{Bucket: aws.String(bucket)}) + require.NoError(t, err) + defer adminS3.DeleteBucket(&s3.DeleteBucketInput{Bucket: aws.String(bucket)}) + + _, err = adminS3.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String("test.txt"), + Body: strings.NewReader("hello"), + }) + require.NoError(t, err) + defer adminS3.DeleteObject(&s3.DeleteObjectInput{Bucket: aws.String(bucket), Key: aws.String("test.txt")}) + + // Get federated credentials with a session policy that only allows GetObject + sessionPolicy := fmt.Sprintf(`{ + "Version": "2012-10-17", + "Statement": [{ + "Effect": "Allow", + "Action": ["s3:GetObject"], + "Resource": ["arn:aws:s3:::%s/*"] + }] + }`, bucket) + + resp, err := callSTSAPIWithSigV4(t, url.Values{ + "Action": {"GetFederationToken"}, + "Version": {"2011-06-15"}, + "Name": {"ScopedClient"}, + "Policy": {sessionPolicy}, + }, accessKey, secretKey) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + t.Logf("GetFederationToken response: status=%d body=%s", resp.StatusCode, string(body)) + + if resp.StatusCode != http.StatusOK { + t.Skipf("GetFederationToken failed (may need IAM policy config): %s", string(body)) + } + + var stsResp GetFederationTokenTestResponse + require.NoError(t, xml.Unmarshal(body, &stsResp)) + + fedCreds := stsResp.Result.Credentials + require.NotEmpty(t, fedCreds.AccessKeyId) + require.NotEmpty(t, fedCreds.SessionToken) + + // Create S3 client with the federated credentials + fedSess, err := session.NewSession(&aws.Config{ + Region: aws.String("us-east-1"), + Endpoint: aws.String(TestSTSEndpoint), + DisableSSL: aws.Bool(true), + S3ForcePathStyle: aws.Bool(true), + Credentials: credentials.NewStaticCredentials( + fedCreds.AccessKeyId, fedCreds.SecretAccessKey, fedCreds.SessionToken), + }) + require.NoError(t, err) + + fedS3 := s3.New(fedSess) + + // GetObject should succeed (allowed by session policy) + getResp, err := fedS3.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String("test.txt"), + }) + if err == nil { + defer getResp.Body.Close() + t.Log("GetObject with federated credentials succeeded (as expected)") + } else { + t.Logf("GetObject with federated credentials: %v (session policy enforcement may vary)", err) + } + + // PutObject should be denied (not allowed by session policy) + _, err = fedS3.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String("denied.txt"), + Body: strings.NewReader("should fail"), + }) + if err != nil { + t.Log("PutObject correctly denied with federated credentials") + assert.Contains(t, err.Error(), "AccessDenied", + "PutObject should be denied by session policy") + } else { + // Clean up if unexpectedly succeeded + adminS3.DeleteObject(&s3.DeleteObjectInput{Bucket: aws.String(bucket), Key: aws.String("denied.txt")}) + t.Log("PutObject unexpectedly succeeded — session policy enforcement may not be active") + } +} diff --git a/test/s3/versioning/s3_versioning_pagination_stress_test.go b/test/s3/versioning/s3_versioning_pagination_stress_test.go index ae18cac65..f22074fab 100644 --- a/test/s3/versioning/s3_versioning_pagination_stress_test.go +++ b/test/s3/versioning/s3_versioning_pagination_stress_test.go @@ -289,6 +289,157 @@ func TestVersioningPaginationMultipleObjectsManyVersions(t *testing.T) { }) } +// TestVersioningPaginationDeepDirectoryHierarchy tests that paginated ListObjectVersions +// correctly skips directory subtrees before the key-marker. This reproduces the +// real-world scenario where Veeam backup objects are spread across many subdirectories +// (e.g., Mailboxes//ItemsData/) and pagination becomes exponentially +// slower as the marker advances through the tree. +// +// Run with: ENABLE_STRESS_TESTS=true go test -v -run TestVersioningPaginationDeepDirectoryHierarchy -timeout 10m +func TestVersioningPaginationDeepDirectoryHierarchy(t *testing.T) { + if os.Getenv("ENABLE_STRESS_TESTS") != "true" { + t.Skip("Skipping stress test. Set ENABLE_STRESS_TESTS=true to run.") + } + + client := getS3Client(t) + bucketName := getNewBucketName() + + createBucket(t, client, bucketName) + defer deleteBucket(t, client, bucketName) + + enableVersioning(t, client, bucketName) + checkVersioningStatus(t, client, bucketName, types.BucketVersioningStatusEnabled) + + // Create a deep directory structure mimicking Veeam 365 backup layout: + // Backup/Organizations//Mailboxes//ItemsData/ + numMailboxes := 20 + filesPerMailbox := 5 + totalObjects := numMailboxes * filesPerMailbox + orgPrefix := "Backup/Organizations/org-001/Mailboxes" + + t.Logf("Creating %d objects across %d subdirectories (depth=6)...", totalObjects, numMailboxes) + startTime := time.Now() + + allKeys := make([]string, 0, totalObjects) + for i := 0; i < numMailboxes; i++ { + mailboxId := fmt.Sprintf("mbx-%03d", i) + for j := 0; j < filesPerMailbox; j++ { + key := fmt.Sprintf("%s/%s/ItemsData/file-%03d.dat", orgPrefix, mailboxId, j) + _, err := client.PutObject(context.TODO(), &s3.PutObjectInput{ + Bucket: aws.String(bucketName), + Key: aws.String(key), + Body: strings.NewReader(fmt.Sprintf("content-%d-%d", i, j)), + }) + require.NoError(t, err) + allKeys = append(allKeys, key) + } + } + t.Logf("Created %d objects in %v", totalObjects, time.Since(startTime)) + + // Test 1: Paginate through all versions with a broad prefix and small maxKeys. + // This forces multiple pages that must skip earlier subdirectory trees. + t.Run("PaginateAcrossSubdirectories", func(t *testing.T) { + maxKeys := int32(10) // Force many pages to exercise marker skipping + var allVersions []types.ObjectVersion + var keyMarker, versionIdMarker *string + pageCount := 0 + pageStartTimes := make([]time.Duration, 0) + + for { + pageStart := time.Now() + resp, err := client.ListObjectVersions(context.TODO(), &s3.ListObjectVersionsInput{ + Bucket: aws.String(bucketName), + Prefix: aws.String(orgPrefix + "/"), + MaxKeys: aws.Int32(maxKeys), + KeyMarker: keyMarker, + VersionIdMarker: versionIdMarker, + }) + pageDuration := time.Since(pageStart) + require.NoError(t, err) + pageCount++ + pageStartTimes = append(pageStartTimes, pageDuration) + + allVersions = append(allVersions, resp.Versions...) + t.Logf("Page %d: %d versions in %v (marker: %v)", + pageCount, len(resp.Versions), pageDuration, + keyMarker) + + if resp.IsTruncated == nil || !*resp.IsTruncated { + break + } + keyMarker = resp.NextKeyMarker + versionIdMarker = resp.NextVersionIdMarker + } + + assert.Greater(t, pageCount, 1, "Should require multiple pages") + + // Verify listed keys exactly match the keys we created (same elements, same order) + listedKeys := make([]string, 0, len(allVersions)) + for _, v := range allVersions { + listedKeys = append(listedKeys, *v.Key) + } + assert.Equal(t, allKeys, listedKeys, + "Listed version keys should exactly match created keys") + + // Check that later pages don't take dramatically longer than earlier ones. + // Before the fix, later pages were exponentially slower because they + // re-traversed the entire tree. With the fix, all pages should be similar. + if len(pageStartTimes) >= 4 { + firstQuarter := pageStartTimes[0] + lastQuarter := pageStartTimes[len(pageStartTimes)-1] + t.Logf("First page: %v, Last page: %v, Ratio: %.1fx", + firstQuarter, lastQuarter, + float64(lastQuarter)/float64(firstQuarter)) + // Allow generous 10x ratio to avoid flakiness; before the fix + // the ratio was 100x+ on large datasets + if lastQuarter > firstQuarter*10 && lastQuarter > 500*time.Millisecond { + t.Errorf("Last page took %.1fx longer than first page (%v vs %v) — possible pagination regression", + float64(lastQuarter)/float64(firstQuarter), lastQuarter, firstQuarter) + } + } + }) + + // Test 2: Paginate with delimiter to verify CommonPrefixes interaction + t.Run("PaginateWithDelimiterAcrossSubdirs", func(t *testing.T) { + maxKeys := int32(5) + var allPrefixes []string + var keyMarker *string + pageCount := 0 + + for { + resp, err := client.ListObjectVersions(context.TODO(), &s3.ListObjectVersionsInput{ + Bucket: aws.String(bucketName), + Prefix: aws.String(orgPrefix + "/"), + Delimiter: aws.String("/"), + MaxKeys: aws.Int32(maxKeys), + KeyMarker: keyMarker, + }) + require.NoError(t, err) + pageCount++ + + for _, cp := range resp.CommonPrefixes { + allPrefixes = append(allPrefixes, *cp.Prefix) + } + + if resp.IsTruncated == nil || !*resp.IsTruncated { + break + } + keyMarker = resp.NextKeyMarker + } + + assert.Greater(t, pageCount, 1, "Should require multiple pages with maxKeys=%d", maxKeys) + + // Build the exact expected prefixes + expectedPrefixes := make([]string, 0, numMailboxes) + for i := 0; i < numMailboxes; i++ { + expectedPrefixes = append(expectedPrefixes, + fmt.Sprintf("%s/mbx-%03d/", orgPrefix, i)) + } + assert.Equal(t, expectedPrefixes, allPrefixes, + "CommonPrefixes should exactly match expected mailbox prefixes") + }) +} + // listAllVersions is a helper to list all versions of a specific object using pagination func listAllVersions(t *testing.T, client *s3.Client, bucketName, objectKey string) []types.ObjectVersion { var allVersions []types.ObjectVersion diff --git a/weed/filer/redis/universal_redis_store.go b/weed/filer/redis/universal_redis_store.go index 7c2a0e47b..4e599fe6f 100644 --- a/weed/filer/redis/universal_redis_store.go +++ b/weed/filer/redis/universal_redis_store.go @@ -173,14 +173,16 @@ func (store *UniversalRedisStore) ListDirectoryEntries(ctx context.Context, dirP members = members[:limit] } + var entry *filer.Entry // fetch entry meta for _, fileName := range members { path := util.NewFullPath(string(dirPath), fileName) - entry, err := store.FindEntry(ctx, path) + entry, err = store.FindEntry(ctx, path) lastFileName = fileName if err != nil { glog.V(0).InfofCtx(ctx, "list %s : %v", path, err) if err == filer_pb.ErrNotFound { + err = nil continue } } else { diff --git a/weed/iam/integration/iam_manager.go b/weed/iam/integration/iam_manager.go index 2e1225a89..fb8a47895 100644 --- a/weed/iam/integration/iam_manager.go +++ b/weed/iam/integration/iam_manager.go @@ -725,6 +725,23 @@ func (m *IAMManager) ExpireSessionForTesting(ctx context.Context, sessionToken s return m.stsService.ExpireSessionForTesting(ctx, sessionToken) } +// GetPoliciesForUser returns the policy names attached to an IAM user. +// Returns an error if the user store is not configured or the lookup fails, +// so callers can fail closed on policy-resolution failures. +func (m *IAMManager) GetPoliciesForUser(ctx context.Context, username string) ([]string, error) { + if m.userStore == nil { + return nil, fmt.Errorf("user store not configured") + } + user, err := m.userStore.GetUser(ctx, username) + if err != nil { + return nil, fmt.Errorf("failed to look up user %q: %w", username, err) + } + if user == nil { + return nil, nil + } + return user.PolicyNames, nil +} + // GetSTSService returns the STS service instance func (m *IAMManager) GetSTSService() *sts.STSService { return m.stsService diff --git a/weed/iam/sts/constants.go b/weed/iam/sts/constants.go index 021aca906..6e293028b 100644 --- a/weed/iam/sts/constants.go +++ b/weed/iam/sts/constants.go @@ -124,6 +124,7 @@ const ( ActionAssumeRole = "sts:AssumeRole" ActionAssumeRoleWithWebIdentity = "sts:AssumeRoleWithWebIdentity" ActionAssumeRoleWithCredentials = "sts:AssumeRoleWithCredentials" + ActionGetFederationToken = "sts:GetFederationToken" ActionValidateSession = "sts:ValidateSession" ) diff --git a/weed/s3api/s3_action_resolver.go b/weed/s3api/s3_action_resolver.go index 1a9edfca8..fa2e0a134 100644 --- a/weed/s3api/s3_action_resolver.go +++ b/weed/s3api/s3_action_resolver.go @@ -296,8 +296,8 @@ func resolveBucketLevelAction(method string, baseAction string) string { // mapBaseActionToS3Format converts coarse-grained base actions to S3 format // This is the fallback when no specific resolution is found func mapBaseActionToS3Format(baseAction string) string { - // Handle actions that already have s3: or iam: prefix - if strings.HasPrefix(baseAction, "s3:") || strings.HasPrefix(baseAction, "iam:") { + // Handle actions that already have a known service prefix + if strings.HasPrefix(baseAction, "s3:") || strings.HasPrefix(baseAction, "iam:") || strings.HasPrefix(baseAction, "sts:") { return baseAction } diff --git a/weed/s3api/s3_action_resolver_test.go b/weed/s3api/s3_action_resolver_test.go index c95ec3972..9e11f1dcb 100644 --- a/weed/s3api/s3_action_resolver_test.go +++ b/weed/s3api/s3_action_resolver_test.go @@ -7,6 +7,59 @@ import ( "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" ) +// TestMapBaseActionToS3Format_ServicePrefixPassthrough verifies that actions +// with known service prefixes (s3:, iam:, sts:) are returned unchanged. +func TestMapBaseActionToS3Format_ServicePrefixPassthrough(t *testing.T) { + tests := []struct { + name string + input string + expect string + }{ + {"s3 prefix", "s3:GetObject", "s3:GetObject"}, + {"iam prefix", "iam:CreateUser", "iam:CreateUser"}, + {"sts:AssumeRole", "sts:AssumeRole", "sts:AssumeRole"}, + {"sts:GetFederationToken", "sts:GetFederationToken", "sts:GetFederationToken"}, + {"sts:GetCallerIdentity", "sts:GetCallerIdentity", "sts:GetCallerIdentity"}, + {"coarse Read maps to s3:GetObject", "Read", s3_constants.S3_ACTION_GET_OBJECT}, + {"coarse Write maps to s3:PutObject", "Write", s3_constants.S3_ACTION_PUT_OBJECT}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mapBaseActionToS3Format(tt.input) + if got != tt.expect { + t.Errorf("mapBaseActionToS3Format(%q) = %q, want %q", tt.input, got, tt.expect) + } + }) + } +} + +// TestResolveS3Action_STSActionsPassthrough verifies that STS actions flow +// through ResolveS3Action unchanged, both with and without an HTTP request. +func TestResolveS3Action_STSActionsPassthrough(t *testing.T) { + stsActions := []string{ + "sts:AssumeRole", + "sts:GetFederationToken", + "sts:GetCallerIdentity", + } + + for _, action := range stsActions { + t.Run("nil_request_"+action, func(t *testing.T) { + got := ResolveS3Action(nil, action, "", "") + if got != action { + t.Errorf("ResolveS3Action(nil, %q) = %q, want %q", action, got, action) + } + }) + t.Run("with_request_"+action, func(t *testing.T) { + r, _ := http.NewRequest(http.MethodPost, "http://localhost/", nil) + got := ResolveS3Action(r, action, "", "") + if got != action { + t.Errorf("ResolveS3Action(r, %q) = %q, want %q", action, got, action) + } + }) + } +} + func TestResolveS3Action_AttributesBeforeVersionId(t *testing.T) { tests := []struct { name string diff --git a/weed/s3api/s3api_object_handlers_list_versioned_test.go b/weed/s3api/s3api_object_handlers_list_versioned_test.go index c362dfe3c..fd2695e03 100644 --- a/weed/s3api/s3api_object_handlers_list_versioned_test.go +++ b/weed/s3api/s3api_object_handlers_list_versioned_test.go @@ -630,3 +630,57 @@ func TestListObjectVersions_PrefixWithLeadingSlash(t *testing.T) { }) } } + +func TestComputeStartFrom(t *testing.T) { + tests := []struct { + name string + keyMarker string + relativePath string + wantStart string + wantInclusive bool + }{ + {"empty marker", "", "", "", false}, + {"empty marker with path", "", "dir", "", false}, + {"root level file", "file1.txt", "", "file1.txt", true}, + {"root level with subpath", "Mailboxes/5ac/file1", "", "Mailboxes", true}, + {"matching subdir", "Mailboxes/5ac/file1", "Mailboxes", "5ac", true}, + {"deeper subdir", "Mailboxes/5ac/ItemsData/file1", "Mailboxes/5ac", "ItemsData", true}, + {"at leaf level", "Mailboxes/5ac/ItemsData/file1", "Mailboxes/5ac/ItemsData", "file1", true}, + {"unrelated directory", "other/path", "Mailboxes", "", false}, + {"marker equals relativePath", "Mailboxes", "Mailboxes", "", false}, + {"marker before directory", "aaa/file", "zzz", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vc := &versionCollector{keyMarker: tt.keyMarker} + startFrom, inclusive := vc.computeStartFrom(tt.relativePath) + assert.Equal(t, tt.wantStart, startFrom) + assert.Equal(t, tt.wantInclusive, inclusive) + }) + } +} + +func TestProcessDirectorySkipsBeforeMarker(t *testing.T) { + tests := []struct { + name string + keyMarker string + entryPath string + shouldSkip bool + }{ + {"no marker", "", "dir_a", false}, + {"dir before marker", "dir_b/file", "dir_a", true}, + {"marker descends into dir", "dir_a/file", "dir_a", false}, + {"dir after marker", "dir_a/file", "dir_b", false}, + {"same prefix different suffix", "dir_a0/file", "dir_a", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + skip := tt.keyMarker != "" && + !strings.HasPrefix(tt.keyMarker, tt.entryPath+"/") && + tt.entryPath+"/" < tt.keyMarker + assert.Equal(t, tt.shouldSkip, skip) + }) + } +} diff --git a/weed/s3api/s3api_object_versioning.go b/weed/s3api/s3api_object_versioning.go index 07f3cc250..60d3ad907 100644 --- a/weed/s3api/s3api_object_versioning.go +++ b/weed/s3api/s3api_object_versioning.go @@ -427,6 +427,34 @@ func (vc *versionCollector) matchesPrefixFilter(entryPath string, isDirectory bo return isMatch || canDescend } +// computeStartFrom extracts the first path component from keyMarker that applies +// to the given directory level (relativePath), allowing the directory listing to +// skip directly to the marker position instead of scanning from the beginning. +// Returns ("", false) when no optimization is possible. +func (vc *versionCollector) computeStartFrom(relativePath string) (startFrom string, inclusive bool) { + if vc.keyMarker == "" { + return "", false + } + + var remainder string + if relativePath == "" { + remainder = vc.keyMarker + } else if strings.HasPrefix(vc.keyMarker, relativePath+"/") { + remainder = vc.keyMarker[len(relativePath)+1:] + } else { + return "", false + } + + if remainder == "" { + return "", false + } + + if idx := strings.Index(remainder, "/"); idx >= 0 { + return remainder[:idx], true + } + return remainder, true +} + // shouldSkipObjectForMarker returns true if the object should be skipped based on keyMarker func (vc *versionCollector) shouldSkipObjectForMarker(objectKey string) bool { if vc.keyMarker == "" { @@ -639,12 +667,21 @@ func (s3a *S3ApiServer) findVersionsRecursively(currentPath, relativePath string // collectVersions recursively collects versions from the given path func (vc *versionCollector) collectVersions(currentPath, relativePath string) error { startFrom := "" + inclusive := false + // On the first iteration, skip ahead to the marker position to avoid + // re-scanning all entries before the marker on every paginated request. + if markerStart, ok := vc.computeStartFrom(relativePath); ok && markerStart != "" { + startFrom = markerStart + inclusive = true + } for { if vc.isFull() { return nil } - entries, isLast, err := vc.s3a.list(currentPath, "", startFrom, false, filer.PaginationSize) + entries, isLast, err := vc.s3a.list(currentPath, "", startFrom, inclusive, filer.PaginationSize) + // After the first batch, use exclusive mode for standard pagination + inclusive = false if err != nil { return err } @@ -731,6 +768,14 @@ func (vc *versionCollector) processDirectory(currentPath, entryPath string, entr vc.processExplicitDirectory(entryPath, entry) } + // Skip entire subdirectory if all keys within it are before the keyMarker. + // All object keys under this directory start with entryPath+"/". If the marker + // doesn't descend into this directory and entryPath+"/" sorts before the marker, + // then every key in this subtree was already returned in a previous page. + if vc.keyMarker != "" && !strings.HasPrefix(vc.keyMarker, entryPath+"/") && entryPath+"/" < vc.keyMarker { + return nil + } + // Recursively search subdirectory fullPath := path.Join(currentPath, entry.Name) if err := vc.collectVersions(fullPath, entryPath); err != nil { diff --git a/weed/s3api/s3api_server.go b/weed/s3api/s3api_server.go index 35f3684d3..d6a1c1437 100644 --- a/weed/s3api/s3api_server.go +++ b/weed/s3api/s3api_server.go @@ -522,7 +522,7 @@ func (s3a *S3ApiServer) UnifiedPostHandler(w http.ResponseWriter, r *http.Reques // 3. Dispatch action := r.Form.Get("Action") - if strings.HasPrefix(action, "AssumeRole") { + if strings.HasPrefix(action, "AssumeRole") || action == "GetCallerIdentity" { // STS if s3a.stsHandlers == nil { s3err.WriteErrorResponse(w, r, s3err.ErrServiceUnavailable) @@ -826,7 +826,11 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) { apiRouter.Methods(http.MethodPost).Path("/").Queries("Action", "AssumeRoleWithLDAPIdentity"). HandlerFunc(track(s3a.stsHandlers.HandleSTSRequest, "STS-LDAP")) - glog.V(1).Infof("STS API enabled on S3 port (AssumeRole, AssumeRoleWithWebIdentity, AssumeRoleWithLDAPIdentity)") + // GetCallerIdentity - returns caller identity based on SigV4 authentication + apiRouter.Methods(http.MethodPost).Path("/").Queries("Action", "GetCallerIdentity"). + HandlerFunc(track(s3a.stsHandlers.HandleSTSRequest, "STS-GetCallerIdentity")) + + glog.V(1).Infof("STS API enabled on S3 port (AssumeRole, AssumeRoleWithWebIdentity, AssumeRoleWithLDAPIdentity, GetCallerIdentity)") } // Embedded IAM API endpoint @@ -849,7 +853,7 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) { // Action in Query String is handled by explicit STS routes above action := r.URL.Query().Get("Action") - if action == "AssumeRole" || action == "AssumeRoleWithWebIdentity" || action == "AssumeRoleWithLDAPIdentity" { + if action == "AssumeRole" || action == "AssumeRoleWithWebIdentity" || action == "AssumeRoleWithLDAPIdentity" || action == "GetCallerIdentity" { return false } diff --git a/weed/s3api/s3api_sts.go b/weed/s3api/s3api_sts.go index 8015f3595..f9167d271 100644 --- a/weed/s3api/s3api_sts.go +++ b/weed/s3api/s3api_sts.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "net/http" + "regexp" "strconv" "time" @@ -36,6 +37,11 @@ const ( actionAssumeRole = "AssumeRole" actionAssumeRoleWithWebIdentity = "AssumeRoleWithWebIdentity" actionAssumeRoleWithLDAPIdentity = "AssumeRoleWithLDAPIdentity" + actionGetCallerIdentity = "GetCallerIdentity" + actionGetFederationToken = "GetFederationToken" + + // GetFederationToken-specific parameters + stsFederationName = "Name" // LDAP parameter names stsLDAPUsername = "LDAPUsername" @@ -43,15 +49,20 @@ const ( stsLDAPProviderName = "LDAPProviderName" ) +// federationNameRegex validates the Name parameter for GetFederationToken per AWS spec +var federationNameRegex = regexp.MustCompile(`^[\w+=,.@-]+$`) + // STS duration constants (AWS specification) const ( - minDurationSeconds = int64(900) // 15 minutes - maxDurationSeconds = int64(43200) // 12 hours + minDurationSeconds = int64(900) // 15 minutes + maxDurationSeconds = int64(43200) // 12 hours (AssumeRole) + defaultFederationDurationSeconds = int64(43200) // 12 hours (GetFederationToken default) + maxFederationDurationSeconds = int64(129600) // 36 hours (GetFederationToken max) ) -// parseDurationSeconds parses and validates the DurationSeconds parameter -// Returns nil if the parameter is not provided, or a pointer to the parsed value -func parseDurationSeconds(r *http.Request) (*int64, STSErrorCode, error) { +// parseDurationSecondsWithBounds parses and validates the DurationSeconds parameter +// against the given min and max bounds. Returns nil if the parameter is not provided. +func parseDurationSecondsWithBounds(r *http.Request, minSec, maxSec int64) (*int64, STSErrorCode, error) { dsStr := r.FormValue("DurationSeconds") if dsStr == "" { return nil, "", nil @@ -62,14 +73,19 @@ func parseDurationSeconds(r *http.Request) (*int64, STSErrorCode, error) { return nil, STSErrInvalidParameterValue, fmt.Errorf("invalid DurationSeconds: %w", err) } - if ds < minDurationSeconds || ds > maxDurationSeconds { + if ds < minSec || ds > maxSec { return nil, STSErrInvalidParameterValue, - fmt.Errorf("DurationSeconds must be between %d and %d seconds", minDurationSeconds, maxDurationSeconds) + fmt.Errorf("DurationSeconds must be between %d and %d seconds", minSec, maxSec) } return &ds, "", nil } +// parseDurationSeconds parses DurationSeconds for AssumeRole (15 min to 12 hours) +func parseDurationSeconds(r *http.Request) (*int64, STSErrorCode, error) { + return parseDurationSecondsWithBounds(r, minDurationSeconds, maxDurationSeconds) +} + // Removed generateSecureCredentials - now using STS service's JWT token generation // The STS service generates proper JWT tokens with embedded claims that can be validated // across distributed instances without shared state. @@ -121,6 +137,10 @@ func (h *STSHandlers) HandleSTSRequest(w http.ResponseWriter, r *http.Request) { h.handleAssumeRoleWithWebIdentity(w, r) case actionAssumeRoleWithLDAPIdentity: h.handleAssumeRoleWithLDAPIdentity(w, r) + case actionGetCallerIdentity: + h.handleGetCallerIdentity(w, r) + case actionGetFederationToken: + h.handleGetFederationToken(w, r) default: h.writeSTSErrorResponse(w, r, STSErrInvalidAction, fmt.Errorf("unsupported action: %s", action)) @@ -293,7 +313,7 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) { // Check authorizations if roleArn != "" { // Check if the caller is authorized to assume the role (sts:AssumeRole permission) - if authErr := h.iam.VerifyActionPermission(r, identity, Action("sts:AssumeRole"), "", roleArn); authErr != s3err.ErrNone { + if authErr := h.iam.VerifyActionPermission(r, identity, Action(sts.ActionAssumeRole), "", roleArn); authErr != s3err.ErrNone { glog.V(2).Infof("AssumeRole: caller %s is not authorized to assume role %s", identity.Name, roleArn) h.writeSTSErrorResponse(w, r, STSErrAccessDenied, fmt.Errorf("user %s is not authorized to assume role %s", identity.Name, roleArn)) @@ -317,7 +337,7 @@ func (h *STSHandlers) handleAssumeRole(w http.ResponseWriter, r *http.Request) { // For safety/consistency with previous logic, we keep the check but strictly it might not be required by AWS for GetSessionToken. // But since this IS AssumeRole, let's keep it. // Admin/Global check when no specific role is requested - if authErr := h.iam.VerifyActionPermission(r, identity, Action("sts:AssumeRole"), "", ""); authErr != s3err.ErrNone { + if authErr := h.iam.VerifyActionPermission(r, identity, Action(sts.ActionAssumeRole), "", ""); authErr != s3err.ErrNone { glog.Warningf("AssumeRole: caller %s attempted to assume role without RoleArn and lacks global sts:AssumeRole permission", identity.Name) h.writeSTSErrorResponse(w, r, STSErrAccessDenied, fmt.Errorf("access denied")) return @@ -502,6 +522,202 @@ func (h *STSHandlers) handleAssumeRoleWithLDAPIdentity(w http.ResponseWriter, r s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) } +// handleGetFederationToken handles the GetFederationToken API action. +// This allows long-term IAM users to obtain temporary credentials scoped down +// by an optional inline session policy. Temporary credentials cannot call this action. +func (h *STSHandlers) handleGetFederationToken(w http.ResponseWriter, r *http.Request) { + // Extract parameters + name := r.FormValue(stsFederationName) + + // Validate required parameters + if name == "" { + h.writeSTSErrorResponse(w, r, STSErrMissingParameter, + fmt.Errorf("Name is required")) + return + } + + // AWS requires Name to be 2-32 characters matching [\w+=,.@-]+ + if len(name) < 2 || len(name) > 32 { + h.writeSTSErrorResponse(w, r, STSErrInvalidParameterValue, + fmt.Errorf("Name must be between 2 and 32 characters")) + return + } + if !federationNameRegex.MatchString(name) { + h.writeSTSErrorResponse(w, r, STSErrInvalidParameterValue, + fmt.Errorf("Name contains invalid characters, must match [\\w+=,.@-]+")) + return + } + + // Parse and validate DurationSeconds (GetFederationToken allows up to 36 hours) + durationSeconds, errCode, err := parseDurationSecondsWithBounds(r, minDurationSeconds, maxFederationDurationSeconds) + if err != nil { + h.writeSTSErrorResponse(w, r, errCode, err) + return + } + + // Reject calls from temporary credentials (session tokens) early, + // before SigV4 verification — no need to authenticate first. + // GetFederationToken can only be called by long-term IAM users. + securityToken := r.Header.Get("X-Amz-Security-Token") + if securityToken == "" { + securityToken = r.URL.Query().Get("X-Amz-Security-Token") + } + if securityToken != "" { + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("GetFederationToken cannot be called with temporary credentials")) + return + } + + // Check if STS service is initialized + if h.stsService == nil || !h.stsService.IsInitialized() { + h.writeSTSErrorResponse(w, r, STSErrSTSNotReady, + fmt.Errorf("STS service not initialized")) + return + } + + // Check if IAM is available for SigV4 verification + if h.iam == nil { + h.writeSTSErrorResponse(w, r, STSErrSTSNotReady, + fmt.Errorf("IAM not configured for STS")) + return + } + + // Validate AWS SigV4 authentication + identity, _, _, _, sigErrCode := h.iam.verifyV4Signature(r, false) + if sigErrCode != s3err.ErrNone { + glog.V(2).Infof("GetFederationToken SigV4 verification failed: %v", sigErrCode) + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("invalid AWS signature: %v", sigErrCode)) + return + } + + if identity == nil { + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("unable to identify caller")) + return + } + + glog.V(2).Infof("GetFederationToken: caller identity=%s, name=%s", identity.Name, name) + + // Check if the caller is authorized to call GetFederationToken + if authErr := h.iam.VerifyActionPermission(r, identity, Action(sts.ActionGetFederationToken), "", ""); authErr != s3err.ErrNone { + glog.V(2).Infof("GetFederationToken: caller %s is not authorized to call GetFederationToken", identity.Name) + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("user %s is not authorized to call GetFederationToken", identity.Name)) + return + } + + // Validate session policy if provided + sessionPolicyJSON, err := sts.NormalizeSessionPolicy(r.FormValue("Policy")) + if err != nil { + h.writeSTSErrorResponse(w, r, STSErrMalformedPolicyDocument, + fmt.Errorf("invalid Policy document: %w", err)) + return + } + + // Calculate duration (default 12 hours for GetFederationToken) + duration := time.Duration(defaultFederationDurationSeconds) * time.Second + if durationSeconds != nil { + duration = time.Duration(*durationSeconds) * time.Second + } + + // Generate session ID + sessionId, err := sts.GenerateSessionId() + if err != nil { + h.writeSTSErrorResponse(w, r, STSErrInternalError, + fmt.Errorf("failed to generate session ID: %w", err)) + return + } + + expiration := time.Now().Add(duration) + accountID := h.getAccountID() + + // Build federated user ARN: arn:aws:sts:::federated-user/ + federatedUserArn := fmt.Sprintf("arn:aws:sts::%s:federated-user/%s", accountID, name) + federatedUserId := fmt.Sprintf("%s:%s", accountID, name) + + // Create session claims — use the caller's principal ARN as the RoleArn + // so that policy evaluation resolves the caller's attached policies + claims := sts.NewSTSSessionClaims(sessionId, h.stsService.Config.Issuer, expiration). + WithSessionName(name). + WithRoleInfo(identity.PrincipalArn, federatedUserId, federatedUserArn) + + // Embed the caller's effective policies into the token. + // Merge identity.PolicyNames (from SigV4 identity) with policies resolved + // from the IAM manager (which may include group-attached policies). + policySet := make(map[string]struct{}) + for _, p := range identity.PolicyNames { + policySet[p] = struct{}{} + } + + var policyManager *integration.IAMManager + if h.iam.iamIntegration != nil { + if provider, ok := h.iam.iamIntegration.(IAMManagerProvider); ok { + policyManager = provider.GetIAMManager() + } + } + if policyManager != nil { + userPolicies, err := policyManager.GetPoliciesForUser(r.Context(), identity.Name) + if err != nil { + glog.V(2).Infof("GetFederationToken: failed to resolve policies for %s: %v", identity.Name, err) + h.writeSTSErrorResponse(w, r, STSErrInternalError, + fmt.Errorf("failed to resolve caller policies")) + return + } + for _, p := range userPolicies { + policySet[p] = struct{}{} + } + } + + if len(policySet) > 0 { + merged := make([]string, 0, len(policySet)) + for p := range policySet { + merged = append(merged, p) + } + claims.WithPolicies(merged) + } + + if sessionPolicyJSON != "" { + claims.WithSessionPolicy(sessionPolicyJSON) + } + + // Generate JWT session token + sessionToken, err := h.stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + if err != nil { + h.writeSTSErrorResponse(w, r, STSErrInternalError, + fmt.Errorf("failed to generate session token: %w", err)) + return + } + + // Generate temporary credentials + stsCredGen := sts.NewCredentialGenerator() + stsCredsDet, err := stsCredGen.GenerateTemporaryCredentials(sessionId, expiration) + if err != nil { + h.writeSTSErrorResponse(w, r, STSErrInternalError, + fmt.Errorf("failed to generate temporary credentials: %w", err)) + return + } + + // Build and return response + xmlResponse := &GetFederationTokenResponse{ + Result: GetFederationTokenResult{ + Credentials: STSCredentials{ + AccessKeyId: stsCredsDet.AccessKeyId, + SecretAccessKey: stsCredsDet.SecretAccessKey, + SessionToken: sessionToken, + Expiration: expiration.Format(time.RFC3339), + }, + FederatedUser: FederatedUser{ + FederatedUserId: federatedUserId, + Arn: federatedUserArn, + }, + }, + } + xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) + + s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) +} + // prepareSTSCredentials extracts common shared logic for credential generation func (h *STSHandlers) prepareSTSCredentials(ctx context.Context, roleArn, roleSessionName string, durationSeconds *int64, sessionPolicy string, modifyClaims func(*sts.STSSessionClaims)) (STSCredentials, *AssumedRoleUser, error) { @@ -616,6 +832,52 @@ func (h *STSHandlers) prepareSTSCredentials(ctx context.Context, roleArn, roleSe return stsCreds, assumedUser, nil } +// handleGetCallerIdentity handles the GetCallerIdentity API action. +// It returns the identity (ARN, account, user ID) of the caller based on SigV4 authentication. +func (h *STSHandlers) handleGetCallerIdentity(w http.ResponseWriter, r *http.Request) { + if h.iam == nil { + h.writeSTSErrorResponse(w, r, STSErrSTSNotReady, + fmt.Errorf("IAM not configured for STS")) + return + } + + identity, _, _, _, sigErrCode := h.iam.verifyV4Signature(r, false) + if sigErrCode != s3err.ErrNone { + glog.V(2).Infof("GetCallerIdentity SigV4 verification failed: %v", sigErrCode) + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("invalid AWS signature: %v", sigErrCode)) + return + } + + if identity == nil { + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("unable to identify caller")) + return + } + + accountID := h.getAccountID() + + arn := identity.PrincipalArn + if arn == "" { + arn = fmt.Sprintf("arn:aws:iam::%s:user/%s", accountID, identity.Name) + } + + userId := identity.Name + + glog.V(2).Infof("GetCallerIdentity: identity=%s, arn=%s, account=%s", identity.Name, arn, accountID) + + xmlResponse := &GetCallerIdentityResponse{ + Result: GetCallerIdentityResult{ + Arn: arn, + UserId: userId, + Account: accountID, + }, + } + xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) + + s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) +} + // STS Response types for XML marshaling // AssumeRoleWithWebIdentityResponse is the response for AssumeRoleWithWebIdentity @@ -678,6 +940,43 @@ type LDAPIdentityResult struct { AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` } +// GetCallerIdentityResponse is the response for GetCallerIdentity +type GetCallerIdentityResponse struct { + XMLName xml.Name `xml:"https://sts.amazonaws.com/doc/2011-06-15/ GetCallerIdentityResponse"` + Result GetCallerIdentityResult `xml:"GetCallerIdentityResult"` + ResponseMetadata struct { + RequestId string `xml:"RequestId,omitempty"` + } `xml:"ResponseMetadata,omitempty"` +} + +// GetCallerIdentityResult contains the result of GetCallerIdentity +type GetCallerIdentityResult struct { + Arn string `xml:"Arn"` + UserId string `xml:"UserId"` + Account string `xml:"Account"` +} + +// GetFederationTokenResponse is the response for GetFederationToken +type GetFederationTokenResponse struct { + XMLName xml.Name `xml:"https://sts.amazonaws.com/doc/2011-06-15/ GetFederationTokenResponse"` + Result GetFederationTokenResult `xml:"GetFederationTokenResult"` + ResponseMetadata struct { + RequestId string `xml:"RequestId,omitempty"` + } `xml:"ResponseMetadata,omitempty"` +} + +// GetFederationTokenResult contains the result of GetFederationToken +type GetFederationTokenResult struct { + Credentials STSCredentials `xml:"Credentials"` + FederatedUser FederatedUser `xml:"FederatedUser"` +} + +// FederatedUser contains information about the federated user +type FederatedUser struct { + FederatedUserId string `xml:"FederatedUserId"` + Arn string `xml:"Arn"` +} + // STS Error types // STSErrorCode represents STS error codes diff --git a/weed/s3api/s3api_sts_get_caller_identity_test.go b/weed/s3api/s3api_sts_get_caller_identity_test.go new file mode 100644 index 000000000..4aa34a141 --- /dev/null +++ b/weed/s3api/s3api_sts_get_caller_identity_test.go @@ -0,0 +1,33 @@ +package s3api + +import ( + "encoding/xml" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetCallerIdentityResponse_XMLMarshal(t *testing.T) { + response := &GetCallerIdentityResponse{ + Result: GetCallerIdentityResult{ + Arn: fmt.Sprintf("arn:aws:iam::%s:user/alice", defaultAccountID), + UserId: "alice", + Account: defaultAccountID, + }, + } + response.ResponseMetadata.RequestId = "test-request-id" + + data, err := xml.MarshalIndent(response, "", " ") + require.NoError(t, err) + + xmlStr := string(data) + assert.Contains(t, xmlStr, "GetCallerIdentityResponse") + assert.Contains(t, xmlStr, "GetCallerIdentityResult") + assert.Contains(t, xmlStr, "arn:aws:iam::000000000000:user/alice") + assert.Contains(t, xmlStr, "alice") + assert.Contains(t, xmlStr, "000000000000") + assert.Contains(t, xmlStr, "test-request-id") + assert.Contains(t, xmlStr, "https://sts.amazonaws.com/doc/2011-06-15/") +} diff --git a/weed/s3api/s3api_sts_get_federation_token_test.go b/weed/s3api/s3api_sts_get_federation_token_test.go new file mode 100644 index 000000000..7375a1822 --- /dev/null +++ b/weed/s3api/s3api_sts_get_federation_token_test.go @@ -0,0 +1,746 @@ +package s3api + +import ( + "context" + "encoding/xml" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "sort" + "strings" + "testing" + "time" + + "github.com/seaweedfs/seaweedfs/weed/iam/integration" + "github.com/seaweedfs/seaweedfs/weed/iam/policy" + "github.com/seaweedfs/seaweedfs/weed/iam/sts" + "github.com/seaweedfs/seaweedfs/weed/pb/iam_pb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockUserStore implements integration.UserStore for testing GetPoliciesForUser +type mockUserStore struct { + users map[string]*iam_pb.Identity +} + +func (m *mockUserStore) GetUser(_ context.Context, username string) (*iam_pb.Identity, error) { + u, ok := m.users[username] + if !ok { + return nil, nil + } + return u, nil +} + +// TestGetFederationToken_BasicFlow tests basic credential generation for GetFederationToken +func TestGetFederationToken_BasicFlow(t *testing.T) { + stsService, _ := setupTestSTSService(t) + + iam := &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + } + stsHandlers := NewSTSHandlers(stsService, iam) + + // Simulate the core logic of handleGetFederationToken + name := "BobApp" + callerIdentity := &Identity{ + Name: "alice", + PrincipalArn: fmt.Sprintf("arn:aws:iam::%s:user/alice", defaultAccountID), + PolicyNames: []string{"S3ReadPolicy"}, + } + + accountID := stsHandlers.getAccountID() + + // Generate session ID and credentials + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(12 * time.Hour) + federatedUserArn := fmt.Sprintf("arn:aws:sts::%s:federated-user/%s", accountID, name) + federatedUserId := fmt.Sprintf("%s:%s", accountID, name) + + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName(name). + WithRoleInfo(callerIdentity.PrincipalArn, federatedUserId, federatedUserArn). + WithPolicies(callerIdentity.PolicyNames) + + sessionToken, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + // Validate the session token + sessionInfo, err := stsService.ValidateSessionToken(context.Background(), sessionToken) + require.NoError(t, err) + require.NotNil(t, sessionInfo) + + // Verify the session info contains caller's policies + assert.Equal(t, []string{"S3ReadPolicy"}, sessionInfo.Policies) + + // Verify principal is the federated user ARN + assert.Equal(t, federatedUserArn, sessionInfo.Principal) + + // Verify the RoleArn points to the caller's identity (for policy resolution) + assert.Equal(t, callerIdentity.PrincipalArn, sessionInfo.RoleArn) + + // Verify session name + assert.Equal(t, name, sessionInfo.SessionName) +} + +// TestGetFederationToken_WithSessionPolicy tests session policy scoping +func TestGetFederationToken_WithSessionPolicy(t *testing.T) { + stsService, _ := setupTestSTSService(t) + + stsHandlers := NewSTSHandlers(stsService, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + accountID := stsHandlers.getAccountID() + name := "ScopedApp" + + sessionPolicyJSON := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:GetObject"],"Resource":["arn:aws:s3:::my-bucket/*"]}]}` + normalizedPolicy, err := sts.NormalizeSessionPolicy(sessionPolicyJSON) + require.NoError(t, err) + + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(12 * time.Hour) + federatedUserArn := fmt.Sprintf("arn:aws:sts::%s:federated-user/%s", accountID, name) + federatedUserId := fmt.Sprintf("%s:%s", accountID, name) + + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName(name). + WithRoleInfo("arn:aws:iam::000000000000:user/caller", federatedUserId, federatedUserArn). + WithPolicies([]string{"S3FullAccess"}). + WithSessionPolicy(normalizedPolicy) + + sessionToken, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + sessionInfo, err := stsService.ValidateSessionToken(context.Background(), sessionToken) + require.NoError(t, err) + require.NotNil(t, sessionInfo) + + // Verify session policy is embedded + assert.NotEmpty(t, sessionInfo.SessionPolicy) + assert.Contains(t, sessionInfo.SessionPolicy, "s3:GetObject") + + // Verify caller's policies are still present + assert.Equal(t, []string{"S3FullAccess"}, sessionInfo.Policies) +} + +// TestGetFederationToken_RejectTemporaryCredentials tests that requests with +// session tokens are rejected. +func TestGetFederationToken_RejectTemporaryCredentials(t *testing.T) { + stsService, _ := setupTestSTSService(t) + stsHandlers := NewSTSHandlers(stsService, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + tests := []struct { + name string + setToken func(r *http.Request) + description string + }{ + { + name: "SessionTokenInHeader", + setToken: func(r *http.Request) { + r.Header.Set("X-Amz-Security-Token", "some-session-token") + }, + description: "Session token in X-Amz-Security-Token header should be rejected", + }, + { + name: "SessionTokenInQuery", + setToken: func(r *http.Request) { + q := r.URL.Query() + q.Set("X-Amz-Security-Token", "some-session-token") + r.URL.RawQuery = q.Encode() + }, + description: "Session token in query string should be rejected", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + form := url.Values{} + form.Set("Action", "GetFederationToken") + form.Set("Name", "TestUser") + form.Set("Version", "2011-06-15") + + req := httptest.NewRequest("POST", "/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + tt.setToken(req) + + // Parse form so the handler can read it + require.NoError(t, req.ParseForm()) + // Re-set values after parse + req.Form.Set("Action", "GetFederationToken") + req.Form.Set("Name", "TestUser") + req.Form.Set("Version", "2011-06-15") + + rr := httptest.NewRecorder() + stsHandlers.HandleSTSRequest(rr, req) + + // The handler rejects temporary credentials before SigV4 verification + assert.Equal(t, http.StatusForbidden, rr.Code, tt.description) + assert.Contains(t, rr.Body.String(), "AccessDenied") + assert.Contains(t, rr.Body.String(), "cannot be called with temporary credentials") + }) + } +} + +// TestGetFederationToken_MissingName tests that a missing Name parameter returns an error +func TestGetFederationToken_MissingName(t *testing.T) { + stsService, _ := setupTestSTSService(t) + stsHandlers := NewSTSHandlers(stsService, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + req := httptest.NewRequest("POST", "/", nil) + req.Form = url.Values{} + req.Form.Set("Action", "GetFederationToken") + req.Form.Set("Version", "2011-06-15") + // Name is intentionally omitted + + rr := httptest.NewRecorder() + stsHandlers.HandleSTSRequest(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Name is required") +} + +// TestGetFederationToken_NameValidation tests Name parameter validation +func TestGetFederationToken_NameValidation(t *testing.T) { + stsService, _ := setupTestSTSService(t) + stsHandlers := NewSTSHandlers(stsService, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + tests := []struct { + name string + federName string + expectError bool + errContains string + }{ + { + name: "TooShort", + federName: "A", + expectError: true, + errContains: "between 2 and 32", + }, + { + name: "TooLong", + federName: strings.Repeat("A", 33), + expectError: true, + errContains: "between 2 and 32", + }, + { + name: "MinLength", + federName: "AB", + expectError: false, + }, + { + name: "MaxLength", + federName: strings.Repeat("A", 32), + expectError: false, + }, + { + name: "ValidSpecialChars", + federName: "user+=,.@-test", + expectError: false, + }, + { + name: "InvalidChars_Space", + federName: "bad name", + expectError: true, + errContains: "invalid characters", + }, + { + name: "InvalidChars_Slash", + federName: "bad/name", + expectError: true, + errContains: "invalid characters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/", nil) + req.Form = url.Values{} + req.Form.Set("Action", "GetFederationToken") + req.Form.Set("Name", tt.federName) + req.Form.Set("Version", "2011-06-15") + + rr := httptest.NewRecorder() + stsHandlers.HandleSTSRequest(rr, req) + + if tt.expectError { + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), tt.errContains) + } else { + // Valid name should proceed past validation — will fail at SigV4 + // (returns 403 because we have no real signature) + assert.NotEqual(t, http.StatusBadRequest, rr.Code, + "Valid name should not produce a 400 for name validation") + } + }) + } +} + +// TestGetFederationToken_DurationValidation tests DurationSeconds validation +func TestGetFederationToken_DurationValidation(t *testing.T) { + stsService, _ := setupTestSTSService(t) + stsHandlers := NewSTSHandlers(stsService, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + tests := []struct { + name string + duration string + expectError bool + errContains string + }{ + { + name: "BelowMinimum", + duration: "899", + expectError: true, + errContains: "between", + }, + { + name: "AboveMaximum", + duration: "129601", + expectError: true, + errContains: "between", + }, + { + name: "InvalidFormat", + duration: "not-a-number", + expectError: true, + errContains: "invalid DurationSeconds", + }, + { + name: "MinimumValid", + duration: "900", + expectError: false, + }, + { + name: "MaximumValid_36Hours", + duration: "129600", + expectError: false, + }, + { + name: "Default12Hours", + duration: "43200", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/", nil) + req.Form = url.Values{} + req.Form.Set("Action", "GetFederationToken") + req.Form.Set("Name", "TestUser") + req.Form.Set("DurationSeconds", tt.duration) + req.Form.Set("Version", "2011-06-15") + + rr := httptest.NewRecorder() + stsHandlers.HandleSTSRequest(rr, req) + + if tt.expectError { + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), tt.errContains) + } else { + // Valid duration should proceed past validation — will fail at SigV4 + assert.NotEqual(t, http.StatusBadRequest, rr.Code, + "Valid duration should not produce a 400 for duration validation") + } + }) + } +} + +// TestGetFederationToken_ResponseFormat tests the XML response structure +func TestGetFederationToken_ResponseFormat(t *testing.T) { + // Verify the response XML structure matches AWS format + response := GetFederationTokenResponse{ + Result: GetFederationTokenResult{ + Credentials: STSCredentials{ + AccessKeyId: "ASIA1234567890", + SecretAccessKey: "secret123", + SessionToken: "token123", + Expiration: "2026-04-02T12:00:00Z", + }, + FederatedUser: FederatedUser{ + FederatedUserId: "000000000000:BobApp", + Arn: "arn:aws:sts::000000000000:federated-user/BobApp", + }, + }, + } + response.ResponseMetadata.RequestId = "test-request-id" + + data, err := xml.MarshalIndent(response, "", " ") + require.NoError(t, err) + + xmlStr := string(data) + assert.Contains(t, xmlStr, "GetFederationTokenResponse") + assert.Contains(t, xmlStr, "GetFederationTokenResult") + assert.Contains(t, xmlStr, "FederatedUser") + assert.Contains(t, xmlStr, "FederatedUserId") + assert.Contains(t, xmlStr, "federated-user/BobApp") + assert.Contains(t, xmlStr, "ASIA1234567890") + assert.Contains(t, xmlStr, "test-request-id") + + // Verify it can be unmarshaled back + var parsed GetFederationTokenResponse + err = xml.Unmarshal(data, &parsed) + require.NoError(t, err) + assert.Equal(t, "ASIA1234567890", parsed.Result.Credentials.AccessKeyId) + assert.Equal(t, "arn:aws:sts::000000000000:federated-user/BobApp", parsed.Result.FederatedUser.Arn) + assert.Equal(t, "000000000000:BobApp", parsed.Result.FederatedUser.FederatedUserId) +} + +// TestGetFederationToken_PolicyEmbedding tests that the caller's policies are embedded +// into the session token using the IAM integration manager +func TestGetFederationToken_PolicyEmbedding(t *testing.T) { + ctx := context.Background() + manager := newTestSTSIntegrationManager(t) + + // Create a policy that the user has attached + userPolicy := &policy.PolicyDocument{ + Version: "2012-10-17", + Statement: []policy.Statement{ + { + Effect: "Allow", + Action: []string{"s3:GetObject", "s3:PutObject"}, + Resource: []string{"arn:aws:s3:::user-bucket/*"}, + }, + }, + } + require.NoError(t, manager.CreatePolicy(ctx, "", "UserS3Policy", userPolicy)) + + stsService := manager.GetSTSService() + + // Simulate what handleGetFederationToken does for policy embedding + name := "AppClient" + callerPolicies := []string{"UserS3Policy"} + + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(12 * time.Hour) + accountID := defaultAccountID + federatedUserArn := fmt.Sprintf("arn:aws:sts::%s:federated-user/%s", accountID, name) + federatedUserId := fmt.Sprintf("%s:%s", accountID, name) + + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName(name). + WithRoleInfo("arn:aws:iam::000000000000:user/caller", federatedUserId, federatedUserArn). + WithPolicies(callerPolicies) + + sessionToken, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + sessionInfo, err := stsService.ValidateSessionToken(ctx, sessionToken) + require.NoError(t, err) + require.NotNil(t, sessionInfo) + + // Verify the caller's policy names are embedded + assert.Equal(t, []string{"UserS3Policy"}, sessionInfo.Policies) +} + +// TestGetFederationToken_PolicyIntersection tests that both the caller's base policies +// and the restrictive session policy are embedded in the token, enabling the +// authorization layer to compute their intersection at request time. +func TestGetFederationToken_PolicyIntersection(t *testing.T) { + ctx := context.Background() + manager := newTestSTSIntegrationManager(t) + + // Create a broad policy for the caller + broadPolicy := &policy.PolicyDocument{ + Version: "2012-10-17", + Statement: []policy.Statement{ + { + Effect: "Allow", + Action: []string{"s3:*"}, + Resource: []string{"arn:aws:s3:::*", "arn:aws:s3:::*/*"}, + }, + }, + } + require.NoError(t, manager.CreatePolicy(ctx, "", "S3FullAccess", broadPolicy)) + + stsService := manager.GetSTSService() + + // Session policy restricts to one bucket and one action + sessionPolicyJSON := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:GetObject"],"Resource":["arn:aws:s3:::restricted-bucket/*"]}]}` + normalizedPolicy, err := sts.NormalizeSessionPolicy(sessionPolicyJSON) + require.NoError(t, err) + + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(12 * time.Hour) + name := "RestrictedApp" + accountID := defaultAccountID + federatedUserArn := fmt.Sprintf("arn:aws:sts::%s:federated-user/%s", accountID, name) + federatedUserId := fmt.Sprintf("%s:%s", accountID, name) + + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName(name). + WithRoleInfo("arn:aws:iam::000000000000:user/caller", federatedUserId, federatedUserArn). + WithPolicies([]string{"S3FullAccess"}). + WithSessionPolicy(normalizedPolicy) + + sessionToken, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + sessionInfo, err := stsService.ValidateSessionToken(ctx, sessionToken) + require.NoError(t, err) + require.NotNil(t, sessionInfo) + + // Verify both the broad base policies and the restrictive session policy are embedded + // The authorization layer computes intersection at request time + assert.Equal(t, []string{"S3FullAccess"}, sessionInfo.Policies, + "Caller's base policies should be embedded in token") + assert.Contains(t, sessionInfo.SessionPolicy, "restricted-bucket", + "Session policy should restrict to specific bucket") + assert.Contains(t, sessionInfo.SessionPolicy, "s3:GetObject", + "Session policy should restrict to specific action") +} + +// TestGetFederationToken_MalformedPolicy tests that invalid policy JSON is rejected +// by the session policy normalization used in the handler +func TestGetFederationToken_MalformedPolicy(t *testing.T) { + tests := []struct { + name string + policyStr string + expectErr bool + }{ + { + name: "InvalidJSON", + policyStr: "not-valid-json", + expectErr: true, + }, + { + name: "EmptyObject", + policyStr: "{}", + expectErr: true, + }, + { + name: "TooLarge", + policyStr: `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:GetObject"],"Resource":["` + strings.Repeat("a", 2048) + `"]}]}`, + expectErr: true, + }, + { + name: "ValidPolicy", + policyStr: `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:GetObject"],"Resource":["arn:aws:s3:::bucket/*"]}]}`, + expectErr: false, + }, + { + name: "EmptyString", + policyStr: "", + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := sts.NormalizeSessionPolicy(tt.policyStr) + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestGetFederationToken_STSNotReady tests that the handler returns 503 when STS is not initialized +func TestGetFederationToken_STSNotReady(t *testing.T) { + // Create handlers with nil STS service + stsHandlers := NewSTSHandlers(nil, &IdentityAccessManagement{ + iamIntegration: &MockIAMIntegration{}, + }) + + req := httptest.NewRequest("POST", "/", nil) + req.Form = url.Values{} + req.Form.Set("Action", "GetFederationToken") + req.Form.Set("Name", "TestUser") + req.Form.Set("Version", "2011-06-15") + + rr := httptest.NewRecorder() + stsHandlers.HandleSTSRequest(rr, req) + + assert.Equal(t, http.StatusServiceUnavailable, rr.Code) + assert.Contains(t, rr.Body.String(), "ServiceUnavailable") +} + +// TestGetFederationToken_DefaultDuration tests that the default duration is 12 hours +func TestGetFederationToken_DefaultDuration(t *testing.T) { + assert.Equal(t, int64(43200), defaultFederationDurationSeconds, "Default duration should be 12 hours (43200 seconds)") + assert.Equal(t, int64(129600), maxFederationDurationSeconds, "Max duration should be 36 hours (129600 seconds)") +} + +// TestGetFederationToken_GetPoliciesForUser tests that GetPoliciesForUser +// correctly resolves user policies from the UserStore and returns errors +// when the store is unavailable. +func TestGetFederationToken_GetPoliciesForUser(t *testing.T) { + ctx := context.Background() + manager := newTestSTSIntegrationManager(t) + + t.Run("NoUserStore", func(t *testing.T) { + // UserStore not set — should return error + policies, err := manager.GetPoliciesForUser(ctx, "alice") + assert.Error(t, err) + assert.Nil(t, policies) + assert.Contains(t, err.Error(), "user store not configured") + }) + + t.Run("UserNotFound", func(t *testing.T) { + manager.SetUserStore(&mockUserStore{users: map[string]*iam_pb.Identity{}}) + policies, err := manager.GetPoliciesForUser(ctx, "nonexistent") + assert.NoError(t, err) + assert.Nil(t, policies) + }) + + t.Run("UserWithPolicies", func(t *testing.T) { + manager.SetUserStore(&mockUserStore{ + users: map[string]*iam_pb.Identity{ + "alice": { + Name: "alice", + PolicyNames: []string{"GroupReadPolicy", "GroupWritePolicy"}, + }, + }, + }) + policies, err := manager.GetPoliciesForUser(ctx, "alice") + assert.NoError(t, err) + assert.Equal(t, []string{"GroupReadPolicy", "GroupWritePolicy"}, policies) + }) + + t.Run("UserWithNoPolicies", func(t *testing.T) { + manager.SetUserStore(&mockUserStore{ + users: map[string]*iam_pb.Identity{ + "bob": {Name: "bob"}, + }, + }) + policies, err := manager.GetPoliciesForUser(ctx, "bob") + assert.NoError(t, err) + assert.Empty(t, policies) + }) +} + +// TestGetFederationToken_PolicyMergeAndDedup tests that the handler's policy +// merge logic correctly combines identity.PolicyNames with IAM-manager-resolved +// policies and deduplicates the result. +func TestGetFederationToken_PolicyMergeAndDedup(t *testing.T) { + ctx := context.Background() + manager := newTestSTSIntegrationManager(t) + + // Create policies so they exist in the engine + for _, name := range []string{"DirectPolicy", "GroupPolicy", "SharedPolicy"} { + require.NoError(t, manager.CreatePolicy(ctx, "", name, &policy.PolicyDocument{ + Version: "2012-10-17", + Statement: []policy.Statement{ + {Effect: "Allow", Action: []string{"s3:GetObject"}, Resource: []string{"arn:aws:s3:::*/*"}}, + }, + })) + } + + // Set up a user store that returns group-attached policies + manager.SetUserStore(&mockUserStore{ + users: map[string]*iam_pb.Identity{ + "alice": { + Name: "alice", + PolicyNames: []string{"GroupPolicy", "SharedPolicy"}, + }, + }, + }) + + stsService := manager.GetSTSService() + + // Simulate what the handler does: merge identity.PolicyNames with GetPoliciesForUser + identityPolicies := []string{"DirectPolicy", "SharedPolicy"} // SharedPolicy overlaps + + policySet := make(map[string]struct{}) + for _, p := range identityPolicies { + policySet[p] = struct{}{} + } + + userPolicies, err := manager.GetPoliciesForUser(ctx, "alice") + require.NoError(t, err) + for _, p := range userPolicies { + policySet[p] = struct{}{} + } + + merged := make([]string, 0, len(policySet)) + for p := range policySet { + merged = append(merged, p) + } + sort.Strings(merged) // deterministic for assertion + + // Should contain all three unique policies, no duplicates + assert.Equal(t, []string{"DirectPolicy", "GroupPolicy", "SharedPolicy"}, merged) + + // Verify the merged policies can be embedded in a token and recovered + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(time.Hour) + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName("test"). + WithRoleInfo("arn:aws:iam::000000000000:user/alice", "000000000000:test", "arn:aws:sts::000000000000:federated-user/test"). + WithPolicies(merged) + + token, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + sessionInfo, err := stsService.ValidateSessionToken(ctx, token) + require.NoError(t, err) + + sort.Strings(sessionInfo.Policies) + assert.Equal(t, []string{"DirectPolicy", "GroupPolicy", "SharedPolicy"}, sessionInfo.Policies, + "Token should contain the deduplicated merge of identity and group policies") +} + +// TestGetFederationToken_PolicyMergeNoManager tests that when the IAM manager +// is unavailable, identity.PolicyNames alone are still embedded. +func TestGetFederationToken_PolicyMergeNoManager(t *testing.T) { + ctx := context.Background() + stsService, _ := setupTestSTSService(t) + + // No IAM manager — only identity.PolicyNames should be used + identityPolicies := []string{"UserDirectPolicy"} + + policySet := make(map[string]struct{}) + for _, p := range identityPolicies { + policySet[p] = struct{}{} + } + + // IAM manager is nil — skip GetPoliciesForUser (mirrors handler logic) + var policyManager *integration.IAMManager // nil + if policyManager != nil { + t.Fatal("policyManager should be nil in this test") + } + + merged := make([]string, 0, len(policySet)) + for p := range policySet { + merged = append(merged, p) + } + + sessionId, err := sts.GenerateSessionId() + require.NoError(t, err) + + expiration := time.Now().Add(time.Hour) + claims := sts.NewSTSSessionClaims(sessionId, stsService.Config.Issuer, expiration). + WithSessionName("test"). + WithRoleInfo("arn:aws:iam::000000000000:user/alice", "000000000000:test", "arn:aws:sts::000000000000:federated-user/test"). + WithPolicies(merged) + + token, err := stsService.GetTokenGenerator().GenerateJWTWithClaims(claims) + require.NoError(t, err) + + sessionInfo, err := stsService.ValidateSessionToken(ctx, token) + require.NoError(t, err) + + assert.Equal(t, []string{"UserDirectPolicy"}, sessionInfo.Policies, + "Without IAM manager, only identity policies should be embedded") +} diff --git a/weed/worker/tasks/ec_balance/ec_balance_task.go b/weed/worker/tasks/ec_balance/ec_balance_task.go index c58c34809..2b0729535 100644 --- a/weed/worker/tasks/ec_balance/ec_balance_task.go +++ b/weed/worker/tasks/ec_balance/ec_balance_task.go @@ -212,7 +212,8 @@ func (t *ECBalanceTask) GetProgress() float64 { // reportProgress updates the stored progress and reports it via the callback func (t *ECBalanceTask) reportProgress(progress float64, stage string) { t.progress = progress - t.reportProgress(progress, stage) + t.ReportProgressWithStage(progress, stage) + glog.Infof("EC balance volume %d: [%.2f] %s", t.volumeID, progress, stage) } // isDedupPhase checks if this is a dedup-phase task (source and target are the same node)