diff --git a/weed/command/filer_backup.go b/weed/command/filer_backup.go index fb67e1b25..84cbf4828 100644 --- a/weed/command/filer_backup.go +++ b/weed/command/filer_backup.go @@ -11,6 +11,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/pb" "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" + "github.com/seaweedfs/seaweedfs/weed/replication/repl_util" "github.com/seaweedfs/seaweedfs/weed/replication/source" "github.com/seaweedfs/seaweedfs/weed/security" "github.com/seaweedfs/seaweedfs/weed/util" @@ -143,6 +144,10 @@ func doFilerBackup(grpcDialOption grpc.DialOption, backupOption *FilerBackupOpti sourceFiler.ToGrpcAddress(), sourcePath, *backupOption.proxyByFiler) + + if err := repl_util.InitializeSSEForReplication(filerSource); err != nil { + return fmt.Errorf("SSE initialization failed: %v", err) + } dataSink.SetSourceFiler(filerSource) var processEventFn func(*filer_pb.SubscribeMetadataResponse) error diff --git a/weed/replication/repl_util/replication_util.go b/weed/replication/repl_util/replication_util.go index c9812382c..6e13e0359 100644 --- a/weed/replication/repl_util/replication_util.go +++ b/weed/replication/repl_util/replication_util.go @@ -2,14 +2,73 @@ package repl_util import ( "context" + "io" "github.com/seaweedfs/seaweedfs/weed/filer" "github.com/seaweedfs/seaweedfs/weed/glog" + "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" "github.com/seaweedfs/seaweedfs/weed/replication/source" util_http "github.com/seaweedfs/seaweedfs/weed/util/http" ) -func CopyFromChunkViews(chunkViews *filer.IntervalList[*filer.ChunkView], filerSource *source.FilerSource, writeFunc func(data []byte) error) error { +// CopyFromChunkViews copies chunk data with optional SSE decryption. +// If entry has SSE-encrypted chunks, data is decrypted before writing. +func CopyFromChunkViews(chunkViews *filer.IntervalList[*filer.ChunkView], filerSource *source.FilerSource, writeFunc func(data []byte) error, entry *filer_pb.Entry) error { + if entry != nil { + sseType, err := detectSSEType(entry) + if err != nil { + return err + } + if sseType != filer_pb.SSEType_NONE { + return copyWithDecryption(filerSource, entry, writeFunc) + } + } + return copyChunkViews(chunkViews, filerSource, writeFunc) +} + +func copyWithDecryption(filerSource *source.FilerSource, entry *filer_pb.Entry, writeFunc func(data []byte) error) error { + reader := filer.NewFileReader(filerSource, entry) + decrypted, err := MaybeDecryptReader(reader, entry) + if err != nil { + CloseReader(reader) + return err + } + defer CloseMaybeDecryptedReader(reader, decrypted) + buf := make([]byte, 128*1024) + for { + n, readErr := decrypted.Read(buf) + if n > 0 { + if writeErr := writeFunc(buf[:n]); writeErr != nil { + return writeErr + } + } + if readErr == io.EOF { + return nil + } + if readErr != nil { + return readErr + } + } +} + +// CloseReader closes r if it implements io.Closer. +func CloseReader(r io.Reader) { + if closer, ok := r.(io.Closer); ok { + closer.Close() + } +} + +// CloseMaybeDecryptedReader closes the decrypted reader if it implements io.Closer, +// otherwise falls back to closing the original reader. +func CloseMaybeDecryptedReader(original, decrypted io.Reader) { + if closer, ok := decrypted.(io.Closer); ok { + closer.Close() + } else { + CloseReader(original) + } +} + +func copyChunkViews(chunkViews *filer.IntervalList[*filer.ChunkView], filerSource *source.FilerSource, writeFunc func(data []byte) error) error { for x := chunkViews.Front(); x != nil; x = x.Next { chunk := x.Value diff --git a/weed/replication/repl_util/sse_init.go b/weed/replication/repl_util/sse_init.go new file mode 100644 index 000000000..9ddf46aa0 --- /dev/null +++ b/weed/replication/repl_util/sse_init.go @@ -0,0 +1,55 @@ +package repl_util + +import ( + "sync" + + "github.com/seaweedfs/seaweedfs/weed/glog" + "github.com/seaweedfs/seaweedfs/weed/kms" + "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" + "github.com/seaweedfs/seaweedfs/weed/s3api" + "github.com/seaweedfs/seaweedfs/weed/util" +) + +var ( + sseInitMu sync.Mutex + sseInitialized bool +) + +// InitializeSSEForReplication sets up SSE-S3 and SSE-KMS decryption so that +// replication sinks can transparently decrypt encrypted objects. +// SSE-S3 is initialized from the filer (KEK stored on filer). +// SSE-KMS is initialized from Viper config (security.toml [kms] section or +// WEED_KMS_* environment variables). +// SSE-C cannot be decrypted (customer key not available) and will error at +// decryption time. +// +// Safe to call multiple times; only the first successful initialization takes +// effect. Failed attempts do not prevent future retries. +func InitializeSSEForReplication(filerSource filer_pb.FilerClient) error { + sseInitMu.Lock() + defer sseInitMu.Unlock() + if sseInitialized { + return nil + } + + // Initialize SSE-S3 key manager from filer + if err := s3api.GetSSES3KeyManager().InitializeWithFiler(filerSource); err != nil { + return err + } + + // Attempt KMS initialization from Viper config. + // KMS configuration is typically in the S3 config file which the + // replication commands don't load directly. Support loading from + // security.toml [kms] section or WEED_KMS_* environment variables. + loader := kms.NewConfigLoader(util.GetViper()) + if err := loader.LoadConfigurations(); err != nil { + glog.Warningf("KMS initialization from config failed: %v (SSE-KMS decryption will not be available)", err) + } else if err := loader.ValidateConfiguration(); err != nil { + glog.Warningf("KMS configuration validation failed: %v (SSE-KMS decryption will not be available)", err) + } else { + glog.V(0).Infof("KMS initialized for replication") + } + + sseInitialized = true + return nil +} diff --git a/weed/replication/repl_util/sse_reader.go b/weed/replication/repl_util/sse_reader.go new file mode 100644 index 000000000..9a2a426d6 --- /dev/null +++ b/weed/replication/repl_util/sse_reader.go @@ -0,0 +1,150 @@ +package repl_util + +import ( + "bytes" + "fmt" + "io" + + "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" + "github.com/seaweedfs/seaweedfs/weed/s3api" + "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" +) + +// MaybeDecryptReader wraps reader with SSE decryption if the entry has encrypted chunks. +// Returns the original reader unchanged if no SSE encryption is detected. +func MaybeDecryptReader(reader io.Reader, entry *filer_pb.Entry) (io.Reader, error) { + if entry == nil { + return reader, nil + } + + sseType, err := detectSSEType(entry) + if err != nil { + return nil, err + } + if sseType == filer_pb.SSEType_NONE { + return reader, nil + } + + switch sseType { + case filer_pb.SSEType_SSE_S3: + return decryptSSES3(reader, entry) + case filer_pb.SSEType_SSE_KMS: + return decryptSSEKMS(reader, entry) + case filer_pb.SSEType_SSE_C: + return nil, fmt.Errorf("SSE-C encrypted object cannot be decrypted during replication (customer key not available)") + } + + return nil, fmt.Errorf("unsupported SSE type: %v", sseType) +} + +// MaybeDecryptContent decrypts inline entry content if SSE-encrypted. +// Returns the original content unchanged if no SSE encryption is detected. +func MaybeDecryptContent(content []byte, entry *filer_pb.Entry) ([]byte, error) { + if entry == nil || len(content) == 0 { + return content, nil + } + + sseType, err := detectSSEType(entry) + if err != nil { + return nil, err + } + if sseType == filer_pb.SSEType_NONE { + return content, nil + } + + reader := bytes.NewReader(content) + decrypted, err := MaybeDecryptReader(reader, entry) + if err != nil { + return nil, err + } + return io.ReadAll(decrypted) +} + +func detectSSEType(entry *filer_pb.Entry) (filer_pb.SSEType, error) { + // Check chunk metadata first + var detected filer_pb.SSEType + for _, chunk := range entry.GetChunks() { + if chunk.SseType != filer_pb.SSEType_NONE { + if detected == filer_pb.SSEType_NONE { + detected = chunk.SseType + } else if chunk.SseType != detected { + return filer_pb.SSEType_NONE, fmt.Errorf("mixed SSE types in chunks: %v and %v", detected, chunk.SseType) + } + } + } + if detected != filer_pb.SSEType_NONE { + return detected, nil + } + + // Fall back to extended metadata for inline objects (no chunks) + if entry.Extended != nil { + hasS3 := len(entry.Extended[s3_constants.SeaweedFSSSES3Key]) > 0 + hasKMS := len(entry.Extended[s3_constants.SeaweedFSSSEKMSKey]) > 0 + hasC := len(entry.Extended[s3_constants.SeaweedFSSSEIV]) > 0 + count := 0 + if hasS3 { + count++ + } + if hasKMS { + count++ + } + if hasC { + count++ + } + if count > 1 { + return filer_pb.SSEType_NONE, fmt.Errorf("conflicting SSE metadata in entry: multiple SSE key types present") + } + if hasS3 { + return filer_pb.SSEType_SSE_S3, nil + } + if hasKMS { + return filer_pb.SSEType_SSE_KMS, nil + } + if hasC { + return filer_pb.SSEType_SSE_C, nil + } + } + return filer_pb.SSEType_NONE, nil +} + +func decryptSSES3(reader io.Reader, entry *filer_pb.Entry) (io.Reader, error) { + if entry.Extended == nil { + return nil, fmt.Errorf("SSE-S3 encrypted entry has no extended metadata") + } + + keyData := entry.Extended[s3_constants.SeaweedFSSSES3Key] + if len(keyData) == 0 { + return nil, fmt.Errorf("SSE-S3 key metadata not found in entry") + } + + keyManager := s3api.GetSSES3KeyManager() + sseS3Key, err := s3api.DeserializeSSES3Metadata(keyData, keyManager) + if err != nil { + return nil, fmt.Errorf("deserialize SSE-S3 metadata: %w", err) + } + + iv, err := s3api.GetSSES3IV(entry, sseS3Key, keyManager) + if err != nil { + return nil, fmt.Errorf("get SSE-S3 IV: %w", err) + } + + return s3api.CreateSSES3DecryptedReader(reader, sseS3Key, iv) +} + +func decryptSSEKMS(reader io.Reader, entry *filer_pb.Entry) (io.Reader, error) { + if entry.Extended == nil { + return nil, fmt.Errorf("SSE-KMS encrypted entry has no extended metadata") + } + + kmsMetadata := entry.Extended[s3_constants.SeaweedFSSSEKMSKey] + if len(kmsMetadata) == 0 { + return nil, fmt.Errorf("SSE-KMS key metadata not found in entry") + } + + sseKMSKey, err := s3api.DeserializeSSEKMSMetadata(kmsMetadata) + if err != nil { + return nil, fmt.Errorf("deserialize SSE-KMS metadata: %w", err) + } + + return s3api.CreateSSEKMSDecryptedReader(reader, sseKMSKey) +} diff --git a/weed/replication/repl_util/sse_reader_test.go b/weed/replication/repl_util/sse_reader_test.go new file mode 100644 index 000000000..277fd26cd --- /dev/null +++ b/weed/replication/repl_util/sse_reader_test.go @@ -0,0 +1,534 @@ +package repl_util + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "os" + "testing" + + "github.com/seaweedfs/seaweedfs/weed/kms" + "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" + "github.com/seaweedfs/seaweedfs/weed/s3api" + "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants" + "github.com/seaweedfs/seaweedfs/weed/util" +) + +func TestDetectSSEType(t *testing.T) { + tests := []struct { + name string + entry *filer_pb.Entry + wantType filer_pb.SSEType + wantError bool + }{ + { + name: "no chunks no extended", + entry: &filer_pb.Entry{}, + wantType: filer_pb.SSEType_NONE, + }, + { + name: "plaintext chunks", + entry: &filer_pb.Entry{ + Chunks: []*filer_pb.FileChunk{ + {SseType: filer_pb.SSEType_NONE}, + {SseType: filer_pb.SSEType_NONE}, + }, + }, + wantType: filer_pb.SSEType_NONE, + }, + { + name: "uniform SSE-S3 chunks", + entry: &filer_pb.Entry{ + Chunks: []*filer_pb.FileChunk{ + {SseType: filer_pb.SSEType_SSE_S3}, + {SseType: filer_pb.SSEType_SSE_S3}, + }, + }, + wantType: filer_pb.SSEType_SSE_S3, + }, + { + name: "uniform SSE-KMS chunks", + entry: &filer_pb.Entry{ + Chunks: []*filer_pb.FileChunk{ + {SseType: filer_pb.SSEType_SSE_KMS}, + }, + }, + wantType: filer_pb.SSEType_SSE_KMS, + }, + { + name: "mixed chunk SSE types", + entry: &filer_pb.Entry{ + Chunks: []*filer_pb.FileChunk{ + {SseType: filer_pb.SSEType_SSE_S3}, + {SseType: filer_pb.SSEType_SSE_KMS}, + }, + }, + wantError: true, + }, + { + name: "inline SSE-S3 via extended", + entry: &filer_pb.Entry{ + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSES3Key: {0x01}, + }, + }, + wantType: filer_pb.SSEType_SSE_S3, + }, + { + name: "inline SSE-KMS via extended", + entry: &filer_pb.Entry{ + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSEKMSKey: {0x01}, + }, + }, + wantType: filer_pb.SSEType_SSE_KMS, + }, + { + name: "inline SSE-C via extended", + entry: &filer_pb.Entry{ + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSEIV: {0x01}, + }, + }, + wantType: filer_pb.SSEType_SSE_C, + }, + { + name: "conflicting extended metadata", + entry: &filer_pb.Entry{ + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSES3Key: {0x01}, + s3_constants.SeaweedFSSSEKMSKey: {0x02}, + }, + }, + wantError: true, + }, + { + name: "chunks take precedence over extended", + entry: &filer_pb.Entry{ + Chunks: []*filer_pb.FileChunk{ + {SseType: filer_pb.SSEType_SSE_S3}, + }, + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSEKMSKey: {0x01}, + }, + }, + wantType: filer_pb.SSEType_SSE_S3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := detectSSEType(tt.entry) + if tt.wantError { + if err == nil { + t.Fatalf("expected error, got type %v", got) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.wantType { + t.Errorf("got %v, want %v", got, tt.wantType) + } + }) + } +} + +func TestMaybeDecryptReader_Plaintext(t *testing.T) { + content := []byte("hello world") + entry := &filer_pb.Entry{} + reader := bytes.NewReader(content) + + got, err := MaybeDecryptReader(reader, entry) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + result, err := io.ReadAll(got) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if !bytes.Equal(result, content) { + t.Errorf("got %q, want %q", result, content) + } +} + +func TestMaybeDecryptReader_NilEntry(t *testing.T) { + content := []byte("hello") + reader := bytes.NewReader(content) + + got, err := MaybeDecryptReader(reader, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + result, err := io.ReadAll(got) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if !bytes.Equal(result, content) { + t.Errorf("got %q, want %q", result, content) + } +} + +func TestMaybeDecryptReader_SSEC_Error(t *testing.T) { + entry := &filer_pb.Entry{ + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSEIV: {0x01}, + }, + } + reader := bytes.NewReader([]byte("data")) + + _, err := MaybeDecryptReader(reader, entry) + if err == nil { + t.Fatal("expected error for SSE-C") + } +} + +func TestMaybeDecryptContent_Plaintext(t *testing.T) { + content := []byte("hello world") + entry := &filer_pb.Entry{} + + got, err := MaybeDecryptContent(content, entry) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bytes.Equal(got, content) { + t.Errorf("got %q, want %q", got, content) + } +} + +func TestMaybeDecryptContent_NilEntry(t *testing.T) { + content := []byte("data") + got, err := MaybeDecryptContent(content, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bytes.Equal(got, content) { + t.Errorf("got %q, want %q", got, content) + } +} + +func TestMaybeDecryptContent_Empty(t *testing.T) { + got, err := MaybeDecryptContent(nil, &filer_pb.Entry{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != nil { + t.Errorf("expected nil, got %v", got) + } +} + +func TestMaybeDecryptContent_SSEC_Error(t *testing.T) { + entry := &filer_pb.Entry{ + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSEIV: {0x01}, + }, + } + + _, err := MaybeDecryptContent([]byte("data"), entry) + if err == nil { + t.Fatal("expected error for SSE-C") + } +} + +func TestMaybeDecryptContent_MixedExtended_Error(t *testing.T) { + entry := &filer_pb.Entry{ + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSES3Key: {0x01}, + s3_constants.SeaweedFSSSEKMSKey: {0x02}, + }, + } + + _, err := MaybeDecryptContent([]byte("data"), entry) + if err == nil { + t.Fatal("expected error for conflicting SSE metadata") + } +} + +// --- SSE-S3 integration tests --- +// These tests exercise the full MaybeDecryptReader/MaybeDecryptContent path +// for SSE-S3: detectSSEType → decryptSSES3 → DeserializeSSES3Metadata → +// GetSSES3IV → CreateSSES3DecryptedReader. A test KEK is injected via +// WEED_S3_SSE_KEK env var and a mock filer client. + +// testFilerClient is a minimal filer_pb.FilerClient mock that returns +// ErrNotFound for all lookups (no KEK on filer — we use env var instead). +type testFilerClient struct{} + +func (c *testFilerClient) WithFilerClient(_ bool, fn func(filer_pb.SeaweedFilerClient) error) error { + return fmt.Errorf("%w", filer_pb.ErrNotFound) +} +func (c *testFilerClient) AdjustedUrl(loc *filer_pb.Location) string { return loc.Url } +func (c *testFilerClient) GetDataCenter() string { return "" } + +// setupTestSSES3 initializes the global SSE-S3 key manager with a test KEK +// via the WEED_S3_SSE_KEK env var and returns the KEK bytes + cleanup func. +func setupTestSSES3(t *testing.T) (kek []byte, cleanup func()) { + t.Helper() + + kek = make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, kek); err != nil { + t.Fatal(err) + } + + // Force Viper to pick up the new env var + os.Setenv("WEED_S3_SSE_KEK", hex.EncodeToString(kek)) + + // Reset Viper cache so it reads the new env var + v := util.GetViper() + v.AutomaticEnv() + + // Re-initialize the global key manager with the KEK from env + km := s3api.GetSSES3KeyManager() + if err := km.InitializeWithFiler(&testFilerClient{}); err != nil { + os.Unsetenv("WEED_S3_SSE_KEK") + t.Fatalf("InitializeWithFiler: %v", err) + } + + return kek, func() { + os.Unsetenv("WEED_S3_SSE_KEK") + // Re-initialize with no KEK to clear the super key + km.InitializeWithFiler(&testFilerClient{}) + } +} + +func TestMaybeDecryptReader_SSES3(t *testing.T) { + _, cleanup := setupTestSSES3(t) + defer cleanup() + + plaintext := []byte("SSE-S3 encrypted content for testing round-trip decryption") + + // Generate a DEK and encrypt + sseKey, err := s3api.GenerateSSES3Key() + if err != nil { + t.Fatal(err) + } + encReader, encIV, err := s3api.CreateSSES3EncryptedReader(bytes.NewReader(plaintext), sseKey) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + ciphertext, err := io.ReadAll(encReader) + if err != nil { + t.Fatalf("read ciphertext: %v", err) + } + + // Build serialized SSE-S3 metadata (uses the global key manager to + // envelope-encrypt the DEK with the test KEK) + sseKey.IV = encIV + metadataBytes, err := s3api.SerializeSSES3Metadata(sseKey) + if err != nil { + t.Fatalf("serialize metadata: %v", err) + } + + entry := &filer_pb.Entry{ + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSES3Key: metadataBytes, + }, + } + + // Test full path: MaybeDecryptReader → decryptSSES3 → DeserializeSSES3Metadata → CreateSSES3DecryptedReader + decrypted, err := MaybeDecryptReader(bytes.NewReader(ciphertext), entry) + if err != nil { + t.Fatalf("MaybeDecryptReader: %v", err) + } + result, err := io.ReadAll(decrypted) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if !bytes.Equal(result, plaintext) { + t.Errorf("SSE-S3 round-trip failed: got %q, want %q", result, plaintext) + } +} + +func TestMaybeDecryptContent_SSES3(t *testing.T) { + _, cleanup := setupTestSSES3(t) + defer cleanup() + + plaintext := []byte("inline SSE-S3 content") + + // Generate a DEK and encrypt inline content + sseKey, err := s3api.GenerateSSES3Key() + if err != nil { + t.Fatal(err) + } + encReader, encIV, err := s3api.CreateSSES3EncryptedReader(bytes.NewReader(plaintext), sseKey) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + ciphertext, err := io.ReadAll(encReader) + if err != nil { + t.Fatalf("read ciphertext: %v", err) + } + + sseKey.IV = encIV + metadataBytes, err := s3api.SerializeSSES3Metadata(sseKey) + if err != nil { + t.Fatalf("serialize metadata: %v", err) + } + + entry := &filer_pb.Entry{ + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSES3Key: metadataBytes, + }, + } + + // Test full path: MaybeDecryptContent → MaybeDecryptReader → decryptSSES3 + result, err := MaybeDecryptContent(ciphertext, entry) + if err != nil { + t.Fatalf("MaybeDecryptContent: %v", err) + } + if !bytes.Equal(result, plaintext) { + t.Errorf("SSE-S3 round-trip failed: got %q, want %q", result, plaintext) + } +} + +// --- SSE-KMS integration tests --- + +// testKMSProvider is a minimal KMSProvider mock for testing. +type testKMSProvider struct { + keyID string + plaintext []byte // the DEK plaintext returned by Decrypt +} + +func (p *testKMSProvider) GenerateDataKey(_ context.Context, _ *kms.GenerateDataKeyRequest) (*kms.GenerateDataKeyResponse, error) { + return nil, nil +} + +func (p *testKMSProvider) Decrypt(_ context.Context, _ *kms.DecryptRequest) (*kms.DecryptResponse, error) { + return &kms.DecryptResponse{ + KeyID: p.keyID, + Plaintext: append([]byte(nil), p.plaintext...), // return a copy + }, nil +} + +func (p *testKMSProvider) DescribeKey(_ context.Context, _ *kms.DescribeKeyRequest) (*kms.DescribeKeyResponse, error) { + return nil, nil +} + +func (p *testKMSProvider) GetKeyID(_ context.Context, keyIdentifier string) (string, error) { + return p.keyID, nil +} + +func (p *testKMSProvider) Close() error { return nil } + +func TestMaybeDecryptReader_SSEKMS(t *testing.T) { + plaintext := []byte("SSE-KMS encrypted content for testing") + + // Generate a random DEK and IV + dek := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, dek); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + // Encrypt with AES-CTR (same cipher mode as SSE-KMS) + block, err := aes.NewCipher(dek) + if err != nil { + t.Fatal(err) + } + ciphertext := make([]byte, len(plaintext)) + cipher.NewCTR(block, iv).XORKeyStream(ciphertext, plaintext) + + // Set up a mock KMS provider that returns our DEK + keyID := "test-kms-key-1" + encryptedDEK := []byte("fake-encrypted-dek") // mock doesn't validate + kms.SetGlobalKMSProvider(&testKMSProvider{ + keyID: keyID, + plaintext: dek, + }) + defer kms.SetGlobalKMSProvider(nil) + + // Build serialized KMS metadata + kmsMetadata := s3api.SSEKMSMetadata{ + Algorithm: s3_constants.SSEAlgorithmKMS, + KeyID: keyID, + EncryptedDataKey: base64.StdEncoding.EncodeToString(encryptedDEK), + IV: base64.StdEncoding.EncodeToString(iv), + } + metadataBytes, err := json.Marshal(kmsMetadata) + if err != nil { + t.Fatal(err) + } + + entry := &filer_pb.Entry{ + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSEKMSKey: metadataBytes, + }, + } + + // Test MaybeDecryptReader + reader := bytes.NewReader(ciphertext) + decrypted, err := MaybeDecryptReader(reader, entry) + if err != nil { + t.Fatalf("MaybeDecryptReader: %v", err) + } + result, err := io.ReadAll(decrypted) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if !bytes.Equal(result, plaintext) { + t.Errorf("SSE-KMS round-trip failed: got %q, want %q", result, plaintext) + } +} + +func TestMaybeDecryptContent_SSEKMS(t *testing.T) { + plaintext := []byte("inline SSE-KMS content") + + dek := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, dek); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + block, err := aes.NewCipher(dek) + if err != nil { + t.Fatal(err) + } + ciphertext := make([]byte, len(plaintext)) + cipher.NewCTR(block, iv).XORKeyStream(ciphertext, plaintext) + + keyID := "test-kms-key-2" + kms.SetGlobalKMSProvider(&testKMSProvider{ + keyID: keyID, + plaintext: dek, + }) + defer kms.SetGlobalKMSProvider(nil) + + kmsMetadata := s3api.SSEKMSMetadata{ + Algorithm: s3_constants.SSEAlgorithmKMS, + KeyID: keyID, + EncryptedDataKey: base64.StdEncoding.EncodeToString([]byte("fake-encrypted-dek")), + IV: base64.StdEncoding.EncodeToString(iv), + } + metadataBytes, err := json.Marshal(kmsMetadata) + if err != nil { + t.Fatal(err) + } + + entry := &filer_pb.Entry{ + Extended: map[string][]byte{ + s3_constants.SeaweedFSSSEKMSKey: metadataBytes, + }, + } + + result, err := MaybeDecryptContent(ciphertext, entry) + if err != nil { + t.Fatalf("MaybeDecryptContent: %v", err) + } + if !bytes.Equal(result, plaintext) { + t.Errorf("SSE-KMS round-trip failed: got %q, want %q", result, plaintext) + } +} diff --git a/weed/replication/replicator.go b/weed/replication/replicator.go index c992906fa..10c576552 100644 --- a/weed/replication/replicator.go +++ b/weed/replication/replicator.go @@ -8,6 +8,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/pb" "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" + "github.com/seaweedfs/seaweedfs/weed/replication/repl_util" "github.com/seaweedfs/seaweedfs/weed/replication/sink" "github.com/seaweedfs/seaweedfs/weed/replication/source" "github.com/seaweedfs/seaweedfs/weed/util" @@ -25,6 +26,10 @@ func NewReplicator(sourceConfig util.Configuration, configPrefix string, dataSin source := &source.FilerSource{} source.Initialize(sourceConfig, configPrefix) + if err := repl_util.InitializeSSEForReplication(source); err != nil { + glog.Warningf("SSE initialization failed: %v (encrypted objects may fail to replicate)", err) + } + dataSink.SetSourceFiler(source) return &Replicator{ diff --git a/weed/replication/sink/azuresink/azure_sink.go b/weed/replication/sink/azuresink/azure_sink.go index 1c2eb7944..d73b86fc8 100644 --- a/weed/replication/sink/azuresink/azure_sink.go +++ b/weed/replication/sink/azuresink/azure_sink.go @@ -186,13 +186,17 @@ func (g *AzureSink) CreateEntry(key string, entry *filer_pb.Entry, signatures [] } if len(entry.Content) > 0 { - if err := writeFunc(entry.Content); err != nil { + content, err := repl_util.MaybeDecryptContent(entry.Content, entry) + if err != nil { + return cleanupOnError(fmt.Errorf("decrypt inline SSE content: %w", err)) + } + if err := writeFunc(content); err != nil { return cleanupOnError(err) } return nil } - if err := repl_util.CopyFromChunkViews(chunkViews, g.filerSource, writeFunc); err != nil { + if err := repl_util.CopyFromChunkViews(chunkViews, g.filerSource, writeFunc, entry); err != nil { return cleanupOnError(err) } diff --git a/weed/replication/sink/b2sink/b2_sink.go b/weed/replication/sink/b2sink/b2_sink.go index 910465ead..c69dffa33 100644 --- a/weed/replication/sink/b2sink/b2_sink.go +++ b/weed/replication/sink/b2sink/b2_sink.go @@ -2,6 +2,7 @@ package B2Sink import ( "context" + "fmt" "strings" "github.com/kurin/blazer/b2" @@ -116,10 +117,14 @@ func (g *B2Sink) CreateEntry(key string, entry *filer_pb.Entry, signatures []int } if len(entry.Content) > 0 { - return writeFunc(entry.Content) + content, err := repl_util.MaybeDecryptContent(entry.Content, entry) + if err != nil { + return fmt.Errorf("decrypt inline SSE content: %w", err) + } + return writeFunc(content) } - if err := repl_util.CopyFromChunkViews(chunkViews, g.filerSource, writeFunc); err != nil { + if err := repl_util.CopyFromChunkViews(chunkViews, g.filerSource, writeFunc, entry); err != nil { return err } diff --git a/weed/replication/sink/gcssink/gcs_sink.go b/weed/replication/sink/gcssink/gcs_sink.go index 97d9217a2..f0508f1fc 100644 --- a/weed/replication/sink/gcssink/gcs_sink.go +++ b/weed/replication/sink/gcssink/gcs_sink.go @@ -127,9 +127,14 @@ func (g *GcsSink) CreateEntry(key string, entry *filer_pb.Entry, signatures []in var writeErr error if len(entry.Content) > 0 { - writeErr = writeFunc(entry.Content) + content, decErr := repl_util.MaybeDecryptContent(entry.Content, entry) + if decErr != nil { + writeErr = fmt.Errorf("decrypt inline SSE content: %w", decErr) + } else { + writeErr = writeFunc(content) + } } else { - writeErr = repl_util.CopyFromChunkViews(chunkViews, g.filerSource, writeFunc) + writeErr = repl_util.CopyFromChunkViews(chunkViews, g.filerSource, writeFunc, entry) } if writeErr != nil { diff --git a/weed/replication/sink/localsink/local_sink.go b/weed/replication/sink/localsink/local_sink.go index 7f036dbbf..47e1f6dae 100644 --- a/weed/replication/sink/localsink/local_sink.go +++ b/weed/replication/sink/localsink/local_sink.go @@ -2,6 +2,7 @@ package localsink import ( "context" + "fmt" "os" "path/filepath" "strings" @@ -94,36 +95,55 @@ func (localsink *LocalSink) CreateEntry(key string, entry *filer_pb.Entry, signa mode := os.FileMode(entry.Attributes.FileMode) shortFileName := util.ToShortFileName(key) - if err := os.Remove(shortFileName); err != nil && !os.IsNotExist(err) { - return err - } - dstFile, err := os.OpenFile(shortFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) - if err != nil { - return err - } - defer dstFile.Close() - fi, err := dstFile.Stat() + // Write to a temp file in the same directory, then atomically rename + // on success. This prevents leaving a truncated/empty file if + // decryption or chunk copy fails. + tmpFile, err := os.CreateTemp(dir, ".seaweedfs-tmp-*") if err != nil { return err } - if fi.Mode() != mode { - glog.V(4).Infof("Modify file mode: %o -> %o", fi.Mode(), mode) - if err := dstFile.Chmod(mode); err != nil { - return err + tmpName := tmpFile.Name() + defer func() { + // Clean up temp file on any error (rename removes it on success) + if tmpFile != nil { + tmpFile.Close() + os.Remove(tmpName) } + }() + + if err := tmpFile.Chmod(mode); err != nil { + return err } writeFunc := func(data []byte) error { - _, writeErr := dstFile.Write(data) + _, writeErr := tmpFile.Write(data) return writeErr } if len(entry.Content) > 0 { - return writeFunc(entry.Content) + content, err := repl_util.MaybeDecryptContent(entry.Content, entry) + if err != nil { + return fmt.Errorf("decrypt inline SSE content: %w", err) + } + if err := writeFunc(content); err != nil { + return err + } + } else { + if err := repl_util.CopyFromChunkViews(chunkViews, localsink.filerSource, writeFunc, entry); err != nil { + return err + } } - if err := repl_util.CopyFromChunkViews(chunkViews, localsink.filerSource, writeFunc); err != nil { + // Close before rename so the data is flushed + if err := tmpFile.Close(); err != nil { + return err + } + tmpFile = nil // prevent deferred cleanup + + // Atomic rename into final destination + if err := os.Rename(tmpName, shortFileName); err != nil { + os.Remove(tmpName) return err } diff --git a/weed/replication/sink/s3sink/s3_sink.go b/weed/replication/sink/s3sink/s3_sink.go index b554d14f4..9d295292f 100644 --- a/weed/replication/sink/s3sink/s3_sink.go +++ b/weed/replication/sink/s3sink/s3_sink.go @@ -18,6 +18,7 @@ import ( "github.com/seaweedfs/seaweedfs/weed/filer" "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb" + "github.com/seaweedfs/seaweedfs/weed/replication/repl_util" "github.com/seaweedfs/seaweedfs/weed/replication/sink" "github.com/seaweedfs/seaweedfs/weed/replication/source" "github.com/seaweedfs/seaweedfs/weed/util" @@ -161,6 +162,14 @@ func (s3sink *S3Sink) CreateEntry(key string, entry *filer_pb.Entry, signatures reader := filer.NewFileReader(s3sink.filerSource, entry) + // Decrypt SSE-encrypted objects so the destination receives plaintext + decryptedReader, err := repl_util.MaybeDecryptReader(reader, entry) + if err != nil { + repl_util.CloseReader(reader) + return fmt.Errorf("decrypt SSE object: %w", err) + } + defer repl_util.CloseMaybeDecryptedReader(reader, decryptedReader) + // Create an uploader with the session and custom options uploader := s3manager.NewUploaderWithClient(s3sink.conn, func(u *s3manager.Uploader) { u.PartSize = int64(s3sink.uploaderPartSizeMb * 1024 * 1024) @@ -195,7 +204,7 @@ func (s3sink *S3Sink) CreateEntry(key string, entry *filer_pb.Entry, signatures uploadInput := s3manager.UploadInput{ Bucket: aws.String(s3sink.bucket), Key: aws.String(key), - Body: reader, + Body: decryptedReader, } if tags != "" { uploadInput.Tagging = aws.String(tags)