diff --git a/service/maxmind/mmdb/download.go b/service/maxmind/mmdb/download.go new file mode 100644 index 0000000..37da82b --- /dev/null +++ b/service/maxmind/mmdb/download.go @@ -0,0 +1,240 @@ +package mmdb + +import ( + "archive/tar" + "compress/gzip" + "context" + "crypto/rand" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "time" +) + +const ( + // requestTimeout defines the maximum duration for request operations before timing out. + requestTimeout = 60 * time.Second + + // maxDownloadSize defines the maximum allowed size of the downloaded file in bytes. + maxDownloadSize int64 = 200 << 20 // 200 MiB + + // maxArchiveFileSize defines the maximum allowed size of the extracted file from ZIP in bytes. + maxArchiveFileSize uint64 = 500 << 20 // 500 MiB +) + +// DownloadConfig defines the configuration for downloading MaxMind GeoIP2 database. +type DownloadConfig struct { + // RequestTimeout defines the maximum duration for request operations before timing out. + RequestTimeout time.Duration + // MaxDownloadSize defines the maximum allowed size of the downloaded file in bytes. + MaxDownloadSize int64 + // MaxArchiveFileSize defines the maximum allowed size of the extracted file from ZIP in bytes. + MaxArchiveFileSize uint64 +} + +// DefaultDownloadConfig returns the default configuration for downloading MaxMind GeoIP2 database. +func DefaultDownloadConfig() DownloadConfig { + return DownloadConfig{ + RequestTimeout: requestTimeout, + MaxDownloadSize: maxDownloadSize, + MaxArchiveFileSize: maxArchiveFileSize, + } +} + +// Download is a struct for downloading MaxMind GeoIP2 database. +type Download struct { + // url is a URL for downloading MaxMind GeoIP2 database. + url string + // username is a username for downloading MaxMind GeoIP2 database. + username string + // password is a password for downloading MaxMind GeoIP2 database. + password string + // Cfg is a configuration for downloading MaxMind GeoIP2 database. + Cfg DownloadConfig +} + +// NewDownload creates a new Download struct. +// @param downloadURL - a URL for downloading MaxMind GeoIP2 database +// @param username - a username for downloading MaxMind GeoIP2 database +// @param password - a password for downloading MaxMind GeoIP2 database +// @param cfg - a configuration for downloading MaxMind GeoIP2 database +// @return (*Download, error) - a Download struct and an error if the download URL is invalid +func NewDownload(downloadURL, username, password string, cfg DownloadConfig) (*Download, error) { + parsedURL, err := url.Parse(downloadURL) + if err != nil { + return nil, fmt.Errorf("invalid url: %w", err) + } + + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return nil, fmt.Errorf("invalid url scheme: %s", parsedURL.Scheme) + } + + return &Download{ + url: downloadURL, + username: username, + password: password, + Cfg: cfg, + }, nil +} + +// Download downloads MaxMind GeoIP2 database. +// @param dir - a directory to save the downloaded database +// @param ctx - a context for cancelling the download +// @return error - an error if the download failed +func (d *Download) Download(dir string, ctx context.Context) error { + if err := createDir(dir); err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + req.SetBasicAuth(d.username, d.password) + + client := &http.Client{ + Timeout: d.Cfg.RequestTimeout, + } + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("download archive: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("download archive: unexpected status %s", resp.Status) + } + + filename := rand.Text() + ".tar.gz" + archivePath := filepath.Join(dir, filename) + file, err := os.Create(archivePath) + if err != nil { + if os.IsExist(err) { + return fmt.Errorf("archive file already exists: %w", err) + } + return fmt.Errorf("create archive file: %w", err) + } + + limitedReader := io.LimitReader(resp.Body, d.Cfg.MaxDownloadSize+1) + written, copyErr := io.Copy(file, limitedReader) + closeErr := file.Close() + if copyErr != nil { + if removeErr := os.Remove(archivePath); removeErr != nil { + return fmt.Errorf("remove archive: %w", removeErr) + } + return fmt.Errorf("save archive: %w", copyErr) + } + if closeErr != nil { + if removeErr := os.Remove(archivePath); removeErr != nil { + return fmt.Errorf("remove archive: %w", removeErr) + } + return fmt.Errorf("close archive file: %w", closeErr) + } + if written > d.Cfg.MaxDownloadSize { + if removeErr := os.Remove(archivePath); removeErr != nil { + return fmt.Errorf("remove archive: %w", removeErr) + } + return fmt.Errorf("download archive too large: %d bytes", written) + } + + if err := extractTarGz(archivePath, dir, d.Cfg.MaxArchiveFileSize); err != nil { + if removeErr := os.Remove(archivePath); removeErr != nil { + return fmt.Errorf("remove archive: %w", removeErr) + } + return err + } + + if removeErr := os.Remove(archivePath); removeErr != nil { + return fmt.Errorf("remove archive: %w", removeErr) + } + + return nil +} + +// createDir creates a directory if it does not exist. +// @param dir - a directory to create +// @return error - an error if the directory could not be created +func createDir(dir string) error { + if err := os.MkdirAll(dir, 0750); err != nil { + return err + } + return nil +} + +// extractTarGz extracts a tar.gz archive to a specified directory. +// @param archivePath - path to the archive to extract +// @param dir - path to the directory to extract the archive to +// @param maxExtractSize - maximum size of the extracted archive in bytes +// @return error - an error if the archive could not be extracted +func extractTarGz(archivePath string, dir string, maxExtractSize uint64) error { + file, err := os.Open(archivePath) + if err != nil { + return fmt.Errorf("open archive: %w", err) + } + defer func() { + _ = file.Close() + }() + + gzr, err := gzip.NewReader(file) + if err != nil { + return fmt.Errorf("create gzip reader: %w", err) + } + defer func() { + _ = gzr.Close() + }() + + tr := tar.NewReader(gzr) + + var extractedSize uint64 + + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("read tar archive: %w", err) + } + + targetPath := filepath.Join(dir, header.Name) + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("create directory %q: %w", targetPath, err) + } + + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(targetPath), 0750); err != nil { + return fmt.Errorf("create parent directory %q: %w", targetPath, err) + } + + outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode)) + if err != nil { + return fmt.Errorf("create file %q: %w", targetPath, err) + } + + n, err := io.Copy(outFile, tr) + closeErr := outFile.Close() + if err != nil { + return fmt.Errorf("write file %q: %w", targetPath, err) + } + if closeErr != nil { + return fmt.Errorf("close file %q: %w", targetPath, closeErr) + } + + extractedSize += uint64(n) + if extractedSize > maxExtractSize { + return fmt.Errorf("extracted archive too large: %d bytes", extractedSize) + } + } + } + + return nil +}