Files
geoip2/service/maxmind/mmdb/download.go
T

241 lines
6.9 KiB
Go

package mmdb
import (
"archive/tar"
"compress/gzip"
"context"
"crypto/rand"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"time"
)
const (
// requestTimeout defines the maximum duration for request operations before timing out.
requestTimeout = 60 * time.Second
// maxDownloadSize defines the maximum allowed size of the downloaded file in bytes.
maxDownloadSize int64 = 200 << 20 // 200 MiB
// maxArchiveFileSize defines the maximum allowed size of the extracted file from ZIP in bytes.
maxArchiveFileSize uint64 = 500 << 20 // 500 MiB
)
// DownloadConfig defines the configuration for downloading MaxMind GeoIP2 database.
type DownloadConfig struct {
// RequestTimeout defines the maximum duration for request operations before timing out.
RequestTimeout time.Duration
// MaxDownloadSize defines the maximum allowed size of the downloaded file in bytes.
MaxDownloadSize int64
// MaxArchiveFileSize defines the maximum allowed size of the extracted file from ZIP in bytes.
MaxArchiveFileSize uint64
}
// DefaultDownloadConfig returns the default configuration for downloading MaxMind GeoIP2 database.
func DefaultDownloadConfig() DownloadConfig {
return DownloadConfig{
RequestTimeout: requestTimeout,
MaxDownloadSize: maxDownloadSize,
MaxArchiveFileSize: maxArchiveFileSize,
}
}
// Download is a struct for downloading MaxMind GeoIP2 database.
type Download struct {
// url is a URL for downloading MaxMind GeoIP2 database.
url string
// username is a username for downloading MaxMind GeoIP2 database.
username string
// password is a password for downloading MaxMind GeoIP2 database.
password string
// Cfg is a configuration for downloading MaxMind GeoIP2 database.
Cfg DownloadConfig
}
// NewDownload creates a new Download struct.
// @param downloadURL - a URL for downloading MaxMind GeoIP2 database
// @param username - a username for downloading MaxMind GeoIP2 database
// @param password - a password for downloading MaxMind GeoIP2 database
// @param cfg - a configuration for downloading MaxMind GeoIP2 database
// @return (*Download, error) - a Download struct and an error if the download URL is invalid
func NewDownload(downloadURL, username, password string, cfg DownloadConfig) (*Download, error) {
parsedURL, err := url.Parse(downloadURL)
if err != nil {
return nil, fmt.Errorf("invalid url: %w", err)
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return nil, fmt.Errorf("invalid url scheme: %s", parsedURL.Scheme)
}
return &Download{
url: downloadURL,
username: username,
password: password,
Cfg: cfg,
}, nil
}
// Download downloads MaxMind GeoIP2 database.
// @param dir - a directory to save the downloaded database
// @param ctx - a context for cancelling the download
// @return error - an error if the download failed
func (d *Download) Download(dir string, ctx context.Context) error {
if err := createDir(dir); err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil)
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.SetBasicAuth(d.username, d.password)
client := &http.Client{
Timeout: d.Cfg.RequestTimeout,
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("download archive: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download archive: unexpected status %s", resp.Status)
}
filename := rand.Text() + ".tar.gz"
archivePath := filepath.Join(dir, filename)
file, err := os.Create(archivePath)
if err != nil {
if os.IsExist(err) {
return fmt.Errorf("archive file already exists: %w", err)
}
return fmt.Errorf("create archive file: %w", err)
}
limitedReader := io.LimitReader(resp.Body, d.Cfg.MaxDownloadSize+1)
written, copyErr := io.Copy(file, limitedReader)
closeErr := file.Close()
if copyErr != nil {
if removeErr := os.Remove(archivePath); removeErr != nil {
return fmt.Errorf("remove archive: %w", removeErr)
}
return fmt.Errorf("save archive: %w", copyErr)
}
if closeErr != nil {
if removeErr := os.Remove(archivePath); removeErr != nil {
return fmt.Errorf("remove archive: %w", removeErr)
}
return fmt.Errorf("close archive file: %w", closeErr)
}
if written > d.Cfg.MaxDownloadSize {
if removeErr := os.Remove(archivePath); removeErr != nil {
return fmt.Errorf("remove archive: %w", removeErr)
}
return fmt.Errorf("download archive too large: %d bytes", written)
}
if err := extractTarGz(archivePath, dir, d.Cfg.MaxArchiveFileSize); err != nil {
if removeErr := os.Remove(archivePath); removeErr != nil {
return fmt.Errorf("remove archive: %w", removeErr)
}
return err
}
if removeErr := os.Remove(archivePath); removeErr != nil {
return fmt.Errorf("remove archive: %w", removeErr)
}
return nil
}
// createDir creates a directory if it does not exist.
// @param dir - a directory to create
// @return error - an error if the directory could not be created
func createDir(dir string) error {
if err := os.MkdirAll(dir, 0750); err != nil {
return err
}
return nil
}
// extractTarGz extracts a tar.gz archive to a specified directory.
// @param archivePath - path to the archive to extract
// @param dir - path to the directory to extract the archive to
// @param maxExtractSize - maximum size of the extracted archive in bytes
// @return error - an error if the archive could not be extracted
func extractTarGz(archivePath string, dir string, maxExtractSize uint64) error {
file, err := os.Open(archivePath)
if err != nil {
return fmt.Errorf("open archive: %w", err)
}
defer func() {
_ = file.Close()
}()
gzr, err := gzip.NewReader(file)
if err != nil {
return fmt.Errorf("create gzip reader: %w", err)
}
defer func() {
_ = gzr.Close()
}()
tr := tar.NewReader(gzr)
var extractedSize uint64
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("read tar archive: %w", err)
}
targetPath := filepath.Join(dir, header.Name)
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
return fmt.Errorf("create directory %q: %w", targetPath, err)
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(targetPath), 0750); err != nil {
return fmt.Errorf("create parent directory %q: %w", targetPath, err)
}
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
if err != nil {
return fmt.Errorf("create file %q: %w", targetPath, err)
}
n, err := io.Copy(outFile, tr)
closeErr := outFile.Close()
if err != nil {
return fmt.Errorf("write file %q: %w", targetPath, err)
}
if closeErr != nil {
return fmt.Errorf("close file %q: %w", targetPath, closeErr)
}
extractedSize += uint64(n)
if extractedSize > maxExtractSize {
return fmt.Errorf("extracted archive too large: %d bytes", extractedSize)
}
}
}
return nil
}