v0.1.0 #1
@@ -0,0 +1,240 @@
|
||||
package mmdb
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// requestTimeout defines the maximum duration for request operations before timing out.
|
||||
requestTimeout = 60 * time.Second
|
||||
|
||||
// maxDownloadSize defines the maximum allowed size of the downloaded file in bytes.
|
||||
maxDownloadSize int64 = 200 << 20 // 200 MiB
|
||||
|
||||
// maxArchiveFileSize defines the maximum allowed size of the extracted file from ZIP in bytes.
|
||||
maxArchiveFileSize uint64 = 500 << 20 // 500 MiB
|
||||
)
|
||||
|
||||
// DownloadConfig defines the configuration for downloading MaxMind GeoIP2 database.
|
||||
type DownloadConfig struct {
|
||||
// RequestTimeout defines the maximum duration for request operations before timing out.
|
||||
RequestTimeout time.Duration
|
||||
// MaxDownloadSize defines the maximum allowed size of the downloaded file in bytes.
|
||||
MaxDownloadSize int64
|
||||
// MaxArchiveFileSize defines the maximum allowed size of the extracted file from ZIP in bytes.
|
||||
MaxArchiveFileSize uint64
|
||||
}
|
||||
|
||||
// DefaultDownloadConfig returns the default configuration for downloading MaxMind GeoIP2 database.
|
||||
func DefaultDownloadConfig() DownloadConfig {
|
||||
return DownloadConfig{
|
||||
RequestTimeout: requestTimeout,
|
||||
MaxDownloadSize: maxDownloadSize,
|
||||
MaxArchiveFileSize: maxArchiveFileSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Download is a struct for downloading MaxMind GeoIP2 database.
|
||||
type Download struct {
|
||||
// url is a URL for downloading MaxMind GeoIP2 database.
|
||||
url string
|
||||
// username is a username for downloading MaxMind GeoIP2 database.
|
||||
username string
|
||||
// password is a password for downloading MaxMind GeoIP2 database.
|
||||
password string
|
||||
// Cfg is a configuration for downloading MaxMind GeoIP2 database.
|
||||
Cfg DownloadConfig
|
||||
}
|
||||
|
||||
// NewDownload creates a new Download struct.
|
||||
// @param downloadURL - a URL for downloading MaxMind GeoIP2 database
|
||||
// @param username - a username for downloading MaxMind GeoIP2 database
|
||||
// @param password - a password for downloading MaxMind GeoIP2 database
|
||||
// @param cfg - a configuration for downloading MaxMind GeoIP2 database
|
||||
// @return (*Download, error) - a Download struct and an error if the download URL is invalid
|
||||
func NewDownload(downloadURL, username, password string, cfg DownloadConfig) (*Download, error) {
|
||||
parsedURL, err := url.Parse(downloadURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid url: %w", err)
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return nil, fmt.Errorf("invalid url scheme: %s", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return &Download{
|
||||
url: downloadURL,
|
||||
username: username,
|
||||
password: password,
|
||||
Cfg: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Download downloads MaxMind GeoIP2 database.
|
||||
// @param dir - a directory to save the downloaded database
|
||||
// @param ctx - a context for cancelling the download
|
||||
// @return error - an error if the download failed
|
||||
func (d *Download) Download(dir string, ctx context.Context) error {
|
||||
if err := createDir(dir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.SetBasicAuth(d.username, d.password)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: d.Cfg.RequestTimeout,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download archive: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download archive: unexpected status %s", resp.Status)
|
||||
}
|
||||
|
||||
filename := rand.Text() + ".tar.gz"
|
||||
archivePath := filepath.Join(dir, filename)
|
||||
file, err := os.Create(archivePath)
|
||||
if err != nil {
|
||||
if os.IsExist(err) {
|
||||
return fmt.Errorf("archive file already exists: %w", err)
|
||||
}
|
||||
return fmt.Errorf("create archive file: %w", err)
|
||||
}
|
||||
|
||||
limitedReader := io.LimitReader(resp.Body, d.Cfg.MaxDownloadSize+1)
|
||||
written, copyErr := io.Copy(file, limitedReader)
|
||||
closeErr := file.Close()
|
||||
if copyErr != nil {
|
||||
if removeErr := os.Remove(archivePath); removeErr != nil {
|
||||
return fmt.Errorf("remove archive: %w", removeErr)
|
||||
}
|
||||
return fmt.Errorf("save archive: %w", copyErr)
|
||||
}
|
||||
if closeErr != nil {
|
||||
if removeErr := os.Remove(archivePath); removeErr != nil {
|
||||
return fmt.Errorf("remove archive: %w", removeErr)
|
||||
}
|
||||
return fmt.Errorf("close archive file: %w", closeErr)
|
||||
}
|
||||
if written > d.Cfg.MaxDownloadSize {
|
||||
if removeErr := os.Remove(archivePath); removeErr != nil {
|
||||
return fmt.Errorf("remove archive: %w", removeErr)
|
||||
}
|
||||
return fmt.Errorf("download archive too large: %d bytes", written)
|
||||
}
|
||||
|
||||
if err := extractTarGz(archivePath, dir, d.Cfg.MaxArchiveFileSize); err != nil {
|
||||
if removeErr := os.Remove(archivePath); removeErr != nil {
|
||||
return fmt.Errorf("remove archive: %w", removeErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if removeErr := os.Remove(archivePath); removeErr != nil {
|
||||
return fmt.Errorf("remove archive: %w", removeErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDir creates a directory if it does not exist.
|
||||
// @param dir - a directory to create
|
||||
// @return error - an error if the directory could not be created
|
||||
func createDir(dir string) error {
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractTarGz extracts a tar.gz archive to a specified directory.
|
||||
// @param archivePath - path to the archive to extract
|
||||
// @param dir - path to the directory to extract the archive to
|
||||
// @param maxExtractSize - maximum size of the extracted archive in bytes
|
||||
// @return error - an error if the archive could not be extracted
|
||||
func extractTarGz(archivePath string, dir string, maxExtractSize uint64) error {
|
||||
file, err := os.Open(archivePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open archive: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = file.Close()
|
||||
}()
|
||||
|
||||
gzr, err := gzip.NewReader(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create gzip reader: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = gzr.Close()
|
||||
}()
|
||||
|
||||
tr := tar.NewReader(gzr)
|
||||
|
||||
var extractedSize uint64
|
||||
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("read tar archive: %w", err)
|
||||
}
|
||||
|
||||
targetPath := filepath.Join(dir, header.Name)
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
|
||||
return fmt.Errorf("create directory %q: %w", targetPath, err)
|
||||
}
|
||||
|
||||
case tar.TypeReg:
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0750); err != nil {
|
||||
return fmt.Errorf("create parent directory %q: %w", targetPath, err)
|
||||
}
|
||||
|
||||
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create file %q: %w", targetPath, err)
|
||||
}
|
||||
|
||||
n, err := io.Copy(outFile, tr)
|
||||
closeErr := outFile.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("write file %q: %w", targetPath, err)
|
||||
}
|
||||
if closeErr != nil {
|
||||
return fmt.Errorf("close file %q: %w", targetPath, closeErr)
|
||||
}
|
||||
|
||||
extractedSize += uint64(n)
|
||||
if extractedSize > maxExtractSize {
|
||||
return fmt.Errorf("extracted archive too large: %d bytes", extractedSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user