s3tables: improve account ID handling and define missing error codes

Updated getPrincipalFromRequest to prioritize X-Amz-Account-ID header and
added getAccountID helper. Defined ErrVersionTokenMismatch and ErrCodeConflict
for better optimistic concurrency support.
This commit is contained in:
Chris Lu
2026-01-28 13:25:22 -08:00
parent e381b81b47
commit 31867b6f75
2 changed files with 33 additions and 8 deletions

View File

@@ -2,6 +2,7 @@ package s3tables
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -23,6 +24,17 @@ const (
ExtendedKeyTags = "s3tables.tags" ExtendedKeyTags = "s3tables.tags"
) )
var (
ErrVersionTokenMismatch = errors.New("version token mismatch")
)
type ResourceType string
const (
ResourceTypeBucket ResourceType = "bucket"
ResourceTypeTable ResourceType = "table"
)
// S3TablesHandler handles S3 Tables API requests // S3TablesHandler handles S3 Tables API requests
type S3TablesHandler struct { type S3TablesHandler struct {
region string region string
@@ -148,20 +160,32 @@ func (h *S3TablesHandler) getPrincipalFromRequest(r *http.Request) string {
} }
// Fallback to request header (e.g., for testing or legacy clients) // Fallback to request header (e.g., for testing or legacy clients)
principal := r.Header.Get("X-Amz-Principal") if principal := r.Header.Get("X-Amz-Principal"); principal != "" {
if principal != "" {
return principal return principal
} }
// Default to account ID (owner) // Fallback to the authenticated account ID
if accountID := r.Header.Get(s3_constants.AmzAccountId); accountID != "" {
return accountID
}
// Default to handler's default account ID
return h.accountID
}
// getAccountID returns the authenticated account ID from the request or the handler's default
func (h *S3TablesHandler) getAccountID(r *http.Request) string {
if accountID := r.Header.Get(s3_constants.AmzAccountId); accountID != "" {
return accountID
}
return h.accountID return h.accountID
} }
// Request/Response helpers // Request/Response helpers
func (h *S3TablesHandler) readRequestBody(r *http.Request, v interface{}) error { func (h *S3TablesHandler) readRequestBody(r *http.Request, v interface{}) error {
body, err := io.ReadAll(r.Body)
defer r.Body.Close() defer r.Body.Close()
body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
return fmt.Errorf("failed to read request body: %w", err) return fmt.Errorf("failed to read request body: %w", err)
} }
@@ -203,10 +227,10 @@ func (h *S3TablesHandler) writeError(w http.ResponseWriter, status int, code, me
// ARN generation helpers // ARN generation helpers
func (h *S3TablesHandler) generateTableBucketARN(bucketName string) string { func (h *S3TablesHandler) generateTableBucketARN(r *http.Request, bucketName string) string {
return fmt.Sprintf("arn:aws:s3tables:%s:%s:bucket/%s", h.region, h.accountID, bucketName) return fmt.Sprintf("arn:aws:s3tables:%s:%s:bucket/%s", h.region, h.getAccountID(r), bucketName)
} }
func (h *S3TablesHandler) generateTableARN(bucketName, tableID string) string { func (h *S3TablesHandler) generateTableARN(r *http.Request, bucketName, tableID string) string {
return fmt.Sprintf("arn:aws:s3tables:%s:%s:bucket/%s/table/%s", h.region, h.accountID, bucketName, tableID) return fmt.Sprintf("arn:aws:s3tables:%s:%s:bucket/%s/table/%s", h.region, h.getAccountID(r), bucketName, tableID)
} }

View File

@@ -287,4 +287,5 @@ const (
ErrCodeInvalidRequest = "InvalidRequest" ErrCodeInvalidRequest = "InvalidRequest"
ErrCodeInternalError = "InternalError" ErrCodeInternalError = "InternalError"
ErrCodeNoSuchPolicy = "NoSuchPolicy" ErrCodeNoSuchPolicy = "NoSuchPolicy"
ErrCodeConflict = "Conflict"
) )