package mmdb import ( "archive/tar" "compress/gzip" "context" "crypto/rand" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "time" ) const ( // requestTimeout defines the maximum duration for request operations before timing out. requestTimeout = 60 * time.Second // maxDownloadSize defines the maximum allowed size of the downloaded file in bytes. maxDownloadSize int64 = 200 << 20 // 200 MiB // maxArchiveFileSize defines the maximum allowed size of the extracted file from ZIP in bytes. maxArchiveFileSize uint64 = 500 << 20 // 500 MiB ) // DownloadConfig defines the configuration for downloading MaxMind GeoIP2 database. type DownloadConfig struct { // RequestTimeout defines the maximum duration for request operations before timing out. RequestTimeout time.Duration // MaxDownloadSize defines the maximum allowed size of the downloaded file in bytes. MaxDownloadSize int64 // MaxArchiveFileSize defines the maximum allowed size of the extracted file from ZIP in bytes. MaxArchiveFileSize uint64 } // DefaultDownloadConfig returns the default configuration for downloading MaxMind GeoIP2 database. func DefaultDownloadConfig() DownloadConfig { return DownloadConfig{ RequestTimeout: requestTimeout, MaxDownloadSize: maxDownloadSize, MaxArchiveFileSize: maxArchiveFileSize, } } // Download is a struct for downloading MaxMind GeoIP2 database. type Download struct { // url is a URL for downloading MaxMind GeoIP2 database. url string // username is a username for downloading MaxMind GeoIP2 database. username string // password is a password for downloading MaxMind GeoIP2 database. password string // Cfg is a configuration for downloading MaxMind GeoIP2 database. Cfg DownloadConfig } // NewDownload creates a new Download struct. // @param downloadURL - a URL for downloading MaxMind GeoIP2 database // @param username - a username for downloading MaxMind GeoIP2 database // @param password - a password for downloading MaxMind GeoIP2 database // @param cfg - a configuration for downloading MaxMind GeoIP2 database // @return (*Download, error) - a Download struct and an error if the download URL is invalid func NewDownload(downloadURL, username, password string, cfg DownloadConfig) (*Download, error) { parsedURL, err := url.Parse(downloadURL) if err != nil { return nil, fmt.Errorf("invalid url: %w", err) } if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { return nil, fmt.Errorf("invalid url scheme: %s", parsedURL.Scheme) } return &Download{ url: downloadURL, username: username, password: password, Cfg: cfg, }, nil } // Download downloads MaxMind GeoIP2 database. // @param dir - a directory to save the downloaded database // @param ctx - a context for cancelling the download // @return error - an error if the download failed func (d *Download) Download(dir string, ctx context.Context) error { if err := createDir(dir); err != nil { return err } req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil) if err != nil { return fmt.Errorf("create request: %w", err) } req.SetBasicAuth(d.username, d.password) client := &http.Client{ Timeout: d.Cfg.RequestTimeout, } resp, err := client.Do(req) if err != nil { return fmt.Errorf("download archive: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return fmt.Errorf("download archive: unexpected status %s", resp.Status) } filename := rand.Text() + ".tar.gz" archivePath := filepath.Join(dir, filename) file, err := os.Create(archivePath) if err != nil { if os.IsExist(err) { return fmt.Errorf("archive file already exists: %w", err) } return fmt.Errorf("create archive file: %w", err) } limitedReader := io.LimitReader(resp.Body, d.Cfg.MaxDownloadSize+1) written, copyErr := io.Copy(file, limitedReader) closeErr := file.Close() if copyErr != nil { if removeErr := os.Remove(archivePath); removeErr != nil { return fmt.Errorf("remove archive: %w", removeErr) } return fmt.Errorf("save archive: %w", copyErr) } if closeErr != nil { if removeErr := os.Remove(archivePath); removeErr != nil { return fmt.Errorf("remove archive: %w", removeErr) } return fmt.Errorf("close archive file: %w", closeErr) } if written > d.Cfg.MaxDownloadSize { if removeErr := os.Remove(archivePath); removeErr != nil { return fmt.Errorf("remove archive: %w", removeErr) } return fmt.Errorf("download archive too large: %d bytes", written) } if err := extractTarGz(archivePath, dir, d.Cfg.MaxArchiveFileSize); err != nil { if removeErr := os.Remove(archivePath); removeErr != nil { return fmt.Errorf("remove archive: %w", removeErr) } return err } if removeErr := os.Remove(archivePath); removeErr != nil { return fmt.Errorf("remove archive: %w", removeErr) } return nil } // createDir creates a directory if it does not exist. // @param dir - a directory to create // @return error - an error if the directory could not be created func createDir(dir string) error { if err := os.MkdirAll(dir, 0750); err != nil { return err } return nil } // extractTarGz extracts a tar.gz archive to a specified directory. // @param archivePath - path to the archive to extract // @param dir - path to the directory to extract the archive to // @param maxExtractSize - maximum size of the extracted archive in bytes // @return error - an error if the archive could not be extracted func extractTarGz(archivePath string, dir string, maxExtractSize uint64) error { file, err := os.Open(archivePath) if err != nil { return fmt.Errorf("open archive: %w", err) } defer func() { _ = file.Close() }() gzr, err := gzip.NewReader(file) if err != nil { return fmt.Errorf("create gzip reader: %w", err) } defer func() { _ = gzr.Close() }() tr := tar.NewReader(gzr) var extractedSize uint64 for { header, err := tr.Next() if err == io.EOF { break } if err != nil { return fmt.Errorf("read tar archive: %w", err) } targetPath := filepath.Join(dir, header.Name) switch header.Typeflag { case tar.TypeDir: if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil { return fmt.Errorf("create directory %q: %w", targetPath, err) } case tar.TypeReg: if err := os.MkdirAll(filepath.Dir(targetPath), 0750); err != nil { return fmt.Errorf("create parent directory %q: %w", targetPath, err) } outFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode)) if err != nil { return fmt.Errorf("create file %q: %w", targetPath, err) } n, err := io.Copy(outFile, tr) closeErr := outFile.Close() if err != nil { return fmt.Errorf("write file %q: %w", targetPath, err) } if closeErr != nil { return fmt.Errorf("close file %q: %w", targetPath, closeErr) } extractedSize += uint64(n) if extractedSize > maxExtractSize { return fmt.Errorf("extracted archive too large: %d bytes", extractedSize) } } } return nil }