package mmdb import ( "context" "errors" "fmt" "net/netip" "os" "sync" "git.kor-elf.net/kor-elf-shield/geoip2" "git.kor-elf.net/kor-elf-shield/geoip2/internal/pkg" oschwaldGeoip2 "github.com/oschwald/geoip2-golang/v2" ) // download is an interface for downloading MaxMind GeoIP2 database. type download interface { // Download downloads MaxMind GeoIP2 database. Download(dir string, ctx context.Context) error } type service struct { // download is an interface for downloading MaxMind GeoIP2 database. download download // info is a function that returns information about the IP address. info info logger geoip2.Logger // dir is a directory for storing MaxMind GeoIP2 database. dir *pkg.Dir // isRefreshing is a flag that indicates whether the database is being refreshed. isRefreshing bool // mu is a mutex for synchronizing access to the reader. mu sync.RWMutex // reader is a MaxMind GeoIP2 database reader. reader *oschwaldGeoip2.Reader } // info is a function that returns information about the IP address. // @param ip - IP address // @param reader - MaxMind GeoIP2 database reader // @return geoip2.Info - information about the IP address type info func(ip netip.Addr, reader *oschwaldGeoip2.Reader) (geoip2.Info, error) // NewMMDB creates a new MaxMind GeoIP2 database service. // @param download - an interface for downloading MaxMind GeoIP2 database // @param info - a function that returns information about the IP address // @param logger - a logger // @param dir - a directory for storing MaxMind GeoIP2 database // @return geoip2.RefreshableGeoIP2 - a MaxMind GeoIP2 database service func NewMMDB(download download, info info, logger geoip2.Logger, dir *pkg.Dir) geoip2.RefreshableGeoIP2 { s := &service{ download: download, logger: logger, dir: dir, info: info, mu: sync.RWMutex{}, } s.init() return s } // Info returns information about the IP address. // @param ip - IP address // @return geoip2.Info - information about the IP address // @return error - an error if the IP address could not be found func (s *service) Info(ip netip.Addr) (geoip2.Info, error) { s.mu.RLock() defer s.mu.RUnlock() if s.reader == nil { return geoip2.Info{}, errors.New("geoip reader is not ready") } return s.info(ip, s.reader) } // Refresh refreshes the MaxMind GeoIP2 database. // @param ctx - a context // @return error - an error if the database could not be refreshed func (s *service) Refresh(ctx context.Context) error { if s.isRefreshing { return fmt.Errorf("process is refreshing") } s.isRefreshing = true defer func() { s.isRefreshing = false }() if err := s.fetch(ctx); err != nil { return err } newReader, err := s.openReader() if err != nil { return err } s.mu.Lock() oldReader := s.reader s.reader = newReader s.mu.Unlock() if oldReader != nil { _ = oldReader.Close() } return nil } // Close closes the MaxMind GeoIP2 database. // @return error - an error if the database could not be closed func (s *service) Close() error { if s.reader != nil { return s.reader.Close() } return nil } // init initializes the MaxMind GeoIP2 database service. func (s *service) init() { reader, err := s.openReader() if err != nil { go func() { // If there is no data yet, then we try to get the data. if err := s.Refresh(context.Background()); err != nil { s.logger.Error(err) } }() return } s.mu.Lock() defer s.mu.Unlock() s.reader = reader } // openReader opens the MaxMind GeoIP2 database reader. // @return (*oschwaldGeoip2.Reader, error) - a MaxMind GeoIP2 database reader and an error if the database could not be opened // 1. If the database is already open, then it is returned. // 2. If the database is not open, then it is opened and returned. func (s *service) openReader() (*oschwaldGeoip2.Reader, error) { path, err := pkg.FindMMDBFile(s.dir.PathCurrentDir) if err != nil { return nil, err } return oschwaldGeoip2.Open(path) } // fetch downloads the MaxMind GeoIP2 database. // @param ctx - a context // @return error - an error if the database could not be downloaded func (s *service) fetch(ctx context.Context) error { tmpDir, err := s.dir.CreateRandomTmpDir() if err != nil { return err } defer func() { if err := os.RemoveAll(tmpDir); err != nil { s.logger.Error(err) } }() if err := s.download.Download(tmpDir, ctx); err != nil { return err } tmpDirForFiles, err := s.dir.CreateRandomTmpDir() if err != nil { return err } defer func() { if err := os.RemoveAll(tmpDir); err != nil { s.logger.Error(err) } }() if err := pkg.MovingFiles(tmpDir, tmpDirForFiles); err != nil { return err } if err := s.dir.ReplaceDirToCurrent(tmpDirForFiles); err != nil { return err } return nil }