refactor: engine/tcp/worker perf improvements

This commit is contained in:
2026-05-12 15:16:11 +00:00
parent dc16b979e7
commit ecc2cde1c2
9 changed files with 743 additions and 546 deletions
+157 -127
View File
@@ -18,9 +18,9 @@ import (
)
const (
nfqueueNum = 100
nfqueueMaxPacketLen = 0xFFFF
nfqueueNumStart = 100
nfqueueDefaultQueueSize = 128
nfqueueDefaultMaxLen = 0xFFFF
nfqueueConnMarkAccept = 1001
nfqueueConnMarkDrop = 1002
@@ -29,17 +29,25 @@ const (
nftTable = "mellaris"
)
func generateNftRules(local, rst bool) (*nftTableSpec, error) {
func generateNftRules(local, rst bool, numQueues int) (*nftTableSpec, error) {
if local && rst {
return nil, errors.New("tcp rst is not supported in local mode")
}
if numQueues < 1 {
numQueues = 1
}
table := &nftTableSpec{
Family: nftFamily,
Table: nftTable,
}
table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept))
table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop))
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNum))
queueEnd := nfqueueNumStart + numQueues - 1
if numQueues == 1 {
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNumStart))
} else {
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d-%d", nfqueueNumStart, queueEnd))
}
if local {
table.Chains = []nftChainSpec{
{Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"},
@@ -52,7 +60,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
}
for i := range table.Chains {
c := &table.Chains[i]
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK")
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
if rst {
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
@@ -63,10 +71,13 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
return table, nil
}
func generateIptRules(local, rst bool) ([]iptRule, error) {
func generateIptRules(local, rst bool, numQueues int) ([]iptRule, error) {
if local && rst {
return nil, errors.New("tcp rst is not supported in local mode")
}
if numQueues < 1 {
numQueues = 1
}
var chains []string
if local {
chains = []string{"INPUT", "OUTPUT"}
@@ -75,16 +86,19 @@ func generateIptRules(local, rst bool) ([]iptRule, error) {
}
rules := make([]iptRule, 0, 4*len(chains))
for _, chain := range chains {
// Bypass protected connections
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
if rst {
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
}
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}})
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}})
if numQueues == 1 {
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNumStart), "--queue-bypass"}})
} else {
queueSpec := fmt.Sprintf("%d:%d", nfqueueNumStart, nfqueueNumStart+numQueues-1)
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-balance", queueSpec, "--queue-bypass"}})
}
}
return rules, nil
}
@@ -93,12 +107,12 @@ var _ PacketIO = (*nfqueuePacketIO)(nil)
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
type nfqueuePacketIO struct {
n *nfqueue.Nfqueue
local bool
rst bool
rSet bool // whether the nftables/iptables rules have been set
nqs []*nfqueue.Nfqueue
numQueues int
local bool
rst bool
rSet bool
// iptables not nil = use iptables instead of nftables
ipt4 *iptables.IPTables
ipt6 *iptables.IPTables
@@ -106,21 +120,28 @@ type nfqueuePacketIO struct {
}
type NFQueuePacketIOConfig struct {
QueueSize uint32
ReadBuffer int
WriteBuffer int
Local bool
RST bool
QueueSize uint32
ReadBuffer int
WriteBuffer int
Local bool
RST bool
NumQueues int
MaxPacketLen uint32
}
func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
if config.QueueSize == 0 {
config.QueueSize = nfqueueDefaultQueueSize
}
if config.NumQueues <= 0 {
config.NumQueues = 1
}
if config.MaxPacketLen == 0 {
config.MaxPacketLen = nfqueueDefaultMaxLen
}
var ipt4, ipt6 *iptables.IPTables
var err error
if nftCheck() != nil {
// We prefer nftables, but if it's not available, fall back to iptables
ipt4, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, err
@@ -130,36 +151,50 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
return nil, err
}
}
n, err := nfqueue.Open(&nfqueue.Config{
NfQueue: nfqueueNum,
MaxPacketLen: nfqueueMaxPacketLen,
MaxQueueLen: config.QueueSize,
Copymode: nfqueue.NfQnlCopyPacket,
Flags: nfqueue.NfQaCfgFlagConntrack,
})
if err != nil {
return nil, err
}
if config.ReadBuffer > 0 {
err = n.Con.SetReadBuffer(config.ReadBuffer)
nqs := make([]*nfqueue.Nfqueue, config.NumQueues)
for i := range nqs {
n, err := nfqueue.Open(&nfqueue.Config{
NfQueue: uint16(nfqueueNumStart + i),
MaxPacketLen: config.MaxPacketLen,
MaxQueueLen: config.QueueSize,
Copymode: nfqueue.NfQnlCopyPacket,
Flags: nfqueue.NfQaCfgFlagConntrack,
})
if err != nil {
_ = n.Close()
for j := 0; j < i; j++ {
nqs[j].Close()
}
return nil, err
}
}
if config.WriteBuffer > 0 {
err = n.Con.SetWriteBuffer(config.WriteBuffer)
if err != nil {
_ = n.Close()
return nil, err
if config.ReadBuffer > 0 {
err = n.Con.SetReadBuffer(config.ReadBuffer)
if err != nil {
for j := 0; j <= i; j++ {
nqs[j].Close()
}
return nil, err
}
}
if config.WriteBuffer > 0 {
err = n.Con.SetWriteBuffer(config.WriteBuffer)
if err != nil {
for j := 0; j <= i; j++ {
nqs[j].Close()
}
return nil, err
}
}
nqs[i] = n
}
return &nfqueuePacketIO{
n: n,
local: config.Local,
rst: config.RST,
ipt4: ipt4,
ipt6: ipt6,
nqs: nqs,
numQueues: config.NumQueues,
local: config.Local,
rst: config.RST,
ipt4: ipt4,
ipt6: ipt6,
protectedDialer: &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
var err error
@@ -175,60 +210,63 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
}, nil
}
func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
err := n.n.RegisterWithErrorFunc(ctx,
func(a nfqueue.Attribute) int {
if ok, verdict := n.packetAttributeSanityCheck(a); !ok {
if a.PacketID != nil {
_ = n.n.SetVerdict(*a.PacketID, verdict)
}
return 0
}
p := &nfqueuePacket{
id: *a.PacketID,
streamID: ctIDFromCtBytes(*a.Ct),
data: *a.Payload,
}
return okBoolToInt(cb(p, nil))
},
func(e error) int {
if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) {
if errors.Is(opErr.Err, unix.ENOBUFS) {
// Kernel buffer temporarily full, ignore
func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
for i, nq := range nio.nqs {
nq := nq
err := nq.RegisterWithErrorFunc(ctx,
func(a nfqueue.Attribute) int {
if ok, verdict := nio.packetAttributeSanityCheck(a); !ok {
if a.PacketID != nil {
_ = nq.SetVerdict(*a.PacketID, verdict)
}
return 0
}
}
return okBoolToInt(cb(nil, e))
})
if err != nil {
return err
}
if !n.rSet {
if n.ipt4 != nil {
err = n.setupIpt(n.local, n.rst, false)
} else {
err = n.setupNft(n.local, n.rst, false)
}
p := &nfqueuePacket{
id: *a.PacketID,
streamID: ctIDFromCtBytes(*a.Ct),
data: *a.Payload,
nq: nq,
}
return okBoolToInt(cb(p, nil))
},
func(e error) int {
if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) {
if errors.Is(opErr.Err, unix.ENOBUFS) {
return 0
}
}
return okBoolToInt(cb(nil, e))
})
if err != nil {
return err
}
n.rSet = true
}
if !nio.rSet {
if nio.ipt4 != nil {
err := nio.setupIpt(nio.local, nio.rst, false)
if err != nil {
return err
}
} else {
err := nio.setupNft(nio.local, nio.rst, false)
if err != nil {
return err
}
}
nio.rSet = true
}
return nil
}
func (n *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) {
func (nio *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) {
if a.PacketID == nil {
// Re-inject to NFQUEUE is actually not possible in this condition
return false, -1
}
if a.Payload == nil || len(*a.Payload) < 20 {
// 20 is the minimum possible size of an IP packet
return false, nfqueue.NfDrop
}
if a.Ct == nil {
// Multicast packets may not have a conntrack, but only appear in local mode
if n.local {
if nio.local {
return false, nfqueue.NfAccept
}
return false, nfqueue.NfDrop
@@ -236,46 +274,54 @@ func (n *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bo
return true, -1
}
func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
func (nio *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
nP, ok := p.(*nfqueuePacket)
if !ok {
return &ErrInvalidPacket{Err: errNotNFQueuePacket}
}
switch v {
case VerdictAccept:
return n.n.SetVerdict(nP.id, nfqueue.NfAccept)
return nP.nq.SetVerdict(nP.id, nfqueue.NfAccept)
case VerdictAcceptModify:
return n.n.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
return nP.nq.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
case VerdictAcceptStream:
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept)
return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept)
case VerdictDrop:
return n.n.SetVerdict(nP.id, nfqueue.NfDrop)
return nP.nq.SetVerdict(nP.id, nfqueue.NfDrop)
case VerdictDropStream:
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop)
return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop)
default:
// Invalid verdict, ignore for now
return nil
}
}
func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
return n.protectedDialer.DialContext(ctx, network, address)
func (nio *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
return nio.protectedDialer.DialContext(ctx, network, address)
}
func (n *nfqueuePacketIO) Close() error {
if n.rSet {
if n.ipt4 != nil {
_ = n.setupIpt(n.local, n.rst, true)
func (nio *nfqueuePacketIO) Close() error {
if nio.rSet {
if nio.ipt4 != nil {
_ = nio.setupIpt(nio.local, nio.rst, true)
} else {
_ = n.setupNft(n.local, n.rst, true)
_ = nio.setupNft(nio.local, nio.rst, true)
}
n.rSet = false
nio.rSet = false
}
return n.n.Close()
var errs []error
for _, nq := range nio.nqs {
if err := nq.Close(); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errs[0]
}
return nil
}
func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
rules, err := generateNftRules(local, rst)
func (nio *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
rules, err := generateNftRules(local, rst, nio.numQueues)
if err != nil {
return err
}
@@ -283,30 +329,23 @@ func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
if remove {
err = nftDelete(nftFamily, nftTable)
} else {
// Delete first to make sure no leftover rules
_ = nftDelete(nftFamily, nftTable)
err = nftAdd(rulesText)
}
if err != nil {
return err
}
return nil
return err
}
func (n *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
rules, err := generateIptRules(local, rst)
func (nio *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
rules, err := generateIptRules(local, rst, nio.numQueues)
if err != nil {
return err
}
if remove {
err = iptsBatchDeleteIfExists([]*iptables.IPTables{n.ipt4, n.ipt6}, rules)
err = iptsBatchDeleteIfExists([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
} else {
err = iptsBatchAppendUnique([]*iptables.IPTables{n.ipt4, n.ipt6}, rules)
err = iptsBatchAppendUnique([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules)
}
if err != nil {
return err
}
return nil
return err
}
var _ Packet = (*nfqueuePacket)(nil)
@@ -315,30 +354,22 @@ type nfqueuePacket struct {
id uint32
streamID uint32
data []byte
nq *nfqueue.Nfqueue
}
func (p *nfqueuePacket) StreamID() uint32 {
return p.streamID
}
func (p *nfqueuePacket) Data() []byte {
return p.data
}
func (p *nfqueuePacket) StreamID() uint32 { return p.streamID }
func (p *nfqueuePacket) Data() []byte { return p.data }
func okBoolToInt(ok bool) int {
if ok {
return 0
} else {
return 1
}
return 1
}
func nftCheck() error {
_, err := exec.LookPath("nft")
if err != nil {
return err
}
return nil
return err
}
func nftAdd(input string) error {
@@ -363,7 +394,6 @@ func (t *nftTableSpec) String() string {
for _, c := range t.Chains {
chains = append(chains, c.String())
}
return fmt.Sprintf(`
%s