From 7c59b639c9312ced50762717688bd192a56fcff9 Mon Sep 17 00:00:00 2001 From: Chris Lu Date: Thu, 2 Apr 2026 15:59:09 -0700 Subject: [PATCH] STS: add GetCallerIdentity support (#8893) * STS: add GetCallerIdentity support Implement the AWS STS GetCallerIdentity action, which returns the ARN, account ID, and user ID of the caller based on SigV4 authentication. This is commonly used by AWS SDKs and CLI tools (e.g. `aws sts get-caller-identity`) to verify credentials and determine the authenticated identity. * test: remove trivial GetCallerIdentity tests Remove the XML unmarshal test (we don't consume this response as input) and the routing constant test (just asserts a literal equals itself). * fix: route GetCallerIdentity through STS in UnifiedPostHandler and use stable UserId - UnifiedPostHandler only dispatched actions starting with "AssumeRole" to STS, so GetCallerIdentity in a POST body would fall through to the IAM path and get AccessDenied for non-admin users. Add explicit check for GetCallerIdentity. - Use identity.Name as UserId instead of credential.AccessKey, which is a transient value and incorrect for STS assumed-role callers. --- weed/s3api/s3api_server.go | 10 ++- weed/s3api/s3api_sts.go | 65 +++++++++++++++++++ .../s3api_sts_get_caller_identity_test.go | 33 ++++++++++ 3 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 weed/s3api/s3api_sts_get_caller_identity_test.go diff --git a/weed/s3api/s3api_server.go b/weed/s3api/s3api_server.go index 35f3684d3..d6a1c1437 100644 --- a/weed/s3api/s3api_server.go +++ b/weed/s3api/s3api_server.go @@ -522,7 +522,7 @@ func (s3a *S3ApiServer) UnifiedPostHandler(w http.ResponseWriter, r *http.Reques // 3. Dispatch action := r.Form.Get("Action") - if strings.HasPrefix(action, "AssumeRole") { + if strings.HasPrefix(action, "AssumeRole") || action == "GetCallerIdentity" { // STS if s3a.stsHandlers == nil { s3err.WriteErrorResponse(w, r, s3err.ErrServiceUnavailable) @@ -826,7 +826,11 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) { apiRouter.Methods(http.MethodPost).Path("/").Queries("Action", "AssumeRoleWithLDAPIdentity"). HandlerFunc(track(s3a.stsHandlers.HandleSTSRequest, "STS-LDAP")) - glog.V(1).Infof("STS API enabled on S3 port (AssumeRole, AssumeRoleWithWebIdentity, AssumeRoleWithLDAPIdentity)") + // GetCallerIdentity - returns caller identity based on SigV4 authentication + apiRouter.Methods(http.MethodPost).Path("/").Queries("Action", "GetCallerIdentity"). + HandlerFunc(track(s3a.stsHandlers.HandleSTSRequest, "STS-GetCallerIdentity")) + + glog.V(1).Infof("STS API enabled on S3 port (AssumeRole, AssumeRoleWithWebIdentity, AssumeRoleWithLDAPIdentity, GetCallerIdentity)") } // Embedded IAM API endpoint @@ -849,7 +853,7 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) { // Action in Query String is handled by explicit STS routes above action := r.URL.Query().Get("Action") - if action == "AssumeRole" || action == "AssumeRoleWithWebIdentity" || action == "AssumeRoleWithLDAPIdentity" { + if action == "AssumeRole" || action == "AssumeRoleWithWebIdentity" || action == "AssumeRoleWithLDAPIdentity" || action == "GetCallerIdentity" { return false } diff --git a/weed/s3api/s3api_sts.go b/weed/s3api/s3api_sts.go index 8015f3595..478c09139 100644 --- a/weed/s3api/s3api_sts.go +++ b/weed/s3api/s3api_sts.go @@ -36,6 +36,7 @@ const ( actionAssumeRole = "AssumeRole" actionAssumeRoleWithWebIdentity = "AssumeRoleWithWebIdentity" actionAssumeRoleWithLDAPIdentity = "AssumeRoleWithLDAPIdentity" + actionGetCallerIdentity = "GetCallerIdentity" // LDAP parameter names stsLDAPUsername = "LDAPUsername" @@ -121,6 +122,8 @@ func (h *STSHandlers) HandleSTSRequest(w http.ResponseWriter, r *http.Request) { h.handleAssumeRoleWithWebIdentity(w, r) case actionAssumeRoleWithLDAPIdentity: h.handleAssumeRoleWithLDAPIdentity(w, r) + case actionGetCallerIdentity: + h.handleGetCallerIdentity(w, r) default: h.writeSTSErrorResponse(w, r, STSErrInvalidAction, fmt.Errorf("unsupported action: %s", action)) @@ -616,6 +619,52 @@ func (h *STSHandlers) prepareSTSCredentials(ctx context.Context, roleArn, roleSe return stsCreds, assumedUser, nil } +// handleGetCallerIdentity handles the GetCallerIdentity API action. +// It returns the identity (ARN, account, user ID) of the caller based on SigV4 authentication. +func (h *STSHandlers) handleGetCallerIdentity(w http.ResponseWriter, r *http.Request) { + if h.iam == nil { + h.writeSTSErrorResponse(w, r, STSErrSTSNotReady, + fmt.Errorf("IAM not configured for STS")) + return + } + + identity, _, _, _, sigErrCode := h.iam.verifyV4Signature(r, false) + if sigErrCode != s3err.ErrNone { + glog.V(2).Infof("GetCallerIdentity SigV4 verification failed: %v", sigErrCode) + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("invalid AWS signature: %v", sigErrCode)) + return + } + + if identity == nil { + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("unable to identify caller")) + return + } + + accountID := h.getAccountID() + + arn := identity.PrincipalArn + if arn == "" { + arn = fmt.Sprintf("arn:aws:iam::%s:user/%s", accountID, identity.Name) + } + + userId := identity.Name + + glog.V(2).Infof("GetCallerIdentity: identity=%s, arn=%s, account=%s", identity.Name, arn, accountID) + + xmlResponse := &GetCallerIdentityResponse{ + Result: GetCallerIdentityResult{ + Arn: arn, + UserId: userId, + Account: accountID, + }, + } + xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) + + s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) +} + // STS Response types for XML marshaling // AssumeRoleWithWebIdentityResponse is the response for AssumeRoleWithWebIdentity @@ -678,6 +727,22 @@ type LDAPIdentityResult struct { AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` } +// GetCallerIdentityResponse is the response for GetCallerIdentity +type GetCallerIdentityResponse struct { + XMLName xml.Name `xml:"https://sts.amazonaws.com/doc/2011-06-15/ GetCallerIdentityResponse"` + Result GetCallerIdentityResult `xml:"GetCallerIdentityResult"` + ResponseMetadata struct { + RequestId string `xml:"RequestId,omitempty"` + } `xml:"ResponseMetadata,omitempty"` +} + +// GetCallerIdentityResult contains the result of GetCallerIdentity +type GetCallerIdentityResult struct { + Arn string `xml:"Arn"` + UserId string `xml:"UserId"` + Account string `xml:"Account"` +} + // STS Error types // STSErrorCode represents STS error codes diff --git a/weed/s3api/s3api_sts_get_caller_identity_test.go b/weed/s3api/s3api_sts_get_caller_identity_test.go new file mode 100644 index 000000000..4aa34a141 --- /dev/null +++ b/weed/s3api/s3api_sts_get_caller_identity_test.go @@ -0,0 +1,33 @@ +package s3api + +import ( + "encoding/xml" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetCallerIdentityResponse_XMLMarshal(t *testing.T) { + response := &GetCallerIdentityResponse{ + Result: GetCallerIdentityResult{ + Arn: fmt.Sprintf("arn:aws:iam::%s:user/alice", defaultAccountID), + UserId: "alice", + Account: defaultAccountID, + }, + } + response.ResponseMetadata.RequestId = "test-request-id" + + data, err := xml.MarshalIndent(response, "", " ") + require.NoError(t, err) + + xmlStr := string(data) + assert.Contains(t, xmlStr, "GetCallerIdentityResponse") + assert.Contains(t, xmlStr, "GetCallerIdentityResult") + assert.Contains(t, xmlStr, "arn:aws:iam::000000000000:user/alice") + assert.Contains(t, xmlStr, "alice") + assert.Contains(t, xmlStr, "000000000000") + assert.Contains(t, xmlStr, "test-request-id") + assert.Contains(t, xmlStr, "https://sts.amazonaws.com/doc/2011-06-15/") +}