adding cors support (#6987)

* adding cors support

* address some comments

* optimize matchesWildcard

* address comments

* fix for tests

* address comments

* address comments

* address comments

* path building

* refactor

* Update weed/s3api/s3api_bucket_config.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* address comment

Service-level responses need both Access-Control-Allow-Methods and Access-Control-Allow-Headers. After setting Access-Control-Allow-Origin and Access-Control-Expose-Headers, also set Access-Control-Allow-Methods: * and Access-Control-Allow-Headers: * so service endpoints satisfy CORS preflight requirements.

* Update weed/s3api/s3api_bucket_config.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update weed/s3api/s3api_object_handlers.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update weed/s3api/s3api_object_handlers.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix

* refactor

* Update weed/s3api/s3api_bucket_config.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update weed/s3api/s3api_object_handlers.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update weed/s3api/s3api_server.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* simplify

* add cors tests

* fix tests

* fix tests

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Chris Lu
2025-07-15 00:23:54 -07:00
committed by GitHub
parent 548fa0b50a
commit 4b040e8a87
17 changed files with 3756 additions and 53 deletions

649
weed/s3api/cors/cors.go Normal file
View File

@@ -0,0 +1,649 @@
package cors
import (
"context"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
)
// S3 metadata file name constant to avoid typos and reduce duplication
const S3MetadataFileName = ".s3metadata"
// CORSRule represents a single CORS rule
type CORSRule struct {
ID string `xml:"ID,omitempty" json:"ID,omitempty"`
AllowedMethods []string `xml:"AllowedMethod" json:"AllowedMethods"`
AllowedOrigins []string `xml:"AllowedOrigin" json:"AllowedOrigins"`
AllowedHeaders []string `xml:"AllowedHeader,omitempty" json:"AllowedHeaders,omitempty"`
ExposeHeaders []string `xml:"ExposeHeader,omitempty" json:"ExposeHeaders,omitempty"`
MaxAgeSeconds *int `xml:"MaxAgeSeconds,omitempty" json:"MaxAgeSeconds,omitempty"`
}
// CORSConfiguration represents the CORS configuration for a bucket
type CORSConfiguration struct {
XMLName xml.Name `xml:"CORSConfiguration"`
CORSRules []CORSRule `xml:"CORSRule" json:"CORSRules"`
}
// CORSRequest represents a CORS request
type CORSRequest struct {
Origin string
Method string
RequestHeaders []string
IsPreflightRequest bool
AccessControlRequestMethod string
AccessControlRequestHeaders []string
}
// CORSResponse represents CORS response headers
type CORSResponse struct {
AllowOrigin string
AllowMethods string
AllowHeaders string
ExposeHeaders string
MaxAge string
AllowCredentials bool
}
// ValidateConfiguration validates a CORS configuration
func ValidateConfiguration(config *CORSConfiguration) error {
if config == nil {
return fmt.Errorf("CORS configuration cannot be nil")
}
if len(config.CORSRules) == 0 {
return fmt.Errorf("CORS configuration must have at least one rule")
}
if len(config.CORSRules) > 100 {
return fmt.Errorf("CORS configuration cannot have more than 100 rules")
}
for i, rule := range config.CORSRules {
if err := validateRule(&rule); err != nil {
return fmt.Errorf("invalid CORS rule at index %d: %v", i, err)
}
}
return nil
}
// validateRule validates a single CORS rule
func validateRule(rule *CORSRule) error {
if len(rule.AllowedMethods) == 0 {
return fmt.Errorf("AllowedMethods cannot be empty")
}
if len(rule.AllowedOrigins) == 0 {
return fmt.Errorf("AllowedOrigins cannot be empty")
}
// Validate allowed methods
validMethods := map[string]bool{
"GET": true,
"PUT": true,
"POST": true,
"DELETE": true,
"HEAD": true,
}
for _, method := range rule.AllowedMethods {
if !validMethods[method] {
return fmt.Errorf("invalid HTTP method: %s", method)
}
}
// Validate origins
for _, origin := range rule.AllowedOrigins {
if origin == "*" {
continue
}
if err := validateOrigin(origin); err != nil {
return fmt.Errorf("invalid origin %s: %v", origin, err)
}
}
// Validate MaxAgeSeconds
if rule.MaxAgeSeconds != nil && *rule.MaxAgeSeconds < 0 {
return fmt.Errorf("MaxAgeSeconds cannot be negative")
}
return nil
}
// validateOrigin validates an origin string
func validateOrigin(origin string) error {
if origin == "" {
return fmt.Errorf("origin cannot be empty")
}
// Special case: "*" is always valid
if origin == "*" {
return nil
}
// Count wildcards
wildcardCount := strings.Count(origin, "*")
if wildcardCount > 1 {
return fmt.Errorf("origin can contain at most one wildcard")
}
// If there's a wildcard, it should be in a valid position
if wildcardCount == 1 {
// Must be in the format: http://*.example.com or https://*.example.com
if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
return fmt.Errorf("origin with wildcard must start with http:// or https://")
}
}
return nil
}
// ParseRequest parses an HTTP request to extract CORS information
func ParseRequest(r *http.Request) *CORSRequest {
corsReq := &CORSRequest{
Origin: r.Header.Get("Origin"),
Method: r.Method,
}
// Check if this is a preflight request
if r.Method == "OPTIONS" {
corsReq.IsPreflightRequest = true
corsReq.AccessControlRequestMethod = r.Header.Get("Access-Control-Request-Method")
if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" {
corsReq.AccessControlRequestHeaders = strings.Split(headers, ",")
for i := range corsReq.AccessControlRequestHeaders {
corsReq.AccessControlRequestHeaders[i] = strings.TrimSpace(corsReq.AccessControlRequestHeaders[i])
}
}
}
return corsReq
}
// EvaluateRequest evaluates a CORS request against a CORS configuration
func EvaluateRequest(config *CORSConfiguration, corsReq *CORSRequest) (*CORSResponse, error) {
if config == nil || corsReq == nil {
return nil, fmt.Errorf("config and corsReq cannot be nil")
}
if corsReq.Origin == "" {
return nil, fmt.Errorf("origin header is required for CORS requests")
}
// Find the first rule that matches the origin
for _, rule := range config.CORSRules {
if matchesOrigin(rule.AllowedOrigins, corsReq.Origin) {
// For preflight requests, we need more detailed validation
if corsReq.IsPreflightRequest {
return buildPreflightResponse(&rule, corsReq), nil
} else {
// For actual requests, check method
if contains(rule.AllowedMethods, corsReq.Method) {
return buildResponse(&rule, corsReq), nil
}
}
}
}
return nil, fmt.Errorf("no matching CORS rule found")
}
// matchesRule checks if a CORS request matches a CORS rule
func matchesRule(rule *CORSRule, corsReq *CORSRequest) bool {
// Check origin - this is the primary matching criterion
if !matchesOrigin(rule.AllowedOrigins, corsReq.Origin) {
return false
}
// For preflight requests, we need to validate both the requested method and headers
if corsReq.IsPreflightRequest {
// Check if the requested method is allowed
if corsReq.AccessControlRequestMethod != "" {
if !contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod) {
return false
}
}
// Check if all requested headers are allowed
if len(corsReq.AccessControlRequestHeaders) > 0 {
for _, requestedHeader := range corsReq.AccessControlRequestHeaders {
if !matchesHeader(rule.AllowedHeaders, requestedHeader) {
return false
}
}
}
return true
}
// For non-preflight requests, check method matching
method := corsReq.Method
if !contains(rule.AllowedMethods, method) {
return false
}
return true
}
// matchesOrigin checks if an origin matches any of the allowed origins
func matchesOrigin(allowedOrigins []string, origin string) bool {
for _, allowedOrigin := range allowedOrigins {
if allowedOrigin == "*" {
return true
}
if allowedOrigin == origin {
return true
}
// Check wildcard matching
if strings.Contains(allowedOrigin, "*") {
if matchesWildcard(allowedOrigin, origin) {
return true
}
}
}
return false
}
// matchesWildcard checks if an origin matches a wildcard pattern
// Uses string manipulation instead of regex for better performance
func matchesWildcard(pattern, origin string) bool {
// Handle simple cases first
if pattern == "*" {
return true
}
if pattern == origin {
return true
}
// For CORS, we typically only deal with * wildcards (not ? wildcards)
// Use string manipulation for * wildcards only (more efficient than regex)
// Split pattern by wildcards
parts := strings.Split(pattern, "*")
if len(parts) == 1 {
// No wildcards, exact match
return pattern == origin
}
// Check if string starts with first part
if len(parts[0]) > 0 && !strings.HasPrefix(origin, parts[0]) {
return false
}
// Check if string ends with last part
if len(parts[len(parts)-1]) > 0 && !strings.HasSuffix(origin, parts[len(parts)-1]) {
return false
}
// Check middle parts
searchStr := origin
if len(parts[0]) > 0 {
searchStr = searchStr[len(parts[0]):]
}
if len(parts[len(parts)-1]) > 0 {
searchStr = searchStr[:len(searchStr)-len(parts[len(parts)-1])]
}
for i := 1; i < len(parts)-1; i++ {
if len(parts[i]) > 0 {
index := strings.Index(searchStr, parts[i])
if index == -1 {
return false
}
searchStr = searchStr[index+len(parts[i]):]
}
}
return true
}
// matchesHeader checks if a header matches allowed headers
func matchesHeader(allowedHeaders []string, header string) bool {
if len(allowedHeaders) == 0 {
return true // No restrictions
}
for _, allowedHeader := range allowedHeaders {
if allowedHeader == "*" {
return true
}
if strings.EqualFold(allowedHeader, header) {
return true
}
// Check wildcard matching for headers
if strings.Contains(allowedHeader, "*") {
if matchesWildcard(strings.ToLower(allowedHeader), strings.ToLower(header)) {
return true
}
}
}
return false
}
// buildPreflightResponse builds a CORS response for preflight requests
// This function allows partial matches - origin can match while methods/headers may not
func buildPreflightResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
response := &CORSResponse{
AllowOrigin: corsReq.Origin,
}
// Check if the requested method is allowed
methodAllowed := corsReq.AccessControlRequestMethod == "" || contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod)
// Check requested headers
var allowedRequestHeaders []string
allHeadersAllowed := true
if len(corsReq.AccessControlRequestHeaders) > 0 {
// Check if wildcard is allowed
hasWildcard := false
for _, header := range rule.AllowedHeaders {
if header == "*" {
hasWildcard = true
break
}
}
if hasWildcard {
// All requested headers are allowed with wildcard
allowedRequestHeaders = corsReq.AccessControlRequestHeaders
} else {
// Check each requested header individually
for _, requestedHeader := range corsReq.AccessControlRequestHeaders {
if matchesHeader(rule.AllowedHeaders, requestedHeader) {
allowedRequestHeaders = append(allowedRequestHeaders, requestedHeader)
} else {
allHeadersAllowed = false
}
}
}
}
// Only set method and header info if both method and ALL headers are allowed
if methodAllowed && allHeadersAllowed {
response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
if len(allowedRequestHeaders) > 0 {
response.AllowHeaders = strings.Join(allowedRequestHeaders, ", ")
}
// Set exposed headers
if len(rule.ExposeHeaders) > 0 {
response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
}
// Set max age
if rule.MaxAgeSeconds != nil {
response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
}
}
return response
}
// buildResponse builds a CORS response from a matching rule
func buildResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
response := &CORSResponse{
AllowOrigin: corsReq.Origin,
}
// Set allowed methods - for preflight requests, return all allowed methods
if corsReq.IsPreflightRequest {
response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
} else {
// For non-preflight requests, return all allowed methods
response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
}
// Set allowed headers
if corsReq.IsPreflightRequest && len(rule.AllowedHeaders) > 0 {
// For preflight requests, check if wildcard is allowed
hasWildcard := false
for _, header := range rule.AllowedHeaders {
if header == "*" {
hasWildcard = true
break
}
}
if hasWildcard && len(corsReq.AccessControlRequestHeaders) > 0 {
// Return the specific headers that were requested when wildcard is allowed
response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ")
} else if len(corsReq.AccessControlRequestHeaders) > 0 {
// For non-wildcard cases, return the requested headers (preserving case)
// since we already validated they are allowed in matchesRule
response.AllowHeaders = strings.Join(corsReq.AccessControlRequestHeaders, ", ")
} else {
// Fallback to configured headers if no specific headers were requested
response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ")
}
} else if len(rule.AllowedHeaders) > 0 {
// For non-preflight requests, return the allowed headers from the rule
response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ")
}
// Set exposed headers
if len(rule.ExposeHeaders) > 0 {
response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
}
// Set max age
if rule.MaxAgeSeconds != nil {
response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
}
return response
}
// contains checks if a slice contains a string
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
// ApplyHeaders applies CORS headers to an HTTP response
func ApplyHeaders(w http.ResponseWriter, corsResp *CORSResponse) {
if corsResp == nil {
return
}
if corsResp.AllowOrigin != "" {
w.Header().Set("Access-Control-Allow-Origin", corsResp.AllowOrigin)
}
if corsResp.AllowMethods != "" {
w.Header().Set("Access-Control-Allow-Methods", corsResp.AllowMethods)
}
if corsResp.AllowHeaders != "" {
w.Header().Set("Access-Control-Allow-Headers", corsResp.AllowHeaders)
}
if corsResp.ExposeHeaders != "" {
w.Header().Set("Access-Control-Expose-Headers", corsResp.ExposeHeaders)
}
if corsResp.MaxAge != "" {
w.Header().Set("Access-Control-Max-Age", corsResp.MaxAge)
}
if corsResp.AllowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
// FilerClient interface for dependency injection
type FilerClient interface {
WithFilerClient(streamingMode bool, fn func(filer_pb.SeaweedFilerClient) error) error
}
// EntryGetter interface for getting filer entries
type EntryGetter interface {
GetEntry(directory, name string) (*filer_pb.Entry, error)
}
// Storage provides CORS configuration storage operations
type Storage struct {
filerClient FilerClient
entryGetter EntryGetter
bucketsPath string
}
// NewStorage creates a new CORS storage instance
func NewStorage(filerClient FilerClient, entryGetter EntryGetter, bucketsPath string) *Storage {
return &Storage{
filerClient: filerClient,
entryGetter: entryGetter,
bucketsPath: bucketsPath,
}
}
// Store stores CORS configuration in the filer
func (s *Storage) Store(bucket string, config *CORSConfiguration) error {
// Store in bucket metadata
bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
// Get existing metadata
existingEntry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
var metadata map[string]interface{}
if err == nil && existingEntry != nil && len(existingEntry.Content) > 0 {
if err := json.Unmarshal(existingEntry.Content, &metadata); err != nil {
glog.V(1).Infof("Failed to unmarshal existing metadata: %v", err)
metadata = make(map[string]interface{})
}
} else {
metadata = make(map[string]interface{})
}
metadata["cors"] = config
metadataBytes, err := json.Marshal(metadata)
if err != nil {
return fmt.Errorf("failed to marshal bucket metadata: %v", err)
}
// Store metadata
return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
request := &filer_pb.CreateEntryRequest{
Directory: s.bucketsPath + "/" + bucket,
Entry: &filer_pb.Entry{
Name: S3MetadataFileName,
IsDirectory: false,
Attributes: &filer_pb.FuseAttributes{
Crtime: time.Now().Unix(),
Mtime: time.Now().Unix(),
FileMode: 0644,
},
Content: metadataBytes,
},
}
_, err := client.CreateEntry(context.Background(), request)
return err
})
}
// Load loads CORS configuration from the filer
func (s *Storage) Load(bucket string) (*CORSConfiguration, error) {
bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
entry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
if err != nil || entry == nil {
return nil, fmt.Errorf("no CORS configuration found")
}
if len(entry.Content) == 0 {
return nil, fmt.Errorf("no CORS configuration found")
}
var metadata map[string]interface{}
if err := json.Unmarshal(entry.Content, &metadata); err != nil {
return nil, fmt.Errorf("failed to unmarshal metadata: %v", err)
}
corsData, exists := metadata["cors"]
if !exists {
return nil, fmt.Errorf("no CORS configuration found")
}
// Convert back to CORSConfiguration
corsBytes, err := json.Marshal(corsData)
if err != nil {
return nil, fmt.Errorf("failed to marshal CORS data: %v", err)
}
var config CORSConfiguration
if err := json.Unmarshal(corsBytes, &config); err != nil {
return nil, fmt.Errorf("failed to unmarshal CORS configuration: %v", err)
}
return &config, nil
}
// Delete deletes CORS configuration from the filer
func (s *Storage) Delete(bucket string) error {
bucketMetadataPath := filepath.Join(s.bucketsPath, bucket, S3MetadataFileName)
entry, err := s.entryGetter.GetEntry("", bucketMetadataPath)
if err != nil || entry == nil {
return nil // Already deleted or doesn't exist
}
var metadata map[string]interface{}
if len(entry.Content) > 0 {
if err := json.Unmarshal(entry.Content, &metadata); err != nil {
return fmt.Errorf("failed to unmarshal metadata: %v", err)
}
} else {
return nil // No metadata to delete
}
// Remove CORS configuration
delete(metadata, "cors")
metadataBytes, err := json.Marshal(metadata)
if err != nil {
return fmt.Errorf("failed to marshal metadata: %v", err)
}
// Update metadata
return s.filerClient.WithFilerClient(false, func(client filer_pb.SeaweedFilerClient) error {
request := &filer_pb.CreateEntryRequest{
Directory: s.bucketsPath + "/" + bucket,
Entry: &filer_pb.Entry{
Name: S3MetadataFileName,
IsDirectory: false,
Attributes: &filer_pb.FuseAttributes{
Crtime: time.Now().Unix(),
Mtime: time.Now().Unix(),
FileMode: 0644,
},
Content: metadataBytes,
},
}
_, err := client.CreateEntry(context.Background(), request)
return err
})
}

View File

@@ -0,0 +1,526 @@
package cors
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
)
func TestValidateConfiguration(t *testing.T) {
tests := []struct {
name string
config *CORSConfiguration
wantErr bool
}{
{
name: "nil config",
config: nil,
wantErr: true,
},
{
name: "empty rules",
config: &CORSConfiguration{
CORSRules: []CORSRule{},
},
wantErr: true,
},
{
name: "valid single rule",
config: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedMethods: []string{"GET", "POST"},
AllowedOrigins: []string{"*"},
},
},
},
wantErr: false,
},
{
name: "too many rules",
config: &CORSConfiguration{
CORSRules: make([]CORSRule, 101),
},
wantErr: true,
},
{
name: "invalid method",
config: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedMethods: []string{"INVALID"},
AllowedOrigins: []string{"*"},
},
},
},
wantErr: true,
},
{
name: "empty origins",
config: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedMethods: []string{"GET"},
AllowedOrigins: []string{},
},
},
},
wantErr: true,
},
{
name: "invalid origin with multiple wildcards",
config: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedMethods: []string{"GET"},
AllowedOrigins: []string{"http://*.*.example.com"},
},
},
},
wantErr: true,
},
{
name: "negative MaxAgeSeconds",
config: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedMethods: []string{"GET"},
AllowedOrigins: []string{"*"},
MaxAgeSeconds: intPtr(-1),
},
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateConfiguration(tt.config)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateConfiguration() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestValidateOrigin(t *testing.T) {
tests := []struct {
name string
origin string
wantErr bool
}{
{
name: "empty origin",
origin: "",
wantErr: true,
},
{
name: "valid origin",
origin: "http://example.com",
wantErr: false,
},
{
name: "wildcard origin",
origin: "*",
wantErr: false,
},
{
name: "valid wildcard origin",
origin: "http://*.example.com",
wantErr: false,
},
{
name: "https wildcard origin",
origin: "https://*.example.com",
wantErr: false,
},
{
name: "invalid wildcard origin",
origin: "*.example.com",
wantErr: true,
},
{
name: "multiple wildcards",
origin: "http://*.*.example.com",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateOrigin(tt.origin)
if (err != nil) != tt.wantErr {
t.Errorf("validateOrigin() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestParseRequest(t *testing.T) {
tests := []struct {
name string
req *http.Request
want *CORSRequest
}{
{
name: "simple GET request",
req: &http.Request{
Method: "GET",
Header: http.Header{
"Origin": []string{"http://example.com"},
},
},
want: &CORSRequest{
Origin: "http://example.com",
Method: "GET",
IsPreflightRequest: false,
},
},
{
name: "OPTIONS preflight request",
req: &http.Request{
Method: "OPTIONS",
Header: http.Header{
"Origin": []string{"http://example.com"},
"Access-Control-Request-Method": []string{"PUT"},
"Access-Control-Request-Headers": []string{"Content-Type, Authorization"},
},
},
want: &CORSRequest{
Origin: "http://example.com",
Method: "OPTIONS",
IsPreflightRequest: true,
AccessControlRequestMethod: "PUT",
AccessControlRequestHeaders: []string{"Content-Type", "Authorization"},
},
},
{
name: "request without origin",
req: &http.Request{
Method: "GET",
Header: http.Header{},
},
want: &CORSRequest{
Origin: "",
Method: "GET",
IsPreflightRequest: false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ParseRequest(tt.req)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ParseRequest() = %v, want %v", got, tt.want)
}
})
}
}
func TestMatchesOrigin(t *testing.T) {
tests := []struct {
name string
allowedOrigins []string
origin string
want bool
}{
{
name: "wildcard match",
allowedOrigins: []string{"*"},
origin: "http://example.com",
want: true,
},
{
name: "exact match",
allowedOrigins: []string{"http://example.com"},
origin: "http://example.com",
want: true,
},
{
name: "no match",
allowedOrigins: []string{"http://example.com"},
origin: "http://other.com",
want: false,
},
{
name: "wildcard subdomain match",
allowedOrigins: []string{"http://*.example.com"},
origin: "http://api.example.com",
want: true,
},
{
name: "wildcard subdomain no match",
allowedOrigins: []string{"http://*.example.com"},
origin: "http://example.com",
want: false,
},
{
name: "multiple origins with match",
allowedOrigins: []string{"http://example.com", "http://other.com"},
origin: "http://other.com",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := matchesOrigin(tt.allowedOrigins, tt.origin)
if got != tt.want {
t.Errorf("matchesOrigin() = %v, want %v", got, tt.want)
}
})
}
}
func TestMatchesHeader(t *testing.T) {
tests := []struct {
name string
allowedHeaders []string
header string
want bool
}{
{
name: "empty allowed headers",
allowedHeaders: []string{},
header: "Content-Type",
want: true,
},
{
name: "wildcard match",
allowedHeaders: []string{"*"},
header: "Content-Type",
want: true,
},
{
name: "exact match",
allowedHeaders: []string{"Content-Type"},
header: "Content-Type",
want: true,
},
{
name: "case insensitive match",
allowedHeaders: []string{"content-type"},
header: "Content-Type",
want: true,
},
{
name: "no match",
allowedHeaders: []string{"Authorization"},
header: "Content-Type",
want: false,
},
{
name: "wildcard prefix match",
allowedHeaders: []string{"x-amz-*"},
header: "x-amz-date",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := matchesHeader(tt.allowedHeaders, tt.header)
if got != tt.want {
t.Errorf("matchesHeader() = %v, want %v", got, tt.want)
}
})
}
}
func TestEvaluateRequest(t *testing.T) {
config := &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedMethods: []string{"GET", "POST"},
AllowedOrigins: []string{"http://example.com"},
AllowedHeaders: []string{"Content-Type"},
ExposeHeaders: []string{"ETag"},
MaxAgeSeconds: intPtr(3600),
},
{
AllowedMethods: []string{"PUT"},
AllowedOrigins: []string{"*"},
},
},
}
tests := []struct {
name string
config *CORSConfiguration
corsReq *CORSRequest
want *CORSResponse
wantErr bool
}{
{
name: "matching first rule",
config: config,
corsReq: &CORSRequest{
Origin: "http://example.com",
Method: "GET",
},
want: &CORSResponse{
AllowOrigin: "http://example.com",
AllowMethods: "GET, POST",
AllowHeaders: "Content-Type",
ExposeHeaders: "ETag",
MaxAge: "3600",
},
wantErr: false,
},
{
name: "matching second rule",
config: config,
corsReq: &CORSRequest{
Origin: "http://other.com",
Method: "PUT",
},
want: &CORSResponse{
AllowOrigin: "http://other.com",
AllowMethods: "PUT",
},
wantErr: false,
},
{
name: "no matching rule",
config: config,
corsReq: &CORSRequest{
Origin: "http://forbidden.com",
Method: "GET",
},
want: nil,
wantErr: true,
},
{
name: "preflight request",
config: config,
corsReq: &CORSRequest{
Origin: "http://example.com",
Method: "OPTIONS",
IsPreflightRequest: true,
AccessControlRequestMethod: "POST",
AccessControlRequestHeaders: []string{"Content-Type"},
},
want: &CORSResponse{
AllowOrigin: "http://example.com",
AllowMethods: "GET, POST",
AllowHeaders: "Content-Type",
ExposeHeaders: "ETag",
MaxAge: "3600",
},
wantErr: false,
},
{
name: "preflight request with forbidden header",
config: config,
corsReq: &CORSRequest{
Origin: "http://example.com",
Method: "OPTIONS",
IsPreflightRequest: true,
AccessControlRequestMethod: "POST",
AccessControlRequestHeaders: []string{"Authorization"},
},
want: &CORSResponse{
AllowOrigin: "http://example.com",
// No AllowMethods or AllowHeaders because the requested header is forbidden
},
wantErr: false,
},
{
name: "request without origin",
config: config,
corsReq: &CORSRequest{
Origin: "",
Method: "GET",
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := EvaluateRequest(tt.config, tt.corsReq)
if (err != nil) != tt.wantErr {
t.Errorf("EvaluateRequest() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("EvaluateRequest() = %v, want %v", got, tt.want)
}
})
}
}
func TestApplyHeaders(t *testing.T) {
tests := []struct {
name string
corsResp *CORSResponse
want map[string]string
}{
{
name: "nil response",
corsResp: nil,
want: map[string]string{},
},
{
name: "complete response",
corsResp: &CORSResponse{
AllowOrigin: "http://example.com",
AllowMethods: "GET, POST",
AllowHeaders: "Content-Type",
ExposeHeaders: "ETag",
MaxAge: "3600",
},
want: map[string]string{
"Access-Control-Allow-Origin": "http://example.com",
"Access-Control-Allow-Methods": "GET, POST",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Expose-Headers": "ETag",
"Access-Control-Max-Age": "3600",
},
},
{
name: "with credentials",
corsResp: &CORSResponse{
AllowOrigin: "http://example.com",
AllowMethods: "GET",
AllowCredentials: true,
},
want: map[string]string{
"Access-Control-Allow-Origin": "http://example.com",
"Access-Control-Allow-Methods": "GET",
"Access-Control-Allow-Credentials": "true",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a proper response writer using httptest
w := httptest.NewRecorder()
ApplyHeaders(w, tt.corsResp)
// Extract headers from the response
headers := make(map[string]string)
for key, values := range w.Header() {
if len(values) > 0 {
headers[key] = values[0]
}
}
if !reflect.DeepEqual(headers, tt.want) {
t.Errorf("ApplyHeaders() headers = %v, want %v", headers, tt.want)
}
})
}
}
// Helper functions and types for testing
func intPtr(i int) *int {
return &i
}

View File

@@ -0,0 +1,143 @@
package cors
import (
"net/http"
"github.com/seaweedfs/seaweedfs/weed/glog"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
)
// BucketChecker interface for checking bucket existence
type BucketChecker interface {
CheckBucket(r *http.Request, bucket string) s3err.ErrorCode
}
// CORSConfigGetter interface for getting CORS configuration
type CORSConfigGetter interface {
GetCORSConfiguration(bucket string) (*CORSConfiguration, s3err.ErrorCode)
}
// Middleware handles CORS evaluation for all S3 API requests
type Middleware struct {
storage *Storage
bucketChecker BucketChecker
corsConfigGetter CORSConfigGetter
}
// NewMiddleware creates a new CORS middleware instance
func NewMiddleware(storage *Storage, bucketChecker BucketChecker, corsConfigGetter CORSConfigGetter) *Middleware {
return &Middleware{
storage: storage,
bucketChecker: bucketChecker,
corsConfigGetter: corsConfigGetter,
}
}
// evaluateCORSRequest performs the common CORS request evaluation logic
// Returns: (corsResponse, responseWritten, shouldContinue)
// - corsResponse: the CORS response if evaluation succeeded
// - responseWritten: true if an error response was already written
// - shouldContinue: true if the request should continue to the next handler
func (m *Middleware) evaluateCORSRequest(w http.ResponseWriter, r *http.Request) (*CORSResponse, bool, bool) {
// Parse CORS request
corsReq := ParseRequest(r)
if corsReq.Origin == "" {
// Not a CORS request
return nil, false, true
}
// Extract bucket from request
bucket, _ := s3_constants.GetBucketAndObject(r)
if bucket == "" {
return nil, false, true
}
// Check if bucket exists
if err := m.bucketChecker.CheckBucket(r, bucket); err != s3err.ErrNone {
// For non-existent buckets, let the normal handler deal with it
return nil, false, true
}
// Load CORS configuration from cache
config, errCode := m.corsConfigGetter.GetCORSConfiguration(bucket)
if errCode != s3err.ErrNone || config == nil {
// No CORS configuration, handle based on request type
if corsReq.IsPreflightRequest {
// Preflight request without CORS config should fail
s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied)
return nil, true, false // Response written, don't continue
}
// Non-preflight request, continue normally
return nil, false, true
}
// Evaluate CORS request
corsResp, err := EvaluateRequest(config, corsReq)
if err != nil {
glog.V(3).Infof("CORS evaluation failed for bucket %s: %v", bucket, err)
if corsReq.IsPreflightRequest {
// Preflight request that doesn't match CORS rules should fail
s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied)
return nil, true, false // Response written, don't continue
}
// Non-preflight request, continue normally but without CORS headers
return nil, false, true
}
return corsResp, false, false
}
// Handler returns the CORS middleware handler
func (m *Middleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Use the common evaluation logic
corsResp, responseWritten, shouldContinue := m.evaluateCORSRequest(w, r)
if responseWritten {
// Response was already written (error case)
return
}
if shouldContinue {
// Continue with normal request processing
next.ServeHTTP(w, r)
return
}
// Parse request to check if it's a preflight request
corsReq := ParseRequest(r)
// Apply CORS headers to response
ApplyHeaders(w, corsResp)
// Handle preflight requests
if corsReq.IsPreflightRequest {
// Preflight request should return 200 OK with just CORS headers
w.WriteHeader(http.StatusOK)
return
}
// Continue with normal request processing
next.ServeHTTP(w, r)
})
}
// HandleOptionsRequest handles OPTIONS requests for CORS preflight
func (m *Middleware) HandleOptionsRequest(w http.ResponseWriter, r *http.Request) {
// Use the common evaluation logic
corsResp, responseWritten, shouldContinue := m.evaluateCORSRequest(w, r)
if responseWritten {
// Response was already written (error case)
return
}
if shouldContinue || corsResp == nil {
// Not a CORS request or should continue normally
w.WriteHeader(http.StatusOK)
return
}
// Apply CORS headers and return success
ApplyHeaders(w, corsResp)
w.WriteHeader(http.StatusOK)
}