Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion router/router_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func deleteServer(c *gin.Context) {
// Remove the install log from this server
filename := filepath.Join(config.Get().System.LogDirectory, "install", ID+".log")
err := os.Remove(filename)
if err != nil {
if err != nil && !os.IsNotExist(err) {
log.WithFields(log.Fields{"server_id": ID, "error": err}).Warn("failed to remove server install log during deletion process")
}

Expand Down
188 changes: 148 additions & 40 deletions router/router_transfer.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package router

import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
Expand All @@ -12,12 +11,14 @@ import (
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"

"github.com/apex/log"
"github.com/gin-gonic/gin"
"github.com/google/uuid"

"github.com/pelican-dev/wings/config"
"github.com/pelican-dev/wings/router/middleware"
"github.com/pelican-dev/wings/router/tokens"
"github.com/pelican-dev/wings/server"
Expand Down Expand Up @@ -141,18 +142,18 @@ func postTransfers(c *gin.Context) {
return
}

// Used to calculate the hash of the file as it is being uploaded.
h := sha256.New()

// Used to read the file and checksum from the request body.
mr := multipart.NewReader(c.Request.Body, params["boundary"])

// Loop through the parts of the request body and process them.
var (
hasArchive bool
hasChecksum bool
checksumVerified bool
hasArchive bool
archiveChecksum string
archiveChecksumReceived string
backupChecksumsCalculated = make(map[string]string)
backupChecksumsReceived = make(map[string]string)
)
// Process multipart form
out:
for {
select {
Expand All @@ -169,76 +170,183 @@ out:
}

name := p.FormName()
switch name {
case "archive":

switch {
case name == "archive":
trnsfr.Log().Debug("received archive")
hasArchive = true

if err := trnsfr.Server.EnsureDataDirectoryExists(); err != nil {
middleware.CaptureAndAbort(c, err)
return
}

tee := io.TeeReader(p, h)
// Calculate checksum while streaming to extraction
archiveHasher := sha256.New()
tee := io.TeeReader(p, archiveHasher)

// Stream directly to extraction while calculating checksum
if err := trnsfr.Server.Filesystem().ExtractStreamUnsafe(ctx, "/", tee); err != nil {
middleware.CaptureAndAbort(c, err)
return
}

hasArchive = true
case "checksum":
trnsfr.Log().Debug("received checksum")
// Store the CALCULATED checksum for later verification
archiveChecksum = hex.EncodeToString(archiveHasher.Sum(nil))

if !hasArchive {
middleware.CaptureAndAbort(c, errors.New("archive must be sent before the checksum"))
return
}
trnsfr.Log().Debug("archive extracted and checksum calculated")

hasChecksum = true

v, err := io.ReadAll(p)
case strings.HasPrefix(name, "checksum_archive"):
trnsfr.Log().Debug("received archive checksum")
checksumData, err := io.ReadAll(p)
if err != nil {
middleware.CaptureAndAbort(c, err)
return
}
// Store the RECEIVED checksum for verification
archiveChecksumReceived = string(checksumData)

case name == "install_logs":
trnsfr.Log().Debug("received install logs")

// Create install log directory if it doesn't exist
cfg := config.Get()
installLogDir := filepath.Join(cfg.System.LogDirectory, "install")
if err := os.MkdirAll(installLogDir, 0755); err != nil {
// Don't fail transfer for install logs, just log and continue
trnsfr.Log().WithError(err).Warn("failed to create install log directory, skipping")
break
}

// Use the correct install log path with server UUID
installLogPath := filepath.Join(installLogDir, trnsfr.Server.ID()+".log")

// Create the install log file
installLogFile, err := os.Create(installLogPath)
if err != nil {
// Don't fail transfer for install logs, just log and continue
trnsfr.Log().WithError(err).Warn("failed to create install log file, skipping")
break
}

// Stream the install logs to file
if _, err := io.Copy(installLogFile, p); err != nil {
installLogFile.Close()
// Don't fail transfer for install logs, just log and continue
trnsfr.Log().WithError(err).Warn("failed to stream install logs to file, skipping")
break
}

if err := installLogFile.Close(); err != nil {
// Don't fail transfer for install logs, just log and continue
trnsfr.Log().WithError(err).Warn("failed to close install log file")
}

trnsfr.Log().WithField("path", installLogPath).Debug("install logs saved successfully")

case strings.HasPrefix(name, "backup_"):
backupName := strings.TrimPrefix(name, "backup_")
trnsfr.Log().WithField("backup", backupName).Debug("received backup file")

// Create backup directory if it doesn't exist
cfg := config.Get()
backupDir := filepath.Join(cfg.System.BackupDirectory, trnsfr.Server.ID())
if err := os.MkdirAll(backupDir, 0755); err != nil {
middleware.CaptureAndAbort(c, fmt.Errorf("failed to create backup directory: %w", err))
return
}

expected := make([]byte, hex.DecodedLen(len(v)))
n, err := hex.Decode(expected, v)
backupPath := filepath.Join(backupDir, backupName)

// Create the backup file and stream directly to disk
backupFile, err := os.Create(backupPath)
if err != nil {
middleware.CaptureAndAbort(c, err)
middleware.CaptureAndAbort(c, fmt.Errorf("failed to create backup file %s: %w", backupPath, err))
return
}
actual := h.Sum(nil)

trnsfr.Log().WithFields(log.Fields{
"expected": hex.EncodeToString(expected),
"actual": hex.EncodeToString(actual),
}).Debug("checksums")
// Stream and calculate checksum simultaneously
hasher := sha256.New()
tee := io.TeeReader(p, hasher)

if !bytes.Equal(expected[:n], actual) {
middleware.CaptureAndAbort(c, errors.New("checksums don't match"))
if _, err := io.Copy(backupFile, tee); err != nil {
backupFile.Close()
middleware.CaptureAndAbort(c, fmt.Errorf("failed to stream backup file %s: %w", backupName, err))
return
}

trnsfr.Log().Debug("checksums match")
checksumVerified = true
default:
continue
if err := backupFile.Close(); err != nil {
middleware.CaptureAndAbort(c, fmt.Errorf("failed to close backup file %s: %w", backupName, err))
return
}

// Store the checksum for later verification
backupChecksumsCalculated[backupName] = hex.EncodeToString(hasher.Sum(nil))

trnsfr.Log().WithField("backup", backupName).Debug("backup streamed to disk successfully")

case strings.HasPrefix(name, "checksum_backup_"):
backupName := strings.TrimPrefix(name, "checksum_backup_")
trnsfr.Log().WithField("backup", backupName).Debug("received backup checksum")

checksumData, err := io.ReadAll(p)
if err != nil {
middleware.CaptureAndAbort(c, err)
return
}
backupChecksumsReceived[backupName] = string(checksumData)
}
}
}

if !hasArchive || !hasChecksum {
middleware.CaptureAndAbort(c, errors.New("missing archive or checksum"))
return
// Verify main archive checksum
if hasArchive {
if archiveChecksumReceived == "" {
middleware.CaptureAndAbort(c, errors.New("archive checksum missing"))
return
}

// Compare the calculated checksum with the received checksum
if archiveChecksum != archiveChecksumReceived {
trnsfr.Log().WithFields(log.Fields{
"expected": archiveChecksumReceived,
"actual": archiveChecksum,
}).Error("archive checksum mismatch")
middleware.CaptureAndAbort(c, errors.New("archive checksum mismatch"))
return
}

trnsfr.Log().Debug("archive checksum verified")
}

// Verify backup checksums
for backupName, calculatedChecksum := range backupChecksumsCalculated {
receivedChecksum, exists := backupChecksumsReceived[backupName]
if !exists {
middleware.CaptureAndAbort(c, fmt.Errorf("checksum missing for backup %s", backupName))
return
}

if calculatedChecksum != receivedChecksum {
trnsfr.Log().WithFields(log.Fields{
"backup": backupName,
"expected": receivedChecksum,
"actual": calculatedChecksum,
}).Error("backup checksum mismatch")
middleware.CaptureAndAbort(c, fmt.Errorf("backup %s checksum mismatch", backupName))
return
}

trnsfr.Log().WithField("backup", backupName).Debug("backup checksum verified")
}

if !checksumVerified {
middleware.CaptureAndAbort(c, errors.New("checksums don't match"))
if !hasArchive {
middleware.CaptureAndAbort(c, errors.New("missing archive"))
return
}

// Transfer is almost complete, we just want to ensure the environment is
// configured correctly. We might want to not fail the transfer at this
// configured correctly. We might want to not fail the transfer at this
// stage, but we will just to be safe.

// Ensure the server environment gets configured.
Expand Down
Loading
Loading