s3tables: improve account ID handling and define missing error codes
Updated getPrincipalFromRequest to prioritize X-Amz-Account-ID header and added getAccountID helper. Defined ErrVersionTokenMismatch and ErrCodeConflict for better optimistic concurrency support.
This commit is contained in:
@@ -2,6 +2,7 @@ package s3tables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -23,6 +24,17 @@ const (
|
|||||||
ExtendedKeyTags = "s3tables.tags"
|
ExtendedKeyTags = "s3tables.tags"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrVersionTokenMismatch = errors.New("version token mismatch")
|
||||||
|
)
|
||||||
|
|
||||||
|
type ResourceType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ResourceTypeBucket ResourceType = "bucket"
|
||||||
|
ResourceTypeTable ResourceType = "table"
|
||||||
|
)
|
||||||
|
|
||||||
// S3TablesHandler handles S3 Tables API requests
|
// S3TablesHandler handles S3 Tables API requests
|
||||||
type S3TablesHandler struct {
|
type S3TablesHandler struct {
|
||||||
region string
|
region string
|
||||||
@@ -148,20 +160,32 @@ func (h *S3TablesHandler) getPrincipalFromRequest(r *http.Request) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to request header (e.g., for testing or legacy clients)
|
// Fallback to request header (e.g., for testing or legacy clients)
|
||||||
principal := r.Header.Get("X-Amz-Principal")
|
if principal := r.Header.Get("X-Amz-Principal"); principal != "" {
|
||||||
if principal != "" {
|
|
||||||
return principal
|
return principal
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default to account ID (owner)
|
// Fallback to the authenticated account ID
|
||||||
|
if accountID := r.Header.Get(s3_constants.AmzAccountId); accountID != "" {
|
||||||
|
return accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to handler's default account ID
|
||||||
|
return h.accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAccountID returns the authenticated account ID from the request or the handler's default
|
||||||
|
func (h *S3TablesHandler) getAccountID(r *http.Request) string {
|
||||||
|
if accountID := r.Header.Get(s3_constants.AmzAccountId); accountID != "" {
|
||||||
|
return accountID
|
||||||
|
}
|
||||||
return h.accountID
|
return h.accountID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request/Response helpers
|
// Request/Response helpers
|
||||||
|
|
||||||
func (h *S3TablesHandler) readRequestBody(r *http.Request, v interface{}) error {
|
func (h *S3TablesHandler) readRequestBody(r *http.Request, v interface{}) error {
|
||||||
body, err := io.ReadAll(r.Body)
|
|
||||||
defer r.Body.Close()
|
defer r.Body.Close()
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read request body: %w", err)
|
return fmt.Errorf("failed to read request body: %w", err)
|
||||||
}
|
}
|
||||||
@@ -203,10 +227,10 @@ func (h *S3TablesHandler) writeError(w http.ResponseWriter, status int, code, me
|
|||||||
|
|
||||||
// ARN generation helpers
|
// ARN generation helpers
|
||||||
|
|
||||||
func (h *S3TablesHandler) generateTableBucketARN(bucketName string) string {
|
func (h *S3TablesHandler) generateTableBucketARN(r *http.Request, bucketName string) string {
|
||||||
return fmt.Sprintf("arn:aws:s3tables:%s:%s:bucket/%s", h.region, h.accountID, bucketName)
|
return fmt.Sprintf("arn:aws:s3tables:%s:%s:bucket/%s", h.region, h.getAccountID(r), bucketName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *S3TablesHandler) generateTableARN(bucketName, tableID string) string {
|
func (h *S3TablesHandler) generateTableARN(r *http.Request, bucketName, tableID string) string {
|
||||||
return fmt.Sprintf("arn:aws:s3tables:%s:%s:bucket/%s/table/%s", h.region, h.accountID, bucketName, tableID)
|
return fmt.Sprintf("arn:aws:s3tables:%s:%s:bucket/%s/table/%s", h.region, h.getAccountID(r), bucketName, tableID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -287,4 +287,5 @@ const (
|
|||||||
ErrCodeInvalidRequest = "InvalidRequest"
|
ErrCodeInvalidRequest = "InvalidRequest"
|
||||||
ErrCodeInternalError = "InternalError"
|
ErrCodeInternalError = "InternalError"
|
||||||
ErrCodeNoSuchPolicy = "NoSuchPolicy"
|
ErrCodeNoSuchPolicy = "NoSuchPolicy"
|
||||||
|
ErrCodeConflict = "Conflict"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user