v0.1.0 #1
@@ -0,0 +1,240 @@
|
|||||||
|
package mmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"archive/tar"
|
||||||
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// requestTimeout defines the maximum duration for request operations before timing out.
|
||||||
|
requestTimeout = 60 * time.Second
|
||||||
|
|
||||||
|
// maxDownloadSize defines the maximum allowed size of the downloaded file in bytes.
|
||||||
|
maxDownloadSize int64 = 200 << 20 // 200 MiB
|
||||||
|
|
||||||
|
// maxArchiveFileSize defines the maximum allowed size of the extracted file from ZIP in bytes.
|
||||||
|
maxArchiveFileSize uint64 = 500 << 20 // 500 MiB
|
||||||
|
)
|
||||||
|
|
||||||
|
// DownloadConfig defines the configuration for downloading MaxMind GeoIP2 database.
|
||||||
|
type DownloadConfig struct {
|
||||||
|
// RequestTimeout defines the maximum duration for request operations before timing out.
|
||||||
|
RequestTimeout time.Duration
|
||||||
|
// MaxDownloadSize defines the maximum allowed size of the downloaded file in bytes.
|
||||||
|
MaxDownloadSize int64
|
||||||
|
// MaxArchiveFileSize defines the maximum allowed size of the extracted file from ZIP in bytes.
|
||||||
|
MaxArchiveFileSize uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultDownloadConfig returns the default configuration for downloading MaxMind GeoIP2 database.
|
||||||
|
func DefaultDownloadConfig() DownloadConfig {
|
||||||
|
return DownloadConfig{
|
||||||
|
RequestTimeout: requestTimeout,
|
||||||
|
MaxDownloadSize: maxDownloadSize,
|
||||||
|
MaxArchiveFileSize: maxArchiveFileSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download is a struct for downloading MaxMind GeoIP2 database.
|
||||||
|
type Download struct {
|
||||||
|
// url is a URL for downloading MaxMind GeoIP2 database.
|
||||||
|
url string
|
||||||
|
// username is a username for downloading MaxMind GeoIP2 database.
|
||||||
|
username string
|
||||||
|
// password is a password for downloading MaxMind GeoIP2 database.
|
||||||
|
password string
|
||||||
|
// Cfg is a configuration for downloading MaxMind GeoIP2 database.
|
||||||
|
Cfg DownloadConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDownload creates a new Download struct.
|
||||||
|
// @param downloadURL - a URL for downloading MaxMind GeoIP2 database
|
||||||
|
// @param username - a username for downloading MaxMind GeoIP2 database
|
||||||
|
// @param password - a password for downloading MaxMind GeoIP2 database
|
||||||
|
// @param cfg - a configuration for downloading MaxMind GeoIP2 database
|
||||||
|
// @return (*Download, error) - a Download struct and an error if the download URL is invalid
|
||||||
|
func NewDownload(downloadURL, username, password string, cfg DownloadConfig) (*Download, error) {
|
||||||
|
parsedURL, err := url.Parse(downloadURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid url: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||||
|
return nil, fmt.Errorf("invalid url scheme: %s", parsedURL.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Download{
|
||||||
|
url: downloadURL,
|
||||||
|
username: username,
|
||||||
|
password: password,
|
||||||
|
Cfg: cfg,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download downloads MaxMind GeoIP2 database.
|
||||||
|
// @param dir - a directory to save the downloaded database
|
||||||
|
// @param ctx - a context for cancelling the download
|
||||||
|
// @return error - an error if the download failed
|
||||||
|
func (d *Download) Download(dir string, ctx context.Context) error {
|
||||||
|
if err := createDir(dir); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
req.SetBasicAuth(d.username, d.password)
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: d.Cfg.RequestTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("download archive: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("download archive: unexpected status %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
filename := rand.Text() + ".tar.gz"
|
||||||
|
archivePath := filepath.Join(dir, filename)
|
||||||
|
file, err := os.Create(archivePath)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsExist(err) {
|
||||||
|
return fmt.Errorf("archive file already exists: %w", err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("create archive file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
limitedReader := io.LimitReader(resp.Body, d.Cfg.MaxDownloadSize+1)
|
||||||
|
written, copyErr := io.Copy(file, limitedReader)
|
||||||
|
closeErr := file.Close()
|
||||||
|
if copyErr != nil {
|
||||||
|
if removeErr := os.Remove(archivePath); removeErr != nil {
|
||||||
|
return fmt.Errorf("remove archive: %w", removeErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("save archive: %w", copyErr)
|
||||||
|
}
|
||||||
|
if closeErr != nil {
|
||||||
|
if removeErr := os.Remove(archivePath); removeErr != nil {
|
||||||
|
return fmt.Errorf("remove archive: %w", removeErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("close archive file: %w", closeErr)
|
||||||
|
}
|
||||||
|
if written > d.Cfg.MaxDownloadSize {
|
||||||
|
if removeErr := os.Remove(archivePath); removeErr != nil {
|
||||||
|
return fmt.Errorf("remove archive: %w", removeErr)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("download archive too large: %d bytes", written)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := extractTarGz(archivePath, dir, d.Cfg.MaxArchiveFileSize); err != nil {
|
||||||
|
if removeErr := os.Remove(archivePath); removeErr != nil {
|
||||||
|
return fmt.Errorf("remove archive: %w", removeErr)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if removeErr := os.Remove(archivePath); removeErr != nil {
|
||||||
|
return fmt.Errorf("remove archive: %w", removeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createDir creates a directory if it does not exist.
|
||||||
|
// @param dir - a directory to create
|
||||||
|
// @return error - an error if the directory could not be created
|
||||||
|
func createDir(dir string) error {
|
||||||
|
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractTarGz extracts a tar.gz archive to a specified directory.
|
||||||
|
// @param archivePath - path to the archive to extract
|
||||||
|
// @param dir - path to the directory to extract the archive to
|
||||||
|
// @param maxExtractSize - maximum size of the extracted archive in bytes
|
||||||
|
// @return error - an error if the archive could not be extracted
|
||||||
|
func extractTarGz(archivePath string, dir string, maxExtractSize uint64) error {
|
||||||
|
file, err := os.Open(archivePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open archive: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = file.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
gzr, err := gzip.NewReader(file)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create gzip reader: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = gzr.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
tr := tar.NewReader(gzr)
|
||||||
|
|
||||||
|
var extractedSize uint64
|
||||||
|
|
||||||
|
for {
|
||||||
|
header, err := tr.Next()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read tar archive: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
targetPath := filepath.Join(dir, header.Name)
|
||||||
|
|
||||||
|
switch header.Typeflag {
|
||||||
|
case tar.TypeDir:
|
||||||
|
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
|
||||||
|
return fmt.Errorf("create directory %q: %w", targetPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case tar.TypeReg:
|
||||||
|
if err := os.MkdirAll(filepath.Dir(targetPath), 0750); err != nil {
|
||||||
|
return fmt.Errorf("create parent directory %q: %w", targetPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create file %q: %w", targetPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := io.Copy(outFile, tr)
|
||||||
|
closeErr := outFile.Close()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("write file %q: %w", targetPath, err)
|
||||||
|
}
|
||||||
|
if closeErr != nil {
|
||||||
|
return fmt.Errorf("close file %q: %w", targetPath, closeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
extractedSize += uint64(n)
|
||||||
|
if extractedSize > maxExtractSize {
|
||||||
|
return fmt.Errorf("extracted archive too large: %d bytes", extractedSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user