7a3f6e945d
Refactors TCP and UDP flow managers to enhance analyzer selection and flow binding accuracy, including O(1) UDP stream rebinding by 5-tuple. Introduces runtime stats tracking for engine and ruleset operations, exposing new APIs for granular performance and error metrics. Optimizes GeoMatcher with result caching and supports efficient geosite set matching, reducing redundant computation in ruleset expressions.
465 lines
11 KiB
Go
465 lines
11 KiB
Go
//go:build linux
|
|
// +build linux
|
|
|
|
package io
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os/exec"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
|
|
"github.com/coreos/go-iptables/iptables"
|
|
"github.com/florianl/go-nfqueue"
|
|
"github.com/mdlayher/netlink"
|
|
"golang.org/x/sys/unix"
|
|
)
|
|
|
|
const (
|
|
nfqueueNumStart = 100
|
|
nfqueueDefaultQueueSize = 128
|
|
nfqueueDefaultMaxLen = 0xFFFF
|
|
|
|
nfqueueConnMarkAccept = 1001
|
|
nfqueueConnMarkDrop = 1002
|
|
|
|
nftFamily = "inet"
|
|
nftTable = "mellaris"
|
|
)
|
|
|
|
func generateNftRules(local, rst bool, numQueues int) (*nftTableSpec, error) {
|
|
if local && rst {
|
|
return nil, errors.New("tcp rst is not supported in local mode")
|
|
}
|
|
if numQueues < 1 {
|
|
numQueues = 1
|
|
}
|
|
table := &nftTableSpec{
|
|
Family: nftFamily,
|
|
Table: nftTable,
|
|
}
|
|
table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept))
|
|
table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop))
|
|
queueEnd := nfqueueNumStart + numQueues - 1
|
|
if numQueues == 1 {
|
|
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNumStart))
|
|
} else {
|
|
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d-%d", nfqueueNumStart, queueEnd))
|
|
}
|
|
if local {
|
|
table.Chains = []nftChainSpec{
|
|
{Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"},
|
|
{Chain: "OUTPUT", Header: "type filter hook output priority filter; policy accept;"},
|
|
}
|
|
} else {
|
|
table.Chains = []nftChainSpec{
|
|
{Chain: "FORWARD", Header: "type filter hook forward priority filter; policy accept;"},
|
|
}
|
|
}
|
|
for i := range table.Chains {
|
|
c := &table.Chains[i]
|
|
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK")
|
|
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
|
|
if rst {
|
|
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
|
|
}
|
|
c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop")
|
|
c.Rules = append(c.Rules, "counter queue num $QUEUE_NUM bypass")
|
|
}
|
|
return table, nil
|
|
}
|
|
|
|
func generateIptRules(local, rst bool, numQueues int) ([]iptRule, error) {
|
|
if local && rst {
|
|
return nil, errors.New("tcp rst is not supported in local mode")
|
|
}
|
|
if numQueues < 1 {
|
|
numQueues = 1
|
|
}
|
|
var chains []string
|
|
if local {
|
|
chains = []string{"INPUT", "OUTPUT"}
|
|
} else {
|
|
chains = []string{"FORWARD"}
|
|
}
|
|
rules := make([]iptRule, 0, 4*len(chains))
|
|
for _, chain := range chains {
|
|
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
|
|
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
|
|
if rst {
|
|
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
|
|
}
|
|
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}})
|
|
if numQueues == 1 {
|
|
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNumStart), "--queue-bypass"}})
|
|
} else {
|
|
queueSpec := fmt.Sprintf("%d:%d", nfqueueNumStart, nfqueueNumStart+numQueues-1)
|
|
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-balance", queueSpec, "--queue-bypass"}})
|
|
}
|
|
}
|
|
return rules, nil
|
|
}
|
|
|
|
var _ PacketIO = (*nfqueuePacketIO)(nil)
|
|
|
|
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
|
|
|
|
type nfqueuePacketIO struct {
|
|
nqs []*nfqueue.Nfqueue
|
|
numQueues int
|
|
local bool
|
|
rst bool
|
|
rSet bool
|
|
|
|
ipt4 *iptables.IPTables
|
|
ipt6 *iptables.IPTables
|
|
|
|
protectedDialer *net.Dialer
|
|
}
|
|
|
|
type NFQueuePacketIOConfig struct {
|
|
QueueSize uint32
|
|
ReadBuffer int
|
|
WriteBuffer int
|
|
Local bool
|
|
RST bool
|
|
NumQueues int
|
|
MaxPacketLen uint32
|
|
}
|
|
|
|
func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
|
if config.QueueSize == 0 {
|
|
config.QueueSize = nfqueueDefaultQueueSize
|
|
}
|
|
if config.NumQueues <= 0 {
|
|
config.NumQueues = 1
|
|
}
|
|
if config.MaxPacketLen == 0 {
|
|
config.MaxPacketLen = nfqueueDefaultMaxLen
|
|
}
|
|
var ipt4, ipt6 *iptables.IPTables
|
|
var err error
|
|
if nftCheck() != nil {
|
|
ipt4, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
nqs := make([]*nfqueue.Nfqueue, config.NumQueues)
|
|
for i := range nqs {
|
|
n, err := nfqueue.Open(&nfqueue.Config{
|
|
NfQueue: uint16(nfqueueNumStart + i),
|
|
MaxPacketLen: config.MaxPacketLen,
|
|
MaxQueueLen: config.QueueSize,
|
|
Copymode: nfqueue.NfQnlCopyPacket,
|
|
Flags: nfqueue.NfQaCfgFlagConntrack,
|
|
})
|
|
if err != nil {
|
|
for j := 0; j < i; j++ {
|
|
nqs[j].Close()
|
|
}
|
|
return nil, err
|
|
}
|
|
if config.ReadBuffer > 0 {
|
|
err = n.Con.SetReadBuffer(config.ReadBuffer)
|
|
if err != nil {
|
|
for j := 0; j <= i; j++ {
|
|
nqs[j].Close()
|
|
}
|
|
return nil, err
|
|
}
|
|
}
|
|
if config.WriteBuffer > 0 {
|
|
err = n.Con.SetWriteBuffer(config.WriteBuffer)
|
|
if err != nil {
|
|
for j := 0; j <= i; j++ {
|
|
nqs[j].Close()
|
|
}
|
|
return nil, err
|
|
}
|
|
}
|
|
nqs[i] = n
|
|
}
|
|
|
|
return &nfqueuePacketIO{
|
|
nqs: nqs,
|
|
numQueues: config.NumQueues,
|
|
local: config.Local,
|
|
rst: config.RST,
|
|
ipt4: ipt4,
|
|
ipt6: ipt6,
|
|
protectedDialer: &net.Dialer{
|
|
Control: func(network, address string, c syscall.RawConn) error {
|
|
var err error
|
|
cErr := c.Control(func(fd uintptr) {
|
|
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept)
|
|
})
|
|
if cErr != nil {
|
|
return cErr
|
|
}
|
|
return err
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
|
|
for _, nq := range nio.nqs {
|
|
nq := nq
|
|
err := nq.RegisterWithErrorFunc(ctx,
|
|
func(a nfqueue.Attribute) int {
|
|
if ok, verdict := nio.packetAttributeSanityCheck(a); !ok {
|
|
if a.PacketID != nil {
|
|
_ = nq.SetVerdict(*a.PacketID, verdict)
|
|
}
|
|
return 0
|
|
}
|
|
p := &nfqueuePacket{
|
|
id: *a.PacketID,
|
|
streamID: ctIDFromCtBytes(*a.Ct),
|
|
data: *a.Payload,
|
|
nq: nq,
|
|
}
|
|
return okBoolToInt(cb(p, nil))
|
|
},
|
|
func(e error) int {
|
|
if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) {
|
|
if errors.Is(opErr.Err, unix.ENOBUFS) {
|
|
return 0
|
|
}
|
|
}
|
|
return okBoolToInt(cb(nil, e))
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if !nio.rSet {
|
|
if nio.ipt4 != nil {
|
|
err := nio.setupIpt(nio.local, nio.rst, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
err := nio.setupNft(nio.local, nio.rst, false)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
nio.rSet = true
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (nio *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) {
|
|
if a.PacketID == nil {
|
|
return false, -1
|
|
}
|
|
if a.Payload == nil || len(*a.Payload) < 20 {
|
|
return false, nfqueue.NfDrop
|
|
}
|
|
if a.Ct == nil {
|
|
if nio.local {
|
|
return false, nfqueue.NfAccept
|
|
}
|
|
return false, nfqueue.NfDrop
|
|
}
|
|
return true, -1
|
|
}
|
|
|
|
func (nio *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
|
|
nP, ok := p.(*nfqueuePacket)
|
|
if !ok {
|
|
return &ErrInvalidPacket{Err: errNotNFQueuePacket}
|
|
}
|
|
switch v {
|
|
case VerdictAccept:
|
|
return nP.nq.SetVerdict(nP.id, nfqueue.NfAccept)
|
|
case VerdictAcceptModify:
|
|
return nP.nq.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
|
|
case VerdictAcceptStream:
|
|
return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept)
|
|
case VerdictDrop:
|
|
return nP.nq.SetVerdict(nP.id, nfqueue.NfDrop)
|
|
case VerdictDropStream:
|
|
return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop)
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (nio *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
return nio.protectedDialer.DialContext(ctx, network, address)
|
|
}
|
|
|
|
func (nio *nfqueuePacketIO) Close() error {
|
|
if nio.rSet {
|
|
if nio.ipt4 != nil {
|
|
_ = nio.setupIpt(nio.local, nio.rst, true)
|
|
} else {
|
|
_ = nio.setupNft(nio.local, nio.rst, true)
|
|
}
|
|
nio.rSet = false
|
|
}
|
|
var errs []error
|
|
for _, nq := range nio.nqs {
|
|
if err := nq.Close(); err != nil {
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
if len(errs) > 0 {
|
|
return errs[0]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (nio *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
|
|
rules, err := generateNftRules(local, rst, nio.numQueues)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rulesText := rules.String()
|
|
if remove {
|
|
err = nftDelete(nftFamily, nftTable)
|
|
} else {
|
|
_ = nftDelete(nftFamily, nftTable)
|
|
err = nftAdd(rulesText)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (nio *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
|
|
rules, err := generateIptRules(local, rst, nio.numQueues)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if remove {
|
|
err = iptsBatchDeleteIfExists([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
|
|
} else {
|
|
err = iptsBatchAppendUnique([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
|
|
}
|
|
return err
|
|
}
|
|
|
|
var _ Packet = (*nfqueuePacket)(nil)
|
|
|
|
type nfqueuePacket struct {
|
|
id uint32
|
|
streamID uint32
|
|
data []byte
|
|
nq *nfqueue.Nfqueue
|
|
}
|
|
|
|
func (p *nfqueuePacket) StreamID() uint32 { return p.streamID }
|
|
func (p *nfqueuePacket) Data() []byte { return p.data }
|
|
|
|
func okBoolToInt(ok bool) int {
|
|
if ok {
|
|
return 0
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func nftCheck() error {
|
|
_, err := exec.LookPath("nft")
|
|
return err
|
|
}
|
|
|
|
func nftAdd(input string) error {
|
|
cmd := exec.Command("nft", "-f", "-")
|
|
cmd.Stdin = strings.NewReader(input)
|
|
return cmd.Run()
|
|
}
|
|
|
|
func nftDelete(family, table string) error {
|
|
cmd := exec.Command("nft", "delete", "table", family, table)
|
|
return cmd.Run()
|
|
}
|
|
|
|
type nftTableSpec struct {
|
|
Defines []string
|
|
Family, Table string
|
|
Chains []nftChainSpec
|
|
}
|
|
|
|
func (t *nftTableSpec) String() string {
|
|
chains := make([]string, 0, len(t.Chains))
|
|
for _, c := range t.Chains {
|
|
chains = append(chains, c.String())
|
|
}
|
|
return fmt.Sprintf(`
|
|
%s
|
|
|
|
table %s %s {
|
|
%s
|
|
}
|
|
`, strings.Join(t.Defines, "\n"), t.Family, t.Table, strings.Join(chains, ""))
|
|
}
|
|
|
|
type nftChainSpec struct {
|
|
Chain string
|
|
Header string
|
|
Rules []string
|
|
}
|
|
|
|
func (c *nftChainSpec) String() string {
|
|
return fmt.Sprintf(`
|
|
chain %s {
|
|
%s
|
|
%s
|
|
}
|
|
`, c.Chain, c.Header, strings.Join(c.Rules, "\n\x20\x20\x20\x20"))
|
|
}
|
|
|
|
type iptRule struct {
|
|
Table, Chain string
|
|
RuleSpec []string
|
|
}
|
|
|
|
func iptsBatchAppendUnique(ipts []*iptables.IPTables, rules []iptRule) error {
|
|
for _, r := range rules {
|
|
for _, ipt := range ipts {
|
|
err := ipt.AppendUnique(r.Table, r.Chain, r.RuleSpec...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func iptsBatchDeleteIfExists(ipts []*iptables.IPTables, rules []iptRule) error {
|
|
for _, r := range rules {
|
|
for _, ipt := range ipts {
|
|
err := ipt.DeleteIfExists(r.Table, r.Chain, r.RuleSpec...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ctIDFromCtBytes(ct []byte) uint32 {
|
|
ctAttrs, err := netlink.UnmarshalAttributes(ct)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
for _, attr := range ctAttrs {
|
|
if attr.Type == 12 { // CTA_ID
|
|
return binary.BigEndian.Uint32(attr.Data)
|
|
}
|
|
}
|
|
return 0
|
|
}
|