STS: add GetCallerIdentity support (#8893)

* STS: add GetCallerIdentity support

Implement the AWS STS GetCallerIdentity action, which returns the
ARN, account ID, and user ID of the caller based on SigV4 authentication.
This is commonly used by AWS SDKs and CLI tools (e.g. `aws sts get-caller-identity`)
to verify credentials and determine the authenticated identity.

* test: remove trivial GetCallerIdentity tests

Remove the XML unmarshal test (we don't consume this response as input)
and the routing constant test (just asserts a literal equals itself).

* fix: route GetCallerIdentity through STS in UnifiedPostHandler and use stable UserId

- UnifiedPostHandler only dispatched actions starting with "AssumeRole" to STS,
  so GetCallerIdentity in a POST body would fall through to the IAM path and
  get AccessDenied for non-admin users. Add explicit check for GetCallerIdentity.
- Use identity.Name as UserId instead of credential.AccessKey, which is a
  transient value and incorrect for STS assumed-role callers.
This commit is contained in:
Chris Lu
2026-04-02 15:59:09 -07:00
committed by GitHub
parent 772ad67f6b
commit 7c59b639c9
3 changed files with 105 additions and 3 deletions

View File

@@ -522,7 +522,7 @@ func (s3a *S3ApiServer) UnifiedPostHandler(w http.ResponseWriter, r *http.Reques
// 3. Dispatch // 3. Dispatch
action := r.Form.Get("Action") action := r.Form.Get("Action")
if strings.HasPrefix(action, "AssumeRole") { if strings.HasPrefix(action, "AssumeRole") || action == "GetCallerIdentity" {
// STS // STS
if s3a.stsHandlers == nil { if s3a.stsHandlers == nil {
s3err.WriteErrorResponse(w, r, s3err.ErrServiceUnavailable) s3err.WriteErrorResponse(w, r, s3err.ErrServiceUnavailable)
@@ -826,7 +826,11 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
apiRouter.Methods(http.MethodPost).Path("/").Queries("Action", "AssumeRoleWithLDAPIdentity"). apiRouter.Methods(http.MethodPost).Path("/").Queries("Action", "AssumeRoleWithLDAPIdentity").
HandlerFunc(track(s3a.stsHandlers.HandleSTSRequest, "STS-LDAP")) HandlerFunc(track(s3a.stsHandlers.HandleSTSRequest, "STS-LDAP"))
glog.V(1).Infof("STS API enabled on S3 port (AssumeRole, AssumeRoleWithWebIdentity, AssumeRoleWithLDAPIdentity)") // GetCallerIdentity - returns caller identity based on SigV4 authentication
apiRouter.Methods(http.MethodPost).Path("/").Queries("Action", "GetCallerIdentity").
HandlerFunc(track(s3a.stsHandlers.HandleSTSRequest, "STS-GetCallerIdentity"))
glog.V(1).Infof("STS API enabled on S3 port (AssumeRole, AssumeRoleWithWebIdentity, AssumeRoleWithLDAPIdentity, GetCallerIdentity)")
} }
// Embedded IAM API endpoint // Embedded IAM API endpoint
@@ -849,7 +853,7 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
// Action in Query String is handled by explicit STS routes above // Action in Query String is handled by explicit STS routes above
action := r.URL.Query().Get("Action") action := r.URL.Query().Get("Action")
if action == "AssumeRole" || action == "AssumeRoleWithWebIdentity" || action == "AssumeRoleWithLDAPIdentity" { if action == "AssumeRole" || action == "AssumeRoleWithWebIdentity" || action == "AssumeRoleWithLDAPIdentity" || action == "GetCallerIdentity" {
return false return false
} }

View File

@@ -36,6 +36,7 @@ const (
actionAssumeRole = "AssumeRole" actionAssumeRole = "AssumeRole"
actionAssumeRoleWithWebIdentity = "AssumeRoleWithWebIdentity" actionAssumeRoleWithWebIdentity = "AssumeRoleWithWebIdentity"
actionAssumeRoleWithLDAPIdentity = "AssumeRoleWithLDAPIdentity" actionAssumeRoleWithLDAPIdentity = "AssumeRoleWithLDAPIdentity"
actionGetCallerIdentity = "GetCallerIdentity"
// LDAP parameter names // LDAP parameter names
stsLDAPUsername = "LDAPUsername" stsLDAPUsername = "LDAPUsername"
@@ -121,6 +122,8 @@ func (h *STSHandlers) HandleSTSRequest(w http.ResponseWriter, r *http.Request) {
h.handleAssumeRoleWithWebIdentity(w, r) h.handleAssumeRoleWithWebIdentity(w, r)
case actionAssumeRoleWithLDAPIdentity: case actionAssumeRoleWithLDAPIdentity:
h.handleAssumeRoleWithLDAPIdentity(w, r) h.handleAssumeRoleWithLDAPIdentity(w, r)
case actionGetCallerIdentity:
h.handleGetCallerIdentity(w, r)
default: default:
h.writeSTSErrorResponse(w, r, STSErrInvalidAction, h.writeSTSErrorResponse(w, r, STSErrInvalidAction,
fmt.Errorf("unsupported action: %s", action)) fmt.Errorf("unsupported action: %s", action))
@@ -616,6 +619,52 @@ func (h *STSHandlers) prepareSTSCredentials(ctx context.Context, roleArn, roleSe
return stsCreds, assumedUser, nil return stsCreds, assumedUser, nil
} }
// handleGetCallerIdentity handles the GetCallerIdentity API action.
// It returns the identity (ARN, account, user ID) of the caller based on SigV4 authentication.
func (h *STSHandlers) handleGetCallerIdentity(w http.ResponseWriter, r *http.Request) {
if h.iam == nil {
h.writeSTSErrorResponse(w, r, STSErrSTSNotReady,
fmt.Errorf("IAM not configured for STS"))
return
}
identity, _, _, _, sigErrCode := h.iam.verifyV4Signature(r, false)
if sigErrCode != s3err.ErrNone {
glog.V(2).Infof("GetCallerIdentity SigV4 verification failed: %v", sigErrCode)
h.writeSTSErrorResponse(w, r, STSErrAccessDenied,
fmt.Errorf("invalid AWS signature: %v", sigErrCode))
return
}
if identity == nil {
h.writeSTSErrorResponse(w, r, STSErrAccessDenied,
fmt.Errorf("unable to identify caller"))
return
}
accountID := h.getAccountID()
arn := identity.PrincipalArn
if arn == "" {
arn = fmt.Sprintf("arn:aws:iam::%s:user/%s", accountID, identity.Name)
}
userId := identity.Name
glog.V(2).Infof("GetCallerIdentity: identity=%s, arn=%s, account=%s", identity.Name, arn, accountID)
xmlResponse := &GetCallerIdentityResponse{
Result: GetCallerIdentityResult{
Arn: arn,
UserId: userId,
Account: accountID,
},
}
xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r)
s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse)
}
// STS Response types for XML marshaling // STS Response types for XML marshaling
// AssumeRoleWithWebIdentityResponse is the response for AssumeRoleWithWebIdentity // AssumeRoleWithWebIdentityResponse is the response for AssumeRoleWithWebIdentity
@@ -678,6 +727,22 @@ type LDAPIdentityResult struct {
AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"`
} }
// GetCallerIdentityResponse is the response for GetCallerIdentity
type GetCallerIdentityResponse struct {
XMLName xml.Name `xml:"https://sts.amazonaws.com/doc/2011-06-15/ GetCallerIdentityResponse"`
Result GetCallerIdentityResult `xml:"GetCallerIdentityResult"`
ResponseMetadata struct {
RequestId string `xml:"RequestId,omitempty"`
} `xml:"ResponseMetadata,omitempty"`
}
// GetCallerIdentityResult contains the result of GetCallerIdentity
type GetCallerIdentityResult struct {
Arn string `xml:"Arn"`
UserId string `xml:"UserId"`
Account string `xml:"Account"`
}
// STS Error types // STS Error types
// STSErrorCode represents STS error codes // STSErrorCode represents STS error codes

View File

@@ -0,0 +1,33 @@
package s3api
import (
"encoding/xml"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetCallerIdentityResponse_XMLMarshal(t *testing.T) {
response := &GetCallerIdentityResponse{
Result: GetCallerIdentityResult{
Arn: fmt.Sprintf("arn:aws:iam::%s:user/alice", defaultAccountID),
UserId: "alice",
Account: defaultAccountID,
},
}
response.ResponseMetadata.RequestId = "test-request-id"
data, err := xml.MarshalIndent(response, "", " ")
require.NoError(t, err)
xmlStr := string(data)
assert.Contains(t, xmlStr, "GetCallerIdentityResponse")
assert.Contains(t, xmlStr, "GetCallerIdentityResult")
assert.Contains(t, xmlStr, "<Arn>arn:aws:iam::000000000000:user/alice</Arn>")
assert.Contains(t, xmlStr, "<UserId>alice</UserId>")
assert.Contains(t, xmlStr, "<Account>000000000000</Account>")
assert.Contains(t, xmlStr, "<RequestId>test-request-id</RequestId>")
assert.Contains(t, xmlStr, "https://sts.amazonaws.com/doc/2011-06-15/")
}