package io import ( "context" "encoding/binary" "errors" "fmt" "net" "os/exec" "strconv" "strings" "syscall" "github.com/coreos/go-iptables/iptables" "github.com/florianl/go-nfqueue" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) const ( nfqueueNumStart = 100 nfqueueDefaultQueueSize = 128 nfqueueDefaultMaxLen = 0xFFFF nfqueueConnMarkAccept = 1001 nfqueueConnMarkDrop = 1002 nftFamily = "inet" nftTable = "mellaris" ) func generateNftRules(local, rst bool, numQueues int) (*nftTableSpec, error) { if local && rst { return nil, errors.New("tcp rst is not supported in local mode") } if numQueues < 1 { numQueues = 1 } table := &nftTableSpec{ Family: nftFamily, Table: nftTable, } table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept)) table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop)) queueEnd := nfqueueNumStart + numQueues - 1 if numQueues == 1 { table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNumStart)) } else { table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d-%d", nfqueueNumStart, queueEnd)) } if local { table.Chains = []nftChainSpec{ {Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"}, {Chain: "OUTPUT", Header: "type filter hook output priority filter; policy accept;"}, } } else { table.Chains = []nftChainSpec{ {Chain: "FORWARD", Header: "type filter hook forward priority filter; policy accept;"}, } } for i := range table.Chains { c := &table.Chains[i] c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept") if rst { c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset") } c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop") c.Rules = append(c.Rules, "counter queue num $QUEUE_NUM bypass") } return table, nil } func generateIptRules(local, rst bool, numQueues int) ([]iptRule, error) { if local && rst { return nil, errors.New("tcp rst is not supported in local mode") } if numQueues < 1 { numQueues = 1 } var chains []string if local { chains = []string{"INPUT", "OUTPUT"} } else { chains = []string{"FORWARD"} } rules := make([]iptRule, 0, 4*len(chains)) for _, chain := range chains { rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}}) rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}) if rst { rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}) } rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}) if numQueues == 1 { rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNumStart), "--queue-bypass"}}) } else { queueSpec := fmt.Sprintf("%d:%d", nfqueueNumStart, nfqueueNumStart+numQueues-1) rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-balance", queueSpec, "--queue-bypass"}}) } } return rules, nil } var _ PacketIO = (*nfqueuePacketIO)(nil) var errNotNFQueuePacket = errors.New("not an NFQueue packet") type nfqueuePacketIO struct { nqs []*nfqueue.Nfqueue numQueues int local bool rst bool rSet bool ipt4 *iptables.IPTables ipt6 *iptables.IPTables protectedDialer *net.Dialer } type NFQueuePacketIOConfig struct { QueueSize uint32 ReadBuffer int WriteBuffer int Local bool RST bool NumQueues int MaxPacketLen uint32 } func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { if config.QueueSize == 0 { config.QueueSize = nfqueueDefaultQueueSize } if config.NumQueues <= 0 { config.NumQueues = 1 } if config.MaxPacketLen == 0 { config.MaxPacketLen = nfqueueDefaultMaxLen } var ipt4, ipt6 *iptables.IPTables var err error if nftCheck() != nil { ipt4, err = iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, err } ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) if err != nil { return nil, err } } nqs := make([]*nfqueue.Nfqueue, config.NumQueues) for i := range nqs { n, err := nfqueue.Open(&nfqueue.Config{ NfQueue: uint16(nfqueueNumStart + i), MaxPacketLen: config.MaxPacketLen, MaxQueueLen: config.QueueSize, Copymode: nfqueue.NfQnlCopyPacket, Flags: nfqueue.NfQaCfgFlagConntrack, }) if err != nil { for j := 0; j < i; j++ { nqs[j].Close() } return nil, err } if config.ReadBuffer > 0 { err = n.Con.SetReadBuffer(config.ReadBuffer) if err != nil { for j := 0; j <= i; j++ { nqs[j].Close() } return nil, err } } if config.WriteBuffer > 0 { err = n.Con.SetWriteBuffer(config.WriteBuffer) if err != nil { for j := 0; j <= i; j++ { nqs[j].Close() } return nil, err } } nqs[i] = n } return &nfqueuePacketIO{ nqs: nqs, numQueues: config.NumQueues, local: config.Local, rst: config.RST, ipt4: ipt4, ipt6: ipt6, protectedDialer: &net.Dialer{ Control: func(network, address string, c syscall.RawConn) error { var err error cErr := c.Control(func(fd uintptr) { err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept) }) if cErr != nil { return cErr } return err }, }, }, nil } func (nio *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error { for _, nq := range nio.nqs { nq := nq err := nq.RegisterWithErrorFunc(ctx, func(a nfqueue.Attribute) int { if ok, verdict := nio.packetAttributeSanityCheck(a); !ok { if a.PacketID != nil { _ = nq.SetVerdict(*a.PacketID, verdict) } return 0 } p := &nfqueuePacket{ id: *a.PacketID, streamID: ctIDFromCtBytes(*a.Ct), data: *a.Payload, nq: nq, } return okBoolToInt(cb(p, nil)) }, func(e error) int { if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) { if errors.Is(opErr.Err, unix.ENOBUFS) { return 0 } } return okBoolToInt(cb(nil, e)) }) if err != nil { return err } } if !nio.rSet { if nio.ipt4 != nil { err := nio.setupIpt(nio.local, nio.rst, false) if err != nil { return err } } else { err := nio.setupNft(nio.local, nio.rst, false) if err != nil { return err } } nio.rSet = true } return nil } func (nio *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) { if a.PacketID == nil { return false, -1 } if a.Payload == nil || len(*a.Payload) < 20 { return false, nfqueue.NfDrop } if a.Ct == nil { if nio.local { return false, nfqueue.NfAccept } return false, nfqueue.NfDrop } return true, -1 } func (nio *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error { nP, ok := p.(*nfqueuePacket) if !ok { return &ErrInvalidPacket{Err: errNotNFQueuePacket} } switch v { case VerdictAccept: return nP.nq.SetVerdict(nP.id, nfqueue.NfAccept) case VerdictAcceptModify: return nP.nq.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket) case VerdictAcceptStream: return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept) case VerdictDrop: return nP.nq.SetVerdict(nP.id, nfqueue.NfDrop) case VerdictDropStream: return nP.nq.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop) default: return nil } } func (nio *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) { return nio.protectedDialer.DialContext(ctx, network, address) } func (nio *nfqueuePacketIO) Close() error { if nio.rSet { if nio.ipt4 != nil { _ = nio.setupIpt(nio.local, nio.rst, true) } else { _ = nio.setupNft(nio.local, nio.rst, true) } nio.rSet = false } var errs []error for _, nq := range nio.nqs { if err := nq.Close(); err != nil { errs = append(errs, err) } } if len(errs) > 0 { return errs[0] } return nil } func (nio *nfqueuePacketIO) setupNft(local, rst, remove bool) error { rules, err := generateNftRules(local, rst, nio.numQueues) if err != nil { return err } rulesText := rules.String() if remove { err = nftDelete(nftFamily, nftTable) } else { _ = nftDelete(nftFamily, nftTable) err = nftAdd(rulesText) } return err } func (nio *nfqueuePacketIO) setupIpt(local, rst, remove bool) error { rules, err := generateIptRules(local, rst, nio.numQueues) if err != nil { return err } if remove { err = iptsBatchDeleteIfExists([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules) } else { err = iptsBatchAppendUnique([]*iptables.IPTables{nio.ipt4, nio.ipt6}, rules) } return err } var _ Packet = (*nfqueuePacket)(nil) type nfqueuePacket struct { id uint32 streamID uint32 data []byte nq *nfqueue.Nfqueue } func (p *nfqueuePacket) StreamID() uint32 { return p.streamID } func (p *nfqueuePacket) Data() []byte { return p.data } func okBoolToInt(ok bool) int { if ok { return 0 } return 1 } func nftCheck() error { _, err := exec.LookPath("nft") return err } func nftAdd(input string) error { cmd := exec.Command("nft", "-f", "-") cmd.Stdin = strings.NewReader(input) return cmd.Run() } func nftDelete(family, table string) error { cmd := exec.Command("nft", "delete", "table", family, table) return cmd.Run() } type nftTableSpec struct { Defines []string Family, Table string Chains []nftChainSpec } func (t *nftTableSpec) String() string { chains := make([]string, 0, len(t.Chains)) for _, c := range t.Chains { chains = append(chains, c.String()) } return fmt.Sprintf(` %s table %s %s { %s } `, strings.Join(t.Defines, "\n"), t.Family, t.Table, strings.Join(chains, "")) } type nftChainSpec struct { Chain string Header string Rules []string } func (c *nftChainSpec) String() string { return fmt.Sprintf(` chain %s { %s %s } `, c.Chain, c.Header, strings.Join(c.Rules, "\n\x20\x20\x20\x20")) } type iptRule struct { Table, Chain string RuleSpec []string } func iptsBatchAppendUnique(ipts []*iptables.IPTables, rules []iptRule) error { for _, r := range rules { for _, ipt := range ipts { err := ipt.AppendUnique(r.Table, r.Chain, r.RuleSpec...) if err != nil { return err } } } return nil } func iptsBatchDeleteIfExists(ipts []*iptables.IPTables, rules []iptRule) error { for _, r := range rules { for _, ipt := range ipts { err := ipt.DeleteIfExists(r.Table, r.Chain, r.RuleSpec...) if err != nil { return err } } } return nil } func ctIDFromCtBytes(ct []byte) uint32 { ctAttrs, err := netlink.UnmarshalAttributes(ct) if err != nil { return 0 } for _, attr := range ctAttrs { if attr.Type == 12 { // CTA_ID return binary.BigEndian.Uint32(attr.Data) } } return 0 }