diff --git a/weed/s3api/s3api_server.go b/weed/s3api/s3api_server.go index 35f3684d3..d6a1c1437 100644 --- a/weed/s3api/s3api_server.go +++ b/weed/s3api/s3api_server.go @@ -522,7 +522,7 @@ func (s3a *S3ApiServer) UnifiedPostHandler(w http.ResponseWriter, r *http.Reques // 3. Dispatch action := r.Form.Get("Action") - if strings.HasPrefix(action, "AssumeRole") { + if strings.HasPrefix(action, "AssumeRole") || action == "GetCallerIdentity" { // STS if s3a.stsHandlers == nil { s3err.WriteErrorResponse(w, r, s3err.ErrServiceUnavailable) @@ -826,7 +826,11 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) { apiRouter.Methods(http.MethodPost).Path("/").Queries("Action", "AssumeRoleWithLDAPIdentity"). HandlerFunc(track(s3a.stsHandlers.HandleSTSRequest, "STS-LDAP")) - glog.V(1).Infof("STS API enabled on S3 port (AssumeRole, AssumeRoleWithWebIdentity, AssumeRoleWithLDAPIdentity)") + // GetCallerIdentity - returns caller identity based on SigV4 authentication + apiRouter.Methods(http.MethodPost).Path("/").Queries("Action", "GetCallerIdentity"). + HandlerFunc(track(s3a.stsHandlers.HandleSTSRequest, "STS-GetCallerIdentity")) + + glog.V(1).Infof("STS API enabled on S3 port (AssumeRole, AssumeRoleWithWebIdentity, AssumeRoleWithLDAPIdentity, GetCallerIdentity)") } // Embedded IAM API endpoint @@ -849,7 +853,7 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) { // Action in Query String is handled by explicit STS routes above action := r.URL.Query().Get("Action") - if action == "AssumeRole" || action == "AssumeRoleWithWebIdentity" || action == "AssumeRoleWithLDAPIdentity" { + if action == "AssumeRole" || action == "AssumeRoleWithWebIdentity" || action == "AssumeRoleWithLDAPIdentity" || action == "GetCallerIdentity" { return false } diff --git a/weed/s3api/s3api_sts.go b/weed/s3api/s3api_sts.go index 8015f3595..478c09139 100644 --- a/weed/s3api/s3api_sts.go +++ b/weed/s3api/s3api_sts.go @@ -36,6 +36,7 @@ const ( actionAssumeRole = "AssumeRole" actionAssumeRoleWithWebIdentity = "AssumeRoleWithWebIdentity" actionAssumeRoleWithLDAPIdentity = "AssumeRoleWithLDAPIdentity" + actionGetCallerIdentity = "GetCallerIdentity" // LDAP parameter names stsLDAPUsername = "LDAPUsername" @@ -121,6 +122,8 @@ func (h *STSHandlers) HandleSTSRequest(w http.ResponseWriter, r *http.Request) { h.handleAssumeRoleWithWebIdentity(w, r) case actionAssumeRoleWithLDAPIdentity: h.handleAssumeRoleWithLDAPIdentity(w, r) + case actionGetCallerIdentity: + h.handleGetCallerIdentity(w, r) default: h.writeSTSErrorResponse(w, r, STSErrInvalidAction, fmt.Errorf("unsupported action: %s", action)) @@ -616,6 +619,52 @@ func (h *STSHandlers) prepareSTSCredentials(ctx context.Context, roleArn, roleSe return stsCreds, assumedUser, nil } +// handleGetCallerIdentity handles the GetCallerIdentity API action. +// It returns the identity (ARN, account, user ID) of the caller based on SigV4 authentication. +func (h *STSHandlers) handleGetCallerIdentity(w http.ResponseWriter, r *http.Request) { + if h.iam == nil { + h.writeSTSErrorResponse(w, r, STSErrSTSNotReady, + fmt.Errorf("IAM not configured for STS")) + return + } + + identity, _, _, _, sigErrCode := h.iam.verifyV4Signature(r, false) + if sigErrCode != s3err.ErrNone { + glog.V(2).Infof("GetCallerIdentity SigV4 verification failed: %v", sigErrCode) + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("invalid AWS signature: %v", sigErrCode)) + return + } + + if identity == nil { + h.writeSTSErrorResponse(w, r, STSErrAccessDenied, + fmt.Errorf("unable to identify caller")) + return + } + + accountID := h.getAccountID() + + arn := identity.PrincipalArn + if arn == "" { + arn = fmt.Sprintf("arn:aws:iam::%s:user/%s", accountID, identity.Name) + } + + userId := identity.Name + + glog.V(2).Infof("GetCallerIdentity: identity=%s, arn=%s, account=%s", identity.Name, arn, accountID) + + xmlResponse := &GetCallerIdentityResponse{ + Result: GetCallerIdentityResult{ + Arn: arn, + UserId: userId, + Account: accountID, + }, + } + xmlResponse.ResponseMetadata.RequestId = request_id.GetFromRequest(r) + + s3err.WriteXMLResponse(w, r, http.StatusOK, xmlResponse) +} + // STS Response types for XML marshaling // AssumeRoleWithWebIdentityResponse is the response for AssumeRoleWithWebIdentity @@ -678,6 +727,22 @@ type LDAPIdentityResult struct { AssumedRoleUser *AssumedRoleUser `xml:"AssumedRoleUser,omitempty"` } +// GetCallerIdentityResponse is the response for GetCallerIdentity +type GetCallerIdentityResponse struct { + XMLName xml.Name `xml:"https://sts.amazonaws.com/doc/2011-06-15/ GetCallerIdentityResponse"` + Result GetCallerIdentityResult `xml:"GetCallerIdentityResult"` + ResponseMetadata struct { + RequestId string `xml:"RequestId,omitempty"` + } `xml:"ResponseMetadata,omitempty"` +} + +// GetCallerIdentityResult contains the result of GetCallerIdentity +type GetCallerIdentityResult struct { + Arn string `xml:"Arn"` + UserId string `xml:"UserId"` + Account string `xml:"Account"` +} + // STS Error types // STSErrorCode represents STS error codes diff --git a/weed/s3api/s3api_sts_get_caller_identity_test.go b/weed/s3api/s3api_sts_get_caller_identity_test.go new file mode 100644 index 000000000..4aa34a141 --- /dev/null +++ b/weed/s3api/s3api_sts_get_caller_identity_test.go @@ -0,0 +1,33 @@ +package s3api + +import ( + "encoding/xml" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetCallerIdentityResponse_XMLMarshal(t *testing.T) { + response := &GetCallerIdentityResponse{ + Result: GetCallerIdentityResult{ + Arn: fmt.Sprintf("arn:aws:iam::%s:user/alice", defaultAccountID), + UserId: "alice", + Account: defaultAccountID, + }, + } + response.ResponseMetadata.RequestId = "test-request-id" + + data, err := xml.MarshalIndent(response, "", " ") + require.NoError(t, err) + + xmlStr := string(data) + assert.Contains(t, xmlStr, "GetCallerIdentityResponse") + assert.Contains(t, xmlStr, "GetCallerIdentityResult") + assert.Contains(t, xmlStr, "arn:aws:iam::000000000000:user/alice") + assert.Contains(t, xmlStr, "alice") + assert.Contains(t, xmlStr, "000000000000") + assert.Contains(t, xmlStr, "test-request-id") + assert.Contains(t, xmlStr, "https://sts.amazonaws.com/doc/2011-06-15/") +}