Files
Mellaris/io/nfqueue.go
T
hayzam 7a3f6e945d Improves flow handling and adds runtime stats APIs
Refactors TCP and UDP flow managers to enhance analyzer selection and flow binding accuracy, including O(1) UDP stream rebinding by 5-tuple.
Introduces runtime stats tracking for engine and ruleset operations, exposing new APIs for granular performance and error metrics.
Optimizes GeoMatcher with result caching and supports efficient geosite set matching, reducing redundant computation in ruleset expressions.
2026-05-13 06:10:38 +05:30

465 lines
11 KiB
Go

//go:build linux
// +build linux
package io
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"os/exec"
"strconv"
"strings"
"syscall"
"github.com/coreos/go-iptables/iptables"
"github.com/florianl/go-nfqueue"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
)
const (
nfqueueNumStart = 100
nfqueueDefaultQueueSize = 128
nfqueueDefaultMaxLen = 0xFFFF
nfqueueConnMarkAccept = 1001
nfqueueConnMarkDrop = 1002
nftFamily = "inet"
nftTable = "mellaris"
)
func generateNftRules(local, rst bool, numQueues int) (*nftTableSpec, error) {
if local && rst {
return nil, errors.New("tcp rst is not supported in local mode")
}
if numQueues < 1 {
numQueues = 1
}
table := &nftTableSpec{
Family: nftFamily,
Table: nftTable,
}
table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept))
table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop))
queueEnd := nfqueueNumStart + numQueues - 1
if numQueues == 1 {
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNumStart))
} else {
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d-%d", nfqueueNumStart, queueEnd))
}
if local {
table.Chains = []nftChainSpec{
{Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"},
{Chain: "OUTPUT", Header: "type filter hook output priority filter; policy accept;"},
}
} else {
table.Chains = []nftChainSpec{
{Chain: "FORWARD", Header: "type filter hook forward priority filter; policy accept;"},
}
}
for i := range table.Chains {
c := &table.Chains[i]
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK")
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
if rst {
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
}
c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop")
c.Rules = append(c.Rules, "counter queue num $QUEUE_NUM bypass")
}
return table, nil
}
func generateIptRules(local, rst bool, numQueues int) ([]iptRule, error) {
if local && rst {
return nil, errors.New("tcp rst is not supported in local mode")
}
if numQueues < 1 {
numQueues = 1
}
var chains []string
if local {
chains = []string{"INPUT", "OUTPUT"}
} else {
chains = []string{"FORWARD"}
}
rules := make([]iptRule, 0, 4*len(chains))
for _, chain := range chains {
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
if rst {
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
}
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}})
if numQueues == 1 {
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNumStart), "--queue-bypass"}})
} else {
queueSpec := fmt.Sprintf("%d:%d", nfqueueNumStart, nfqueueNumStart+numQueues-1)
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-balance", queueSpec, "--queue-bypass"}})
}
}
return rules, nil
}
var _ PacketIO = (*nfqueuePacketIO)(nil)
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
type nfqueuePacketIO struct {
nqs []*nfqueue.Nfqueue
numQueues int
local bool
rst bool
rSet bool
ipt4 *iptables.IPTables
ipt6 *iptables.IPTables
protectedDialer *net.Dialer
}
type NFQueuePacketIOConfig struct {
QueueSize uint32
ReadBuffer int
WriteBuffer int
Local bool
RST bool
NumQueues int
MaxPacketLen uint32
}
func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
if config.QueueSize == 0 {
config.QueueSize = nfqueueDefaultQueueSize
}
if config.NumQueues <= 0 {
config.NumQueues = 1
}
if config.MaxPacketLen == 0 {
config.MaxPacketLen = nfqueueDefaultMaxLen
}
var ipt4, ipt6 *iptables.IPTables
var err error
if nftCheck() != nil {
ipt4, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, err
}
ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
return nil, err
}
}
nqs := make([]*nfqueue.Nfqueue, config.NumQueues)
for i := range nqs {
n, err := nfqueue.Open(&nfqueue.Config{
NfQueue: uint16(nfqueueNumStart + i),
MaxPacketLen: config.MaxPacketLen,
MaxQueueLen: config.QueueSize,
Copymode: nfqueue.NfQnlCopyPacket,
Flags: nfqueue.NfQaCfgFlagConntrack,
})
if err != nil {
for j := 0; j < i; j++ {
nqs[j].Close()
}
return nil, err
}
if config.ReadBuffer > 0 {
err = n.Con.SetReadBuffer(config.ReadBuffer)
if err != nil {
for j := 0; j <= i; j++ {
nqs[j].Close()
}
return nil, err
}
}
if config.WriteBuffer > 0 {
err = n.Con.SetWriteBuffer(config.WriteBuffer)
if err != nil {
for j := 0; j <= i; j++ {
nqs[j].Close()
}
return nil, err
}
}
nqs[i] = n
}
return &nfqueuePacketIO{
nqs: nqs,
numQueues: config.NumQueues,
local: config.Local,
rst: config.RST,
ipt4: ipt4,
ipt6: ipt6,
protectedDialer: &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
var err error
cErr := c.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept)
})
if cErr != nil {
return cErr
}
return err
},
},
}, nil
}
func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
for _, nq := range nio.nqs {
nq := nq
err := nq.RegisterWithErrorFunc(ctx,
func(a nfqueue.Attribute) int {
if ok, verdict := nio.packetAttributeSanityCheck(a); !ok {
if a.PacketID != nil {
_ = nq.SetVerdict(*a.PacketID, verdict)
}
return 0
}
p := &nfqueuePacket{
id: *a.PacketID,
streamID: ctIDFromCtBytes(*a.Ct),
data: *a.Payload,
nq: nq,
}
return okBoolToInt(cb(p, nil))
},
func(e error) int {
if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) {
if errors.Is(opErr.Err, unix.ENOBUFS) {
return 0
}
}
return okBoolToInt(cb(nil, e))
})
if err != nil {
return err
}
}
if !nio.rSet {
if nio.ipt4 != nil {
err := nio.setupIpt(nio.local, nio.rst, false)
if err != nil {
return err
}
} else {
err := nio.setupNft(nio.local, nio.rst, false)
if err != nil {
return err
}
}
nio.rSet = true
}
return nil
}
func (nio *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) {
if a.PacketID == nil {
return false, -1
}
if a.Payload == nil || len(*a.Payload) < 20 {
return false, nfqueue.NfDrop
}
if a.Ct == nil {
if nio.local {
return false, nfqueue.NfAccept
}
return false, nfqueue.NfDrop
}
return true, -1
}
func (nio *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
nP, ok := p.(*nfqueuePacket)
if !ok {
return &ErrInvalidPacket{Err: errNotNFQueuePacket}
}
switch v {
case VerdictAccept:
return nP.nq.SetVerdict(nP.id, nfqueue.NfAccept)
case VerdictAcceptModify:
return nP.nq.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
case VerdictAcceptStream:
return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept)
case VerdictDrop:
return nP.nq.SetVerdict(nP.id, nfqueue.NfDrop)
case VerdictDropStream:
return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop)
default:
return nil
}
}
func (nio *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
return nio.protectedDialer.DialContext(ctx, network, address)
}
func (nio *nfqueuePacketIO) Close() error {
if nio.rSet {
if nio.ipt4 != nil {
_ = nio.setupIpt(nio.local, nio.rst, true)
} else {
_ = nio.setupNft(nio.local, nio.rst, true)
}
nio.rSet = false
}
var errs []error
for _, nq := range nio.nqs {
if err := nq.Close(); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errs[0]
}
return nil
}
func (nio *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
rules, err := generateNftRules(local, rst, nio.numQueues)
if err != nil {
return err
}
rulesText := rules.String()
if remove {
err = nftDelete(nftFamily, nftTable)
} else {
_ = nftDelete(nftFamily, nftTable)
err = nftAdd(rulesText)
}
return err
}
func (nio *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
rules, err := generateIptRules(local, rst, nio.numQueues)
if err != nil {
return err
}
if remove {
err = iptsBatchDeleteIfExists([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
} else {
err = iptsBatchAppendUnique([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
}
return err
}
var _ Packet = (*nfqueuePacket)(nil)
type nfqueuePacket struct {
id uint32
streamID uint32
data []byte
nq *nfqueue.Nfqueue
}
func (p *nfqueuePacket) StreamID() uint32 { return p.streamID }
func (p *nfqueuePacket) Data() []byte { return p.data }
func okBoolToInt(ok bool) int {
if ok {
return 0
}
return 1
}
func nftCheck() error {
_, err := exec.LookPath("nft")
return err
}
func nftAdd(input string) error {
cmd := exec.Command("nft", "-f", "-")
cmd.Stdin = strings.NewReader(input)
return cmd.Run()
}
func nftDelete(family, table string) error {
cmd := exec.Command("nft", "delete", "table", family, table)
return cmd.Run()
}
type nftTableSpec struct {
Defines []string
Family, Table string
Chains []nftChainSpec
}
func (t *nftTableSpec) String() string {
chains := make([]string, 0, len(t.Chains))
for _, c := range t.Chains {
chains = append(chains, c.String())
}
return fmt.Sprintf(`
%s
table %s %s {
%s
}
`, strings.Join(t.Defines, "\n"), t.Family, t.Table, strings.Join(chains, ""))
}
type nftChainSpec struct {
Chain string
Header string
Rules []string
}
func (c *nftChainSpec) String() string {
return fmt.Sprintf(`
chain %s {
%s
%s
}
`, c.Chain, c.Header, strings.Join(c.Rules, "\n\x20\x20\x20\x20"))
}
type iptRule struct {
Table, Chain string
RuleSpec []string
}
func iptsBatchAppendUnique(ipts []*iptables.IPTables, rules []iptRule) error {
for _, r := range rules {
for _, ipt := range ipts {
err := ipt.AppendUnique(r.Table, r.Chain, r.RuleSpec...)
if err != nil {
return err
}
}
}
return nil
}
func iptsBatchDeleteIfExists(ipts []*iptables.IPTables, rules []iptRule) error {
for _, r := range rules {
for _, ipt := range ipts {
err := ipt.DeleteIfExists(r.Table, r.Chain, r.RuleSpec...)
if err != nil {
return err
}
}
}
return nil
}
func ctIDFromCtBytes(ct []byte) uint32 {
ctAttrs, err := netlink.UnmarshalAttributes(ct)
if err != nil {
return 0
}
for _, attr := range ctAttrs {
if attr.Type == 12 { // CTA_ID
return binary.BigEndian.Uint32(attr.Data)
}
}
return 0
}