228 lines
5.5 KiB
Go
228 lines
5.5 KiB
Go
package blocklist
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.kor-elf.net/kor-elf-shield/kor-elf-shield/internal/daemon/db/entity"
|
|
"git.kor-elf.net/kor-elf-shield/kor-elf-shield/internal/daemon/db/repository"
|
|
"git.kor-elf.net/kor-elf-shield/kor-elf-shield/internal/daemon/firewall/nft/block"
|
|
"git.kor-elf.net/kor-elf-shield/kor-elf-shield/internal/log"
|
|
)
|
|
|
|
type newBlocklist func(name string) (block.Blocklist, error)
|
|
|
|
type Blocklist interface {
|
|
Names() []string
|
|
NftReload(blocks map[string]block.Blocklist) error
|
|
Run()
|
|
Close() error
|
|
}
|
|
|
|
type updateSource struct {
|
|
forcedly bool
|
|
source *SourceConfig
|
|
}
|
|
|
|
type blocklist struct {
|
|
Sources []*SourceConfig
|
|
blocklistRepository repository.BlocklistRepository
|
|
logger log.Logger
|
|
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
wg sync.WaitGroup
|
|
|
|
nftBlocklists map[string]block.Blocklist
|
|
mu sync.Mutex
|
|
|
|
launchChannel chan updateSource
|
|
}
|
|
|
|
func New(config Config, ctx context.Context, logger log.Logger) (Blocklist, error) {
|
|
return &blocklist{
|
|
Sources: config.Sources,
|
|
blocklistRepository: config.BlocklistRepository,
|
|
logger: logger,
|
|
ctx: ctx,
|
|
|
|
nftBlocklists: map[string]block.Blocklist{},
|
|
mu: sync.Mutex{},
|
|
|
|
launchChannel: make(chan updateSource, 50),
|
|
}, nil
|
|
}
|
|
|
|
func (b *blocklist) Names() []string {
|
|
names := []string{}
|
|
for _, source := range b.Sources {
|
|
if source.Name != "" {
|
|
names = append(names, source.Name)
|
|
}
|
|
}
|
|
return names
|
|
}
|
|
|
|
func (b *blocklist) NftReload(blocks map[string]block.Blocklist) error {
|
|
b.logger.Debug("Reload blocklist")
|
|
|
|
b.mu.Lock()
|
|
b.nftBlocklists = blocks
|
|
b.mu.Unlock()
|
|
|
|
for _, source := range b.Sources {
|
|
if nftBlocklist, ok := b.nftBlocklists[source.Name]; ok {
|
|
if listEntity, err := b.blocklistRepository.Get(source.Name); err != nil {
|
|
b.logger.Error(fmt.Sprintf("Failed to get blocklist %s: %s", source.Name, err))
|
|
} else if listEntity.IsFresh(source.Interval) {
|
|
if err := nftBlocklist.ReplaceElementsIPv4(listEntity.IPsV4); len(listEntity.IPsV4) > 0 && err != nil {
|
|
b.logger.Error(fmt.Sprintf("Failed to replace elements (IPv4): %s", err))
|
|
}
|
|
|
|
if err := nftBlocklist.ReplaceElementsIPv6(listEntity.IPsV6); len(listEntity.IPsV6) > 0 && err != nil {
|
|
b.logger.Error(fmt.Sprintf("Failed to replace elements (IPv6): %s", err))
|
|
}
|
|
}
|
|
} else {
|
|
b.logger.Error(fmt.Sprintf("NFTables sets blocklist %s not found", source.Name))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (b *blocklist) Run() {
|
|
b.logger.Debug("Starting blocklist")
|
|
if b.cancel != nil {
|
|
// already started
|
|
b.logger.Warn("Blocklist already started")
|
|
return
|
|
}
|
|
b.ctx, b.cancel = context.WithCancel(b.ctx)
|
|
go b.processUpdateData(b.ctx)
|
|
|
|
for _, src := range b.Sources {
|
|
if src == nil || src.Name == "" {
|
|
continue
|
|
}
|
|
|
|
interval := src.Interval
|
|
if interval <= 0 {
|
|
interval = 5 * time.Minute // дефолт
|
|
}
|
|
|
|
b.wg.Add(1)
|
|
go b.runSourceWorker(src, interval)
|
|
}
|
|
}
|
|
|
|
func (b *blocklist) runSourceWorker(sourceConfig *SourceConfig, interval time.Duration) {
|
|
defer b.wg.Done()
|
|
|
|
b.launchChannel <- updateSource{
|
|
forcedly: false,
|
|
source: sourceConfig,
|
|
}
|
|
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-b.ctx.Done():
|
|
b.logger.Debug(fmt.Sprintf("source %s stopped", sourceConfig.Name))
|
|
return
|
|
case <-ticker.C:
|
|
b.logger.Debug(fmt.Sprintf("source %s tick", sourceConfig.Name))
|
|
b.launchChannel <- updateSource{
|
|
forcedly: true,
|
|
source: sourceConfig,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *blocklist) processUpdateData(ctx context.Context) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case updSource, ok := <-b.launchChannel:
|
|
if !ok {
|
|
// Channel closed
|
|
return
|
|
}
|
|
|
|
if updSource.forcedly {
|
|
b.refreshSource(updSource.source)
|
|
continue
|
|
}
|
|
|
|
if listEntity, err := b.blocklistRepository.Get(updSource.source.Name); err != nil {
|
|
b.logger.Error(fmt.Sprintf("Failed to get blocklist %s: %s", updSource.source.Name, err))
|
|
continue
|
|
} else if listEntity.IsFresh(updSource.source.Interval) {
|
|
b.logger.Debug(fmt.Sprintf("blocklist %s is fresh", updSource.source.Name))
|
|
continue
|
|
}
|
|
|
|
b.refreshSource(updSource.source)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (b *blocklist) refreshSource(sourceConfig *SourceConfig) {
|
|
ipsV4, ipsV6, err := sourceConfig.Source.Get()
|
|
if err != nil {
|
|
b.logger.Error(fmt.Sprintf("Failed to get IPs from source %s: %s", sourceConfig.Name, err))
|
|
return
|
|
}
|
|
|
|
if nftBlocklist, ok := b.nftBlocklists[sourceConfig.Name]; ok {
|
|
listEntity := &entity.Blocklist{
|
|
UpdatedAtUnix: time.Now().Unix(),
|
|
IPsV4: nil,
|
|
IPsV6: nil,
|
|
}
|
|
|
|
if len(ipsV4) > 0 {
|
|
if err := nftBlocklist.ReplaceElementsIPv4(ipsV4); err != nil {
|
|
b.logger.Error(fmt.Sprintf("Failed to replace elements (IPv4): %s", err))
|
|
} else {
|
|
listEntity.IPsV4 = ipsV4
|
|
}
|
|
}
|
|
|
|
if len(ipsV6) > 0 {
|
|
if err := nftBlocklist.ReplaceElementsIPv6(ipsV6); err != nil {
|
|
b.logger.Error(fmt.Sprintf("Failed to replace elements (IPv6): %s", err))
|
|
} else {
|
|
listEntity.IPsV6 = ipsV6
|
|
}
|
|
}
|
|
|
|
if err := b.blocklistRepository.Update(sourceConfig.Name, listEntity); err != nil {
|
|
b.logger.Error(fmt.Sprintf("Failed to update blocklist %s: %s", sourceConfig.Name, err))
|
|
}
|
|
|
|
} else {
|
|
b.logger.Error(fmt.Sprintf("NFTables sets blocklist %s not found", sourceConfig.Name))
|
|
return
|
|
}
|
|
|
|
b.logger.Debug(fmt.Sprintf("refresh blocklist from %s", sourceConfig.Name))
|
|
}
|
|
|
|
func (b *blocklist) Close() error {
|
|
b.logger.Debug("Stopping blocklist")
|
|
if b.cancel != nil {
|
|
b.cancel()
|
|
b.wg.Wait()
|
|
b.cancel = nil
|
|
}
|
|
close(b.launchChannel)
|
|
return nil
|
|
}
|