From a2dff48936cf82a4a5b8d3eeb3ac562c7ddc84cd Mon Sep 17 00:00:00 2001 From: Leonid Nikitin Date: Wed, 18 Mar 2026 21:05:11 +0500 Subject: [PATCH] Refactor request logic with reusable `fetch` function and add methods to handle separated IP parsing for both regular and ZIP responses. --- blocklist.go | 162 +++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 124 insertions(+), 38 deletions(-) diff --git a/blocklist.go b/blocklist.go index 712029c..9a5df68 100644 --- a/blocklist.go +++ b/blocklist.go @@ -89,29 +89,12 @@ func NewConfigZip(c Config) ConfigZip { // Get fetches data from the given URL, parses the response using the provided parser, and applies the given configuration. // It returns the parsed IPs and any errors that occurred during the process. func Get(fileUrl string, parser parser.Parser, c Config) (parser.IPs, error) { - parsedURL, err := url.Parse(fileUrl) - if err != nil { - return nil, fmt.Errorf("invalid url: %w", err) - } - if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - return nil, fmt.Errorf("invalid url scheme: %s", parsedURL.Scheme) - } - ctx, cancel := context.WithTimeout(context.Background(), c.ContextTimeout) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, fileUrl, nil) + res, err := fetch(fileUrl, ctx, c.RequestTimeout) if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - - client := &http.Client{ - Timeout: c.RequestTimeout, - } - - res, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("request failed: %w", err) + return nil, err } defer func() { _ = res.Body.Close() @@ -124,32 +107,36 @@ func Get(fileUrl string, parser parser.Parser, c Config) (parser.IPs, error) { return parser.Parse(res.Body, c.Validator, c.Limit) } +// GetSeparatedIPs fetches data from the given URL, parses the response using the provided parser, and applies the given configuration. +// It returns the parsed IPs and any errors that occurred during the process. +func GetSeparatedIPs(fileUrl string, parser parser.Parser, c Config) (ipV4 parser.IPs, ipV6 parser.IPs, err error) { + ctx, cancel := context.WithTimeout(context.Background(), c.ContextTimeout) + defer cancel() + + res, err := fetch(fileUrl, ctx, c.RequestTimeout) + if err != nil { + return nil, nil, err + } + defer func() { + _ = res.Body.Close() + }() + + if res.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + return parser.ParseIPsByVersion(res.Body, c.Validator, c.Limit) +} + // GetZip fetches data from the given URL, parses the response using the provided parser, and applies the given configuration. // It returns the parsed IPs and any errors that occurred during the process. func GetZip(fileUrl string, parser parser.Parser, c ConfigZip) (parser.IPs, error) { - parsedURL, err := url.Parse(fileUrl) - if err != nil { - return nil, fmt.Errorf("invalid url: %w", err) - } - if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - return nil, fmt.Errorf("invalid url scheme: %s", parsedURL.Scheme) - } - ctx, cancel := context.WithTimeout(context.Background(), c.Config.ContextTimeout) defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, fileUrl, nil) + res, err := fetch(fileUrl, ctx, c.Config.RequestTimeout) if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - - client := &http.Client{ - Timeout: c.Config.RequestTimeout, - } - - res, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("request failed: %w", err) + return nil, err } defer func() { _ = res.Body.Close() @@ -184,6 +171,49 @@ func GetZip(fileUrl string, parser parser.Parser, c ConfigZip) (parser.IPs, erro return parseZip(body, parser, c) } +// GetZipSeparatedIPs fetches data from the given URL, parses the response using the provided parser, and applies the given configuration. +// It returns the parsed IPs and any errors that occurred during the process. +func GetZipSeparatedIPs(fileUrl string, parser parser.Parser, c ConfigZip) (ipV4 parser.IPs, ipV6 parser.IPs, err error) { + ctx, cancel := context.WithTimeout(context.Background(), c.Config.ContextTimeout) + defer cancel() + + res, err := fetch(fileUrl, ctx, c.Config.RequestTimeout) + if err != nil { + return nil, nil, err + } + defer func() { + _ = res.Body.Close() + }() + + if res.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + if c.MaxDownloadSize > 0 && res.ContentLength > c.MaxDownloadSize { + return nil, nil, fmt.Errorf("downloaded file is too large: content-length %d exceeds limit %d", res.ContentLength, c.MaxDownloadSize) + } + + reader := res.Body + if c.MaxDownloadSize > 0 { + reader = io.NopCloser(io.LimitReader(res.Body, c.MaxDownloadSize+1)) + } + + body, err := io.ReadAll(reader) + if err != nil { + return nil, nil, fmt.Errorf("read response body: %w", err) + } + + if c.MaxDownloadSize > 0 && int64(len(body)) > c.MaxDownloadSize { + return nil, nil, fmt.Errorf("downloaded file exceeds limit %d bytes", c.MaxDownloadSize) + } + + if !isZip(body) { + return nil, nil, fmt.Errorf("invalid zip archive") + } + + return parseZipSeparatedIPs(body, parser, c) +} + func isZip(body []byte) bool { return len(body) >= 4 && body[0] == 'P' && @@ -223,6 +253,37 @@ func parseZip(body []byte, p parser.Parser, c ConfigZip) (parser.IPs, error) { return p.Parse(zipReader, c.Config.Validator, c.Config.Limit) } +func parseZipSeparatedIPs(body []byte, p parser.Parser, c ConfigZip) (ipV4 parser.IPs, ipV6 parser.IPs, err error) { + reader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) + if err != nil { + return nil, nil, fmt.Errorf("open zip archive: %w", err) + } + + file := findArchiveFile(reader.File) + if file == nil { + return nil, nil, fmt.Errorf("zip archive does not contain a supported file") + } + + if c.MaxArchiveFileSize > 0 && file.UncompressedSize64 > c.MaxArchiveFileSize { + return nil, nil, fmt.Errorf("file %q in zip is too large: %d exceeds limit %d", file.Name, file.UncompressedSize64, c.MaxArchiveFileSize) + } + + rc, err := file.Open() + if err != nil { + return nil, nil, fmt.Errorf("open file %q from zip: %w", file.Name, err) + } + defer func() { + _ = rc.Close() + }() + + var zipReader io.Reader = rc + if c.MaxArchiveFileSize > 0 { + zipReader = io.LimitReader(rc, int64(c.MaxArchiveFileSize)+1) + } + + return p.ParseIPsByVersion(zipReader, c.Config.Validator, c.Config.Limit) +} + func findArchiveFile(files []*zip.File) *zip.File { var fallback *zip.File @@ -246,3 +307,28 @@ func findArchiveFile(files []*zip.File) *zip.File { return fallback } + +func fetch(fileUrl string, ctx context.Context, requestTimeout time.Duration) (*http.Response, error) { + parsedURL, err := url.Parse(fileUrl) + if err != nil { + return nil, fmt.Errorf("invalid url: %w", err) + } + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return nil, fmt.Errorf("invalid url scheme: %s", parsedURL.Scheme) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fileUrl, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + client := &http.Client{ + Timeout: requestTimeout, + } + + res, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + return res, nil +}