First Commit
Some checks failed
Quality check / Static analysis (push) Has been cancelled
Quality check / Tests (push) Has been cancelled

This commit is contained in:
Hayzam Sherif
2026-02-11 06:27:36 +05:30
commit 94e1e26cc3
56 changed files with 8530 additions and 0 deletions

47
.github/workflows/check.yaml vendored Normal file
View File

@@ -0,0 +1,47 @@
name: Quality check
on:
push:
branches:
- "*"
pull_request:
permissions:
contents: read
jobs:
static-analysis:
name: Static analysis
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: 'stable'
- run: go vet ./...
- name: staticcheck
uses: dominikh/staticcheck-action@v1.3.0
with:
install-go: false
tests:
name: Tests
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: 'stable'
- run: go test ./...

210
.gitignore vendored Normal file
View File

@@ -0,0 +1,210 @@
# Created by https://www.toptal.com/developers/gitignore/api/windows,macos,linux,go,goland+all,visualstudiocode
# Edit at https://www.toptal.com/developers/gitignore?templates=windows,macos,linux,go,goland+all,visualstudiocode
### Go ###
# If you prefer the allow list template instead of the deny list, see community template:
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
#
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/
# Go workspace file
go.work
### GoLand+all ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# AWS User-specific
.idea/**/aws.xml
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# SonarLint plugin
.idea/sonarlint/
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
### GoLand+all Patch ###
# Ignore everything but code style settings and run configurations
# that are supposed to be shared within teams.
.idea/*
!.idea/codeStyles
!.idea/runConfigurations
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### macOS Patch ###
# iCloud generated files
*.icloud
### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
# End of https://www.toptal.com/developers/gitignore/api/windows,macos,linux,go,goland+all,visualstudiocode
# Internal tools not ready for public use yet
tools/flowseq/

373
LICENSE Normal file
View File

@@ -0,0 +1,373 @@
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.

48
README.md Normal file
View File

@@ -0,0 +1,48 @@
# Mellaris
Mellaris is a Go library for network stream analysis and filtering.
## Usage
```go
package main
import (
"context"
"git.difuse.io/Difuse/Mellaris"
)
func main() {
cfg := mellaris.Config{
IO: mellaris.IOConfig{
// QueueSize, ReadBuffer, WriteBuffer, Local, RST
},
Workers: mellaris.WorkersConfig{
// Count, QueueSize, TCPMaxBufferedPagesTotal, TCPMaxBufferedPagesPerConn, UDPMaxStreams
},
Ruleset: mellaris.RulesetConfig{
GeoIp: "/path/to/geoip.dat",
GeoSite: "/path/to/geosite.dat",
},
}
app, err := mellaris.New(cfg, mellaris.Options{
RulesFile: "rules.yaml",
Analyzers: mellaris.DefaultAnalyzers(),
Modifiers: mellaris.DefaultModifiers(),
})
if err != nil {
panic(err)
}
defer app.Close()
_ = app.Run(context.Background())
}
```
Based on OpenGFW by apernet: https://github.com/apernet/OpenGFW
## LICENSE
[MPL-2.0](https://opensource.org/licenses/MPL-2.0)

131
analyzer/interface.go Normal file
View File

@@ -0,0 +1,131 @@
package analyzer
import (
"net"
"strings"
)
type Analyzer interface {
// Name returns the name of the analyzer.
Name() string
// Limit returns the byte limit for this analyzer.
// For example, an analyzer can return 1000 to indicate that it only ever needs
// the first 1000 bytes of a stream to do its job. If the stream is still not
// done after 1000 bytes, the engine will stop feeding it data and close it.
// An analyzer can return 0 or a negative number to indicate that it does not
// have a hard limit.
// Note: for UDP streams, the engine always feeds entire packets, even if
// the packet is larger than the remaining quota or the limit itself.
Limit() int
}
type Logger interface {
Debugf(format string, args ...interface{})
Infof(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
type TCPAnalyzer interface {
Analyzer
// NewTCP returns a new TCPStream.
NewTCP(TCPInfo, Logger) TCPStream
}
type TCPInfo struct {
// SrcIP is the source IP address.
SrcIP net.IP
// DstIP is the destination IP address.
DstIP net.IP
// SrcPort is the source port.
SrcPort uint16
// DstPort is the destination port.
DstPort uint16
}
type TCPStream interface {
// Feed feeds a chunk of reassembled data to the stream.
// It returns a prop update containing the information extracted from the stream (can be nil),
// and whether the analyzer is "done" with this stream (i.e. no more data should be fed).
Feed(rev, start, end bool, skip int, data []byte) (u *PropUpdate, done bool)
// Close indicates that the stream is closed.
// Either the connection is closed, or the stream has reached its byte limit.
// Like Feed, it optionally returns a prop update.
Close(limited bool) *PropUpdate
}
type UDPAnalyzer interface {
Analyzer
// NewUDP returns a new UDPStream.
NewUDP(UDPInfo, Logger) UDPStream
}
type UDPInfo struct {
// SrcIP is the source IP address.
SrcIP net.IP
// DstIP is the destination IP address.
DstIP net.IP
// SrcPort is the source port.
SrcPort uint16
// DstPort is the destination port.
DstPort uint16
}
type UDPStream interface {
// Feed feeds a new packet to the stream.
// It returns a prop update containing the information extracted from the stream (can be nil),
// and whether the analyzer is "done" with this stream (i.e. no more data should be fed).
Feed(rev bool, data []byte) (u *PropUpdate, done bool)
// Close indicates that the stream is closed.
// Either the connection is closed, or the stream has reached its byte limit.
// Like Feed, it optionally returns a prop update.
Close(limited bool) *PropUpdate
}
type (
PropMap map[string]interface{}
CombinedPropMap map[string]PropMap
)
// Get returns the value of the property with the given key.
// The key can be a nested key, e.g. "foo.bar.baz".
// Returns nil if the key does not exist.
func (m PropMap) Get(key string) interface{} {
keys := strings.Split(key, ".")
if len(keys) == 0 {
return nil
}
var current interface{} = m
for _, k := range keys {
currentMap, ok := current.(PropMap)
if !ok {
return nil
}
current = currentMap[k]
}
return current
}
// Get returns the value of the property with the given analyzer & key.
// The key can be a nested key, e.g. "foo.bar.baz".
// Returns nil if the key does not exist.
func (cm CombinedPropMap) Get(an string, key string) interface{} {
m, ok := cm[an]
if !ok {
return nil
}
return m.Get(key)
}
type PropUpdateType int
const (
PropUpdateNone PropUpdateType = iota
PropUpdateMerge
PropUpdateReplace
PropUpdateDelete
)
type PropUpdate struct {
Type PropUpdateType
M PropMap
}

224
analyzer/internal/tls.go Normal file
View File

@@ -0,0 +1,224 @@
package internal
import (
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
)
// TLS record types.
const (
RecordTypeHandshake = 0x16
)
// TLS handshake message types.
const (
TypeClientHello = 0x01
TypeServerHello = 0x02
)
// TLS extension numbers.
const (
extServerName = 0x0000
extALPN = 0x0010
extSupportedVersions = 0x002b
extEncryptedClientHello = 0xfe0d
)
func ParseTLSClientHelloMsgData(chBuf *utils.ByteBuffer) analyzer.PropMap {
var ok bool
m := make(analyzer.PropMap)
// Version, random & session ID length combined are within 35 bytes,
// so no need for bounds checking
m["version"], _ = chBuf.GetUint16(false, true)
m["random"], _ = chBuf.Get(32, true)
sessionIDLen, _ := chBuf.GetByte(true)
m["session"], ok = chBuf.Get(int(sessionIDLen), true)
if !ok {
// Not enough data for session ID
return nil
}
cipherSuitesLen, ok := chBuf.GetUint16(false, true)
if !ok {
// Not enough data for cipher suites length
return nil
}
if cipherSuitesLen%2 != 0 {
// Cipher suites are 2 bytes each, so must be even
return nil
}
ciphers := make([]uint16, cipherSuitesLen/2)
for i := range ciphers {
ciphers[i], ok = chBuf.GetUint16(false, true)
if !ok {
return nil
}
}
m["ciphers"] = ciphers
compressionMethodsLen, ok := chBuf.GetByte(true)
if !ok {
// Not enough data for compression methods length
return nil
}
// Compression methods are 1 byte each, we just put a byte slice here
m["compression"], ok = chBuf.Get(int(compressionMethodsLen), true)
if !ok {
// Not enough data for compression methods
return nil
}
extsLen, ok := chBuf.GetUint16(false, true)
if !ok {
// No extensions, I guess it's possible?
return m
}
extBuf, ok := chBuf.GetSubBuffer(int(extsLen), true)
if !ok {
// Not enough data for extensions
return nil
}
for extBuf.Len() > 0 {
extType, ok := extBuf.GetUint16(false, true)
if !ok {
// Not enough data for extension type
return nil
}
extLen, ok := extBuf.GetUint16(false, true)
if !ok {
// Not enough data for extension length
return nil
}
extDataBuf, ok := extBuf.GetSubBuffer(int(extLen), true)
if !ok || !parseTLSExtensions(extType, extDataBuf, m) {
// Not enough data for extension data, or invalid extension
return nil
}
}
return m
}
func ParseTLSServerHelloMsgData(shBuf *utils.ByteBuffer) analyzer.PropMap {
var ok bool
m := make(analyzer.PropMap)
// Version, random & session ID length combined are within 35 bytes,
// so no need for bounds checking
m["version"], _ = shBuf.GetUint16(false, true)
m["random"], _ = shBuf.Get(32, true)
sessionIDLen, _ := shBuf.GetByte(true)
m["session"], ok = shBuf.Get(int(sessionIDLen), true)
if !ok {
// Not enough data for session ID
return nil
}
cipherSuite, ok := shBuf.GetUint16(false, true)
if !ok {
// Not enough data for cipher suite
return nil
}
m["cipher"] = cipherSuite
compressionMethod, ok := shBuf.GetByte(true)
if !ok {
// Not enough data for compression method
return nil
}
m["compression"] = compressionMethod
extsLen, ok := shBuf.GetUint16(false, true)
if !ok {
// No extensions, I guess it's possible?
return m
}
extBuf, ok := shBuf.GetSubBuffer(int(extsLen), true)
if !ok {
// Not enough data for extensions
return nil
}
for extBuf.Len() > 0 {
extType, ok := extBuf.GetUint16(false, true)
if !ok {
// Not enough data for extension type
return nil
}
extLen, ok := extBuf.GetUint16(false, true)
if !ok {
// Not enough data for extension length
return nil
}
extDataBuf, ok := extBuf.GetSubBuffer(int(extLen), true)
if !ok || !parseTLSExtensions(extType, extDataBuf, m) {
// Not enough data for extension data, or invalid extension
return nil
}
}
return m
}
func parseTLSExtensions(extType uint16, extDataBuf *utils.ByteBuffer, m analyzer.PropMap) bool {
switch extType {
case extServerName:
ok := extDataBuf.Skip(2) // Ignore list length, we only care about the first entry for now
if !ok {
// Not enough data for list length
return false
}
sniType, ok := extDataBuf.GetByte(true)
if !ok || sniType != 0 {
// Not enough data for SNI type, or not hostname
return false
}
sniLen, ok := extDataBuf.GetUint16(false, true)
if !ok {
// Not enough data for SNI length
return false
}
m["sni"], ok = extDataBuf.GetString(int(sniLen), true)
if !ok {
// Not enough data for SNI
return false
}
case extALPN:
ok := extDataBuf.Skip(2) // Ignore list length, as we read until the end
if !ok {
// Not enough data for list length
return false
}
var alpnList []string
for extDataBuf.Len() > 0 {
alpnLen, ok := extDataBuf.GetByte(true)
if !ok {
// Not enough data for ALPN length
return false
}
alpn, ok := extDataBuf.GetString(int(alpnLen), true)
if !ok {
// Not enough data for ALPN
return false
}
alpnList = append(alpnList, alpn)
}
m["alpn"] = alpnList
case extSupportedVersions:
if extDataBuf.Len() == 2 {
// Server only selects one version
m["supported_versions"], _ = extDataBuf.GetUint16(false, true)
} else {
// Client sends a list of versions
ok := extDataBuf.Skip(1) // Ignore list length, as we read until the end
if !ok {
// Not enough data for list length
return false
}
var versions []uint16
for extDataBuf.Len() > 0 {
ver, ok := extDataBuf.GetUint16(false, true)
if !ok {
// Not enough data for version
return false
}
versions = append(versions, ver)
}
m["supported_versions"] = versions
}
case extEncryptedClientHello:
// We can't parse ECH for now, just set a flag
m["ech"] = true
}
return true
}

162
analyzer/tcp/fet.go Normal file
View File

@@ -0,0 +1,162 @@
package tcp
import "git.difuse.io/Difuse/Mellaris/analyzer"
var _ analyzer.TCPAnalyzer = (*FETAnalyzer)(nil)
// FETAnalyzer stands for "Fully Encrypted Traffic" analyzer.
// It implements an algorithm to detect fully encrypted proxy protocols
// such as Shadowsocks, mentioned in the following paper:
// https://gfw.report/publications/usenixsecurity23/data/paper/paper.pdf
type FETAnalyzer struct{}
func (a *FETAnalyzer) Name() string {
return "fet"
}
func (a *FETAnalyzer) Limit() int {
// We only really look at the first packet
return 8192
}
func (a *FETAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream {
return newFETStream(logger)
}
type fetStream struct {
logger analyzer.Logger
}
func newFETStream(logger analyzer.Logger) *fetStream {
return &fetStream{logger: logger}
}
func (s *fetStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, done bool) {
if skip != 0 {
return nil, true
}
if len(data) == 0 {
return nil, false
}
ex1 := averagePopCount(data)
ex2 := isFirstSixPrintable(data)
ex3 := printablePercentage(data)
ex4 := contiguousPrintable(data)
ex5 := isTLSorHTTP(data)
exempt := (ex1 <= 3.4 || ex1 >= 4.6) || ex2 || ex3 > 0.5 || ex4 > 20 || ex5
return &analyzer.PropUpdate{
Type: analyzer.PropUpdateReplace,
M: analyzer.PropMap{
"ex1": ex1,
"ex2": ex2,
"ex3": ex3,
"ex4": ex4,
"ex5": ex5,
"yes": !exempt,
},
}, true
}
func (s *fetStream) Close(limited bool) *analyzer.PropUpdate {
return nil
}
func popCount(b byte) int {
count := 0
for b != 0 {
count += int(b & 1)
b >>= 1
}
return count
}
// averagePopCount returns the average popcount of the given bytes.
// This is the "Ex1" metric in the paper.
func averagePopCount(bytes []byte) float32 {
if len(bytes) == 0 {
return 0
}
total := 0
for _, b := range bytes {
total += popCount(b)
}
return float32(total) / float32(len(bytes))
}
// isFirstSixPrintable returns true if the first six bytes are printable ASCII.
// This is the "Ex2" metric in the paper.
func isFirstSixPrintable(bytes []byte) bool {
if len(bytes) < 6 {
return false
}
for i := range bytes[:6] {
if !isPrintable(bytes[i]) {
return false
}
}
return true
}
// printablePercentage returns the percentage of printable ASCII bytes.
// This is the "Ex3" metric in the paper.
func printablePercentage(bytes []byte) float32 {
if len(bytes) == 0 {
return 0
}
count := 0
for i := range bytes {
if isPrintable(bytes[i]) {
count++
}
}
return float32(count) / float32(len(bytes))
}
// contiguousPrintable returns the length of the longest contiguous sequence of
// printable ASCII bytes.
// This is the "Ex4" metric in the paper.
func contiguousPrintable(bytes []byte) int {
if len(bytes) == 0 {
return 0
}
maxCount := 0
current := 0
for i := range bytes {
if isPrintable(bytes[i]) {
current++
} else {
if current > maxCount {
maxCount = current
}
current = 0
}
}
if current > maxCount {
maxCount = current
}
return maxCount
}
// isTLSorHTTP returns true if the given bytes look like TLS or HTTP.
// This is the "Ex5" metric in the paper.
func isTLSorHTTP(bytes []byte) bool {
if len(bytes) < 3 {
return false
}
// "We observe that the GFW exempts any connection whose first
// three bytes match the following regular expression:
// [\x16-\x17]\x03[\x00-\x09]" - from the paper in Section 4.3
if bytes[0] >= 0x16 && bytes[0] <= 0x17 &&
bytes[1] == 0x03 && bytes[2] <= 0x09 {
return true
}
// HTTP request
str := string(bytes[:3])
return str == "GET" || str == "HEA" || str == "POS" ||
str == "PUT" || str == "DEL" || str == "CON" ||
str == "OPT" || str == "TRA" || str == "PAT"
}
func isPrintable(b byte) bool {
return b >= 0x20 && b <= 0x7e
}

193
analyzer/tcp/http.go Normal file
View File

@@ -0,0 +1,193 @@
package tcp
import (
"bytes"
"strconv"
"strings"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
)
var _ analyzer.TCPAnalyzer = (*HTTPAnalyzer)(nil)
type HTTPAnalyzer struct{}
func (a *HTTPAnalyzer) Name() string {
return "http"
}
func (a *HTTPAnalyzer) Limit() int {
return 8192
}
func (a *HTTPAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream {
return newHTTPStream(logger)
}
type httpStream struct {
logger analyzer.Logger
reqBuf *utils.ByteBuffer
reqMap analyzer.PropMap
reqUpdated bool
reqLSM *utils.LinearStateMachine
reqDone bool
respBuf *utils.ByteBuffer
respMap analyzer.PropMap
respUpdated bool
respLSM *utils.LinearStateMachine
respDone bool
}
func newHTTPStream(logger analyzer.Logger) *httpStream {
s := &httpStream{logger: logger, reqBuf: &utils.ByteBuffer{}, respBuf: &utils.ByteBuffer{}}
s.reqLSM = utils.NewLinearStateMachine(
s.parseRequestLine,
s.parseRequestHeaders,
)
s.respLSM = utils.NewLinearStateMachine(
s.parseResponseLine,
s.parseResponseHeaders,
)
return s
}
func (s *httpStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, d bool) {
if skip != 0 {
return nil, true
}
if len(data) == 0 {
return nil, false
}
var update *analyzer.PropUpdate
var cancelled bool
if rev {
s.respBuf.Append(data)
s.respUpdated = false
cancelled, s.respDone = s.respLSM.Run()
if s.respUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{"resp": s.respMap},
}
s.respUpdated = false
}
} else {
s.reqBuf.Append(data)
s.reqUpdated = false
cancelled, s.reqDone = s.reqLSM.Run()
if s.reqUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{"req": s.reqMap},
}
s.reqUpdated = false
}
}
return update, cancelled || (s.reqDone && s.respDone)
}
func (s *httpStream) parseRequestLine() utils.LSMAction {
// Find the end of the request line
line, ok := s.reqBuf.GetUntil([]byte("\r\n"), true, true)
if !ok {
// No end of line yet, but maybe we just need more data
return utils.LSMActionPause
}
fields := strings.Fields(string(line[:len(line)-2])) // Strip \r\n
if len(fields) != 3 {
// Invalid request line
return utils.LSMActionCancel
}
method := fields[0]
path := fields[1]
version := fields[2]
if !strings.HasPrefix(version, "HTTP/") {
// Invalid version
return utils.LSMActionCancel
}
s.reqMap = analyzer.PropMap{
"method": method,
"path": path,
"version": version,
}
s.reqUpdated = true
return utils.LSMActionNext
}
func (s *httpStream) parseResponseLine() utils.LSMAction {
// Find the end of the response line
line, ok := s.respBuf.GetUntil([]byte("\r\n"), true, true)
if !ok {
// No end of line yet, but maybe we just need more data
return utils.LSMActionPause
}
fields := strings.Fields(string(line[:len(line)-2])) // Strip \r\n
if len(fields) < 2 {
// Invalid response line
return utils.LSMActionCancel
}
version := fields[0]
status, _ := strconv.Atoi(fields[1])
if !strings.HasPrefix(version, "HTTP/") || status == 0 {
// Invalid version
return utils.LSMActionCancel
}
s.respMap = analyzer.PropMap{
"version": version,
"status": status,
}
s.respUpdated = true
return utils.LSMActionNext
}
func (s *httpStream) parseHeaders(buf *utils.ByteBuffer) (utils.LSMAction, analyzer.PropMap) {
// Find the end of headers
headers, ok := buf.GetUntil([]byte("\r\n\r\n"), true, true)
if !ok {
// No end of headers yet, but maybe we just need more data
return utils.LSMActionPause, nil
}
headers = headers[:len(headers)-4] // Strip \r\n\r\n
headerMap := make(analyzer.PropMap)
for _, line := range bytes.Split(headers, []byte("\r\n")) {
fields := bytes.SplitN(line, []byte(":"), 2)
if len(fields) != 2 {
// Invalid header
return utils.LSMActionCancel, nil
}
key := string(bytes.TrimSpace(fields[0]))
value := string(bytes.TrimSpace(fields[1]))
// Normalize header keys to lowercase
headerMap[strings.ToLower(key)] = value
}
return utils.LSMActionNext, headerMap
}
func (s *httpStream) parseRequestHeaders() utils.LSMAction {
action, headerMap := s.parseHeaders(s.reqBuf)
if action == utils.LSMActionNext {
s.reqMap["headers"] = headerMap
s.reqUpdated = true
}
return action
}
func (s *httpStream) parseResponseHeaders() utils.LSMAction {
action, headerMap := s.parseHeaders(s.respBuf)
if action == utils.LSMActionNext {
s.respMap["headers"] = headerMap
s.respUpdated = true
}
return action
}
func (s *httpStream) Close(limited bool) *analyzer.PropUpdate {
s.reqBuf.Reset()
s.respBuf.Reset()
s.reqMap = nil
s.respMap = nil
return nil
}

64
analyzer/tcp/http_test.go Normal file
View File

@@ -0,0 +1,64 @@
package tcp
import (
"reflect"
"strings"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
func TestHTTPParsing_Request(t *testing.T) {
testCases := map[string]analyzer.PropMap{
"GET / HTTP/1.1\r\n": {
"method": "GET", "path": "/", "version": "HTTP/1.1",
},
"POST /hello?a=1&b=2 HTTP/1.0\r\n": {
"method": "POST", "path": "/hello?a=1&b=2", "version": "HTTP/1.0",
},
"PUT /world HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody": {
"method": "PUT", "path": "/world", "version": "HTTP/1.1", "headers": analyzer.PropMap{"content-length": "4"},
},
"DELETE /goodbye HTTP/2.0\r\n": {
"method": "DELETE", "path": "/goodbye", "version": "HTTP/2.0",
},
}
for tc, want := range testCases {
t.Run(strings.Split(tc, " ")[0], func(t *testing.T) {
tc, want := tc, want
t.Parallel()
u, _ := newHTTPStream(nil).Feed(false, false, false, 0, []byte(tc))
got := u.M.Get("req")
if !reflect.DeepEqual(got, want) {
t.Errorf("\"%s\" parsed = %v, want %v", tc, got, want)
}
})
}
}
func TestHTTPParsing_Response(t *testing.T) {
testCases := map[string]analyzer.PropMap{
"HTTP/1.0 200 OK\r\nContent-Length: 4\r\n\r\nbody": {
"version": "HTTP/1.0", "status": 200,
"headers": analyzer.PropMap{"content-length": "4"},
},
"HTTP/2.0 204 No Content\r\n\r\n": {
"version": "HTTP/2.0", "status": 204,
},
}
for tc, want := range testCases {
t.Run(strings.Split(tc, " ")[0], func(t *testing.T) {
tc, want := tc, want
t.Parallel()
u, _ := newHTTPStream(nil).Feed(true, false, false, 0, []byte(tc))
got := u.M.Get("resp")
if !reflect.DeepEqual(got, want) {
t.Errorf("\"%s\" parsed = %v, want %v", tc, got, want)
}
})
}
}

508
analyzer/tcp/socks.go Normal file
View File

@@ -0,0 +1,508 @@
package tcp
import (
"net"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
)
const (
SocksInvalid = iota
Socks4
Socks4A
Socks5
Socks4Version = 0x04
Socks5Version = 0x05
Socks4ReplyVN = 0x00
Socks4CmdTCPConnect = 0x01
Socks4CmdTCPBind = 0x02
Socks4ReqGranted = 0x5A
Socks4ReqRejectOrFailed = 0x5B
Socks4ReqRejectIdentd = 0x5C
Socks4ReqRejectUser = 0x5D
Socks5CmdTCPConnect = 0x01
Socks5CmdTCPBind = 0x02
Socks5CmdUDPAssociate = 0x03
Socks5AuthNotRequired = 0x00
Socks5AuthPassword = 0x02
Socks5AuthNoMatchingMethod = 0xFF
Socks5AuthSuccess = 0x00
Socks5AuthFailure = 0x01
Socks5AddrTypeIPv4 = 0x01
Socks5AddrTypeDomain = 0x03
Socks5AddrTypeIPv6 = 0x04
)
var _ analyzer.Analyzer = (*SocksAnalyzer)(nil)
type SocksAnalyzer struct{}
func (a *SocksAnalyzer) Name() string {
return "socks"
}
func (a *SocksAnalyzer) Limit() int {
// Socks4 length limit cannot be predicted
return 0
}
func (a *SocksAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream {
return newSocksStream(logger)
}
type socksStream struct {
logger analyzer.Logger
reqBuf *utils.ByteBuffer
reqMap analyzer.PropMap
reqUpdated bool
reqLSM *utils.LinearStateMachine
reqDone bool
respBuf *utils.ByteBuffer
respMap analyzer.PropMap
respUpdated bool
respLSM *utils.LinearStateMachine
respDone bool
version int
authReqMethod int
authUsername string
authPassword string
authRespMethod int
}
func newSocksStream(logger analyzer.Logger) *socksStream {
s := &socksStream{logger: logger, reqBuf: &utils.ByteBuffer{}, respBuf: &utils.ByteBuffer{}}
s.reqLSM = utils.NewLinearStateMachine(
s.parseSocksReqVersion,
)
s.respLSM = utils.NewLinearStateMachine(
s.parseSocksRespVersion,
)
return s
}
func (s *socksStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, d bool) {
if skip != 0 {
return nil, true
}
if len(data) == 0 {
return nil, false
}
var update *analyzer.PropUpdate
var cancelled bool
if rev {
s.respBuf.Append(data)
s.respUpdated = false
cancelled, s.respDone = s.respLSM.Run()
if s.respUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{"resp": s.respMap},
}
s.respUpdated = false
}
} else {
s.reqBuf.Append(data)
s.reqUpdated = false
cancelled, s.reqDone = s.reqLSM.Run()
if s.reqUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{
"version": s.socksVersion(),
"req": s.reqMap,
},
}
s.reqUpdated = false
}
}
return update, cancelled || (s.reqDone && s.respDone)
}
func (s *socksStream) Close(limited bool) *analyzer.PropUpdate {
s.reqBuf.Reset()
s.respBuf.Reset()
s.reqMap = nil
s.respMap = nil
return nil
}
func (s *socksStream) parseSocksReqVersion() utils.LSMAction {
socksVer, ok := s.reqBuf.GetByte(true)
if !ok {
return utils.LSMActionPause
}
if socksVer != Socks4Version && socksVer != Socks5Version {
return utils.LSMActionCancel
}
s.reqMap = make(analyzer.PropMap)
s.reqUpdated = true
if socksVer == Socks4Version {
s.version = Socks4
s.reqLSM.AppendSteps(
s.parseSocks4ReqIpAndPort,
s.parseSocks4ReqUserId,
s.parseSocks4ReqHostname,
)
} else {
s.version = Socks5
s.reqLSM.AppendSteps(
s.parseSocks5ReqMethod,
s.parseSocks5ReqAuth,
s.parseSocks5ReqConnInfo,
)
}
return utils.LSMActionNext
}
func (s *socksStream) parseSocksRespVersion() utils.LSMAction {
socksVer, ok := s.respBuf.GetByte(true)
if !ok {
return utils.LSMActionPause
}
if (s.version == Socks4 || s.version == Socks4A) && socksVer != Socks4ReplyVN ||
s.version == Socks5 && socksVer != Socks5Version || s.version == SocksInvalid {
return utils.LSMActionCancel
}
if socksVer == Socks4ReplyVN {
s.respLSM.AppendSteps(
s.parseSocks4RespPacket,
)
} else {
s.respLSM.AppendSteps(
s.parseSocks5RespMethod,
s.parseSocks5RespAuth,
s.parseSocks5RespConnInfo,
)
}
return utils.LSMActionNext
}
func (s *socksStream) parseSocks5ReqMethod() utils.LSMAction {
nMethods, ok := s.reqBuf.GetByte(false)
if !ok {
return utils.LSMActionPause
}
methods, ok := s.reqBuf.Get(int(nMethods)+1, true)
if !ok {
return utils.LSMActionPause
}
// For convenience, we only take the first method we can process
s.authReqMethod = Socks5AuthNoMatchingMethod
for _, method := range methods[1:] {
switch method {
case Socks5AuthNotRequired:
s.authReqMethod = Socks5AuthNotRequired
return utils.LSMActionNext
case Socks5AuthPassword:
s.authReqMethod = Socks5AuthPassword
return utils.LSMActionNext
default:
// TODO: more auth method to support
}
}
return utils.LSMActionNext
}
func (s *socksStream) parseSocks5ReqAuth() utils.LSMAction {
switch s.authReqMethod {
case Socks5AuthNotRequired:
s.reqMap["auth"] = analyzer.PropMap{"method": s.authReqMethod}
case Socks5AuthPassword:
meta, ok := s.reqBuf.Get(2, false)
if !ok {
return utils.LSMActionPause
}
if meta[0] != 0x01 {
return utils.LSMActionCancel
}
usernameLen := int(meta[1])
meta, ok = s.reqBuf.Get(usernameLen+3, false)
if !ok {
return utils.LSMActionPause
}
passwordLen := int(meta[usernameLen+2])
meta, ok = s.reqBuf.Get(usernameLen+passwordLen+3, true)
if !ok {
return utils.LSMActionPause
}
s.authUsername = string(meta[2 : usernameLen+2])
s.authPassword = string(meta[usernameLen+3:])
s.reqMap["auth"] = analyzer.PropMap{
"method": s.authReqMethod,
"username": s.authUsername,
"password": s.authPassword,
}
default:
return utils.LSMActionCancel
}
s.reqUpdated = true
return utils.LSMActionNext
}
func (s *socksStream) parseSocks5ReqConnInfo() utils.LSMAction {
/* preInfo struct
+----+-----+-------+------+-------------+
|VER | CMD | RSV | ATYP | DST.ADDR(1) |
+----+-----+-------+------+-------------+
*/
preInfo, ok := s.reqBuf.Get(5, false)
if !ok {
return utils.LSMActionPause
}
// verify socks version
if preInfo[0] != Socks5Version {
return utils.LSMActionCancel
}
var pktLen int
switch int(preInfo[3]) {
case Socks5AddrTypeIPv4:
pktLen = 10
case Socks5AddrTypeDomain:
domainLen := int(preInfo[4])
pktLen = 7 + domainLen
case Socks5AddrTypeIPv6:
pktLen = 22
default:
return utils.LSMActionCancel
}
pkt, ok := s.reqBuf.Get(pktLen, true)
if !ok {
return utils.LSMActionPause
}
// parse cmd
cmd := int(pkt[1])
if cmd != Socks5CmdTCPConnect && cmd != Socks5CmdTCPBind && cmd != Socks5CmdUDPAssociate {
return utils.LSMActionCancel
}
s.reqMap["cmd"] = cmd
// parse addr type
addrType := int(pkt[3])
var addr string
switch addrType {
case Socks5AddrTypeIPv4:
addr = net.IPv4(pkt[4], pkt[5], pkt[6], pkt[7]).String()
case Socks5AddrTypeDomain:
addr = string(pkt[5 : 5+pkt[4]])
case Socks5AddrTypeIPv6:
addr = net.IP(pkt[4 : 4+net.IPv6len]).String()
default:
return utils.LSMActionCancel
}
s.reqMap["addr_type"] = addrType
s.reqMap["addr"] = addr
// parse port
port := int(pkt[pktLen-2])<<8 | int(pkt[pktLen-1])
s.reqMap["port"] = port
s.reqUpdated = true
return utils.LSMActionNext
}
func (s *socksStream) parseSocks5RespMethod() utils.LSMAction {
method, ok := s.respBuf.Get(1, true)
if !ok {
return utils.LSMActionPause
}
s.authRespMethod = int(method[0])
s.respMap = make(analyzer.PropMap)
return utils.LSMActionNext
}
func (s *socksStream) parseSocks5RespAuth() utils.LSMAction {
switch s.authRespMethod {
case Socks5AuthNotRequired:
s.respMap["auth"] = analyzer.PropMap{"method": s.authRespMethod}
case Socks5AuthPassword:
authResp, ok := s.respBuf.Get(2, true)
if !ok {
return utils.LSMActionPause
}
if authResp[0] != 0x01 {
return utils.LSMActionCancel
}
authStatus := int(authResp[1])
s.respMap["auth"] = analyzer.PropMap{
"method": s.authRespMethod,
"status": authStatus,
}
default:
return utils.LSMActionCancel
}
s.respUpdated = true
return utils.LSMActionNext
}
func (s *socksStream) parseSocks5RespConnInfo() utils.LSMAction {
/* preInfo struct
+----+-----+-------+------+-------------+
|VER | REP | RSV | ATYP | BND.ADDR(1) |
+----+-----+-------+------+-------------+
*/
preInfo, ok := s.respBuf.Get(5, false)
if !ok {
return utils.LSMActionPause
}
// verify socks version
if preInfo[0] != Socks5Version {
return utils.LSMActionCancel
}
var pktLen int
switch int(preInfo[3]) {
case Socks5AddrTypeIPv4:
pktLen = 10
case Socks5AddrTypeDomain:
domainLen := int(preInfo[4])
pktLen = 7 + domainLen
case Socks5AddrTypeIPv6:
pktLen = 22
default:
return utils.LSMActionCancel
}
pkt, ok := s.respBuf.Get(pktLen, true)
if !ok {
return utils.LSMActionPause
}
// parse rep
rep := int(pkt[1])
s.respMap["rep"] = rep
// parse addr type
addrType := int(pkt[3])
var addr string
switch addrType {
case Socks5AddrTypeIPv4:
addr = net.IPv4(pkt[4], pkt[5], pkt[6], pkt[7]).String()
case Socks5AddrTypeDomain:
addr = string(pkt[5 : 5+pkt[4]])
case Socks5AddrTypeIPv6:
addr = net.IP(pkt[4 : 4+net.IPv6len]).String()
default:
return utils.LSMActionCancel
}
s.respMap["addr_type"] = addrType
s.respMap["addr"] = addr
// parse port
port := int(pkt[pktLen-2])<<8 | int(pkt[pktLen-1])
s.respMap["port"] = port
s.respUpdated = true
return utils.LSMActionNext
}
func (s *socksStream) parseSocks4ReqIpAndPort() utils.LSMAction {
/* Following field will be parsed in this state:
+-----+----------+--------+
| CMD | DST.PORT | DST.IP |
+-----+----------+--------+
*/
pkt, ok := s.reqBuf.Get(7, true)
if !ok {
return utils.LSMActionPause
}
if pkt[0] != Socks4CmdTCPConnect && pkt[0] != Socks4CmdTCPBind {
return utils.LSMActionCancel
}
dstPort := uint16(pkt[1])<<8 | uint16(pkt[2])
dstIp := net.IPv4(pkt[3], pkt[4], pkt[5], pkt[6]).String()
// Socks4a extension
if pkt[3] == 0 && pkt[4] == 0 && pkt[5] == 0 {
s.version = Socks4A
}
s.reqMap["cmd"] = pkt[0]
s.reqMap["addr"] = dstIp
s.reqMap["addr_type"] = Socks5AddrTypeIPv4
s.reqMap["port"] = dstPort
s.reqUpdated = true
return utils.LSMActionNext
}
func (s *socksStream) parseSocks4ReqUserId() utils.LSMAction {
userIdSlice, ok := s.reqBuf.GetUntil([]byte("\x00"), true, true)
if !ok {
return utils.LSMActionPause
}
userId := string(userIdSlice[:len(userIdSlice)-1])
s.reqMap["auth"] = analyzer.PropMap{
"user_id": userId,
}
s.reqUpdated = true
return utils.LSMActionNext
}
func (s *socksStream) parseSocks4ReqHostname() utils.LSMAction {
// Only Socks4a support hostname
if s.version != Socks4A {
return utils.LSMActionNext
}
hostnameSlice, ok := s.reqBuf.GetUntil([]byte("\x00"), true, true)
if !ok {
return utils.LSMActionPause
}
hostname := string(hostnameSlice[:len(hostnameSlice)-1])
s.reqMap["addr"] = hostname
s.reqMap["addr_type"] = Socks5AddrTypeDomain
s.reqUpdated = true
return utils.LSMActionNext
}
func (s *socksStream) parseSocks4RespPacket() utils.LSMAction {
pkt, ok := s.respBuf.Get(7, true)
if !ok {
return utils.LSMActionPause
}
if pkt[0] != Socks4ReqGranted &&
pkt[0] != Socks4ReqRejectOrFailed &&
pkt[0] != Socks4ReqRejectIdentd &&
pkt[0] != Socks4ReqRejectUser {
return utils.LSMActionCancel
}
dstPort := uint16(pkt[1])<<8 | uint16(pkt[2])
dstIp := net.IPv4(pkt[3], pkt[4], pkt[5], pkt[6]).String()
s.respMap = analyzer.PropMap{
"rep": pkt[0],
"addr": dstIp,
"addr_type": Socks5AddrTypeIPv4,
"port": dstPort,
}
s.respUpdated = true
return utils.LSMActionNext
}
func (s *socksStream) socksVersion() int {
switch s.version {
case Socks4, Socks4A:
return Socks4Version
case Socks5:
return Socks5Version
default:
return SocksInvalid
}
}

147
analyzer/tcp/ssh.go Normal file
View File

@@ -0,0 +1,147 @@
package tcp
import (
"strings"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
)
var _ analyzer.TCPAnalyzer = (*SSHAnalyzer)(nil)
type SSHAnalyzer struct{}
func (a *SSHAnalyzer) Name() string {
return "ssh"
}
func (a *SSHAnalyzer) Limit() int {
return 1024
}
func (a *SSHAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream {
return newSSHStream(logger)
}
type sshStream struct {
logger analyzer.Logger
clientBuf *utils.ByteBuffer
clientMap analyzer.PropMap
clientUpdated bool
clientLSM *utils.LinearStateMachine
clientDone bool
serverBuf *utils.ByteBuffer
serverMap analyzer.PropMap
serverUpdated bool
serverLSM *utils.LinearStateMachine
serverDone bool
}
func newSSHStream(logger analyzer.Logger) *sshStream {
s := &sshStream{logger: logger, clientBuf: &utils.ByteBuffer{}, serverBuf: &utils.ByteBuffer{}}
s.clientLSM = utils.NewLinearStateMachine(
s.parseClientExchangeLine,
)
s.serverLSM = utils.NewLinearStateMachine(
s.parseServerExchangeLine,
)
return s
}
func (s *sshStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, done bool) {
if skip != 0 {
return nil, true
}
if len(data) == 0 {
return nil, false
}
var update *analyzer.PropUpdate
var cancelled bool
if rev {
s.serverBuf.Append(data)
s.serverUpdated = false
cancelled, s.serverDone = s.serverLSM.Run()
if s.serverUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{"server": s.serverMap},
}
s.serverUpdated = false
}
} else {
s.clientBuf.Append(data)
s.clientUpdated = false
cancelled, s.clientDone = s.clientLSM.Run()
if s.clientUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{"client": s.clientMap},
}
s.clientUpdated = false
}
}
return update, cancelled || (s.clientDone && s.serverDone)
}
// parseExchangeLine parses the SSH Protocol Version Exchange string.
// See RFC 4253, section 4.2.
// "SSH-protoversion-softwareversion SP comments CR LF"
// The "comments" part (along with the SP) is optional.
func (s *sshStream) parseExchangeLine(buf *utils.ByteBuffer) (utils.LSMAction, analyzer.PropMap) {
// Find the end of the line
line, ok := buf.GetUntil([]byte("\r\n"), true, true)
if !ok {
// No end of line yet, but maybe we just need more data
return utils.LSMActionPause, nil
}
if !strings.HasPrefix(string(line), "SSH-") {
// Not SSH
return utils.LSMActionCancel, nil
}
fields := strings.Fields(string(line[:len(line)-2])) // Strip \r\n
if len(fields) < 1 || len(fields) > 2 {
// Invalid line
return utils.LSMActionCancel, nil
}
sshFields := strings.SplitN(fields[0], "-", 3)
if len(sshFields) != 3 {
// Invalid SSH version format
return utils.LSMActionCancel, nil
}
sMap := analyzer.PropMap{
"protocol": sshFields[1],
"software": sshFields[2],
}
if len(fields) == 2 {
sMap["comments"] = fields[1]
}
return utils.LSMActionNext, sMap
}
func (s *sshStream) parseClientExchangeLine() utils.LSMAction {
action, sMap := s.parseExchangeLine(s.clientBuf)
if action == utils.LSMActionNext {
s.clientMap = sMap
s.clientUpdated = true
}
return action
}
func (s *sshStream) parseServerExchangeLine() utils.LSMAction {
action, sMap := s.parseExchangeLine(s.serverBuf)
if action == utils.LSMActionNext {
s.serverMap = sMap
s.serverUpdated = true
}
return action
}
func (s *sshStream) Close(limited bool) *analyzer.PropUpdate {
s.clientBuf.Reset()
s.serverBuf.Reset()
s.clientMap = nil
s.serverMap = nil
return nil
}

226
analyzer/tcp/tls.go Normal file
View File

@@ -0,0 +1,226 @@
package tcp
import (
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/analyzer/internal"
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
)
var _ analyzer.TCPAnalyzer = (*TLSAnalyzer)(nil)
type TLSAnalyzer struct{}
func (a *TLSAnalyzer) Name() string {
return "tls"
}
func (a *TLSAnalyzer) Limit() int {
return 8192
}
func (a *TLSAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream {
return newTLSStream(logger)
}
type tlsStream struct {
logger analyzer.Logger
reqBuf *utils.ByteBuffer
reqMap analyzer.PropMap
reqUpdated bool
reqLSM *utils.LinearStateMachine
reqDone bool
respBuf *utils.ByteBuffer
respMap analyzer.PropMap
respUpdated bool
respLSM *utils.LinearStateMachine
respDone bool
clientHelloLen int
serverHelloLen int
}
func newTLSStream(logger analyzer.Logger) *tlsStream {
s := &tlsStream{logger: logger, reqBuf: &utils.ByteBuffer{}, respBuf: &utils.ByteBuffer{}}
s.reqLSM = utils.NewLinearStateMachine(
s.tlsClientHelloPreprocess,
s.parseClientHelloData,
)
s.respLSM = utils.NewLinearStateMachine(
s.tlsServerHelloPreprocess,
s.parseServerHelloData,
)
return s
}
func (s *tlsStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, done bool) {
if skip != 0 {
return nil, true
}
if len(data) == 0 {
return nil, false
}
var update *analyzer.PropUpdate
var cancelled bool
if rev {
s.respBuf.Append(data)
s.respUpdated = false
cancelled, s.respDone = s.respLSM.Run()
if s.respUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{"resp": s.respMap},
}
s.respUpdated = false
}
} else {
s.reqBuf.Append(data)
s.reqUpdated = false
cancelled, s.reqDone = s.reqLSM.Run()
if s.reqUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{"req": s.reqMap},
}
s.reqUpdated = false
}
}
return update, cancelled || (s.reqDone && s.respDone)
}
// tlsClientHelloPreprocess validates ClientHello message.
//
// During validation, message header and first handshake header may be removed
// from `s.reqBuf`.
func (s *tlsStream) tlsClientHelloPreprocess() utils.LSMAction {
// headers size: content type (1 byte) + legacy protocol version (2 bytes) +
// + content length (2 bytes) + message type (1 byte) +
// + handshake length (3 bytes)
const headersSize = 9
// minimal data size: protocol version (2 bytes) + random (32 bytes) +
// + session ID (1 byte) + cipher suites (4 bytes) +
// + compression methods (2 bytes) + no extensions
const minDataSize = 41
header, ok := s.reqBuf.Get(headersSize, true)
if !ok {
// not a full header yet
return utils.LSMActionPause
}
if header[0] != internal.RecordTypeHandshake || header[5] != internal.TypeClientHello {
return utils.LSMActionCancel
}
s.clientHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
if s.clientHelloLen < minDataSize {
return utils.LSMActionCancel
}
// TODO: something is missing. See:
// const messageHeaderSize = 4
// fullMessageLen := int(header[3])<<8 | int(header[4])
// msgNo := fullMessageLen / int(messageHeaderSize+s.serverHelloLen)
// if msgNo != 1 {
// // what here?
// }
// if messageNo != int(messageNo) {
// // what here?
// }
return utils.LSMActionNext
}
// tlsServerHelloPreprocess validates ServerHello message.
//
// During validation, message header and first handshake header may be removed
// from `s.reqBuf`.
func (s *tlsStream) tlsServerHelloPreprocess() utils.LSMAction {
// header size: content type (1 byte) + legacy protocol version (2 byte) +
// + content length (2 byte) + message type (1 byte) +
// + handshake length (3 byte)
const headersSize = 9
// minimal data size: server version (2 byte) + random (32 byte) +
// + session ID (>=1 byte) + cipher suite (2 byte) +
// + compression method (1 byte) + no extensions
const minDataSize = 38
header, ok := s.respBuf.Get(headersSize, true)
if !ok {
// not a full header yet
return utils.LSMActionPause
}
if header[0] != internal.RecordTypeHandshake || header[5] != internal.TypeServerHello {
return utils.LSMActionCancel
}
s.serverHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
if s.serverHelloLen < minDataSize {
return utils.LSMActionCancel
}
// TODO: something is missing. See example:
// const messageHeaderSize = 4
// fullMessageLen := int(header[3])<<8 | int(header[4])
// msgNo := fullMessageLen / int(messageHeaderSize+s.serverHelloLen)
// if msgNo != 1 {
// // what here?
// }
// if messageNo != int(messageNo) {
// // what here?
// }
return utils.LSMActionNext
}
// parseClientHelloData converts valid ClientHello message data (without
// headers) into `analyzer.PropMap`.
//
// Parsing error may leave `s.reqBuf` in an unusable state.
func (s *tlsStream) parseClientHelloData() utils.LSMAction {
chBuf, ok := s.reqBuf.GetSubBuffer(s.clientHelloLen, true)
if !ok {
// Not a full client hello yet
return utils.LSMActionPause
}
m := internal.ParseTLSClientHelloMsgData(chBuf)
if m == nil {
return utils.LSMActionCancel
} else {
s.reqUpdated = true
s.reqMap = m
return utils.LSMActionNext
}
}
// parseServerHelloData converts valid ServerHello message data (without
// headers) into `analyzer.PropMap`.
//
// Parsing error may leave `s.respBuf` in an unusable state.
func (s *tlsStream) parseServerHelloData() utils.LSMAction {
shBuf, ok := s.respBuf.GetSubBuffer(s.serverHelloLen, true)
if !ok {
// Not a full server hello yet
return utils.LSMActionPause
}
m := internal.ParseTLSServerHelloMsgData(shBuf)
if m == nil {
return utils.LSMActionCancel
} else {
s.respUpdated = true
s.respMap = m
return utils.LSMActionNext
}
}
func (s *tlsStream) Close(limited bool) *analyzer.PropUpdate {
s.reqBuf.Reset()
s.respBuf.Reset()
s.reqMap = nil
s.respMap = nil
return nil
}

69
analyzer/tcp/tls_test.go Normal file
View File

@@ -0,0 +1,69 @@
package tcp
import (
"reflect"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
func TestTlsStreamParsing_ClientHello(t *testing.T) {
// example packet taken from <https://tls12.xargs.org/#client-hello/annotated>
clientHello := []byte{
0x16, 0x03, 0x01, 0x00, 0xa5, 0x01, 0x00, 0x00, 0xa1, 0x03, 0x03, 0x00,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c,
0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x00, 0x00, 0x20, 0xcc, 0xa8,
0xcc, 0xa9, 0xc0, 0x2f, 0xc0, 0x30, 0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13,
0xc0, 0x09, 0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9c, 0x00, 0x9d, 0x00, 0x2f,
0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, 0x01, 0x00, 0x00, 0x58, 0x00, 0x00,
0x00, 0x18, 0x00, 0x16, 0x00, 0x00, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70,
0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e,
0x65, 0x74, 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
0x0a, 0x00, 0x0a, 0x00, 0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00,
0x19, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, 0x0d, 0x00, 0x12, 0x00,
0x10, 0x04, 0x01, 0x04, 0x03, 0x05, 0x01, 0x05, 0x03, 0x06, 0x01, 0x06,
0x03, 0x02, 0x01, 0x02, 0x03, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x12,
0x00, 0x00,
}
want := analyzer.PropMap{
"ciphers": []uint16{52392, 52393, 49199, 49200, 49195, 49196, 49171, 49161, 49172, 49162, 156, 157, 47, 53, 49170, 10},
"compression": []uint8{0},
"random": []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31},
"session": []uint8{},
"sni": "example.ulfheim.net",
"version": uint16(771),
}
s := newTLSStream(nil)
u, _ := s.Feed(false, false, false, 0, clientHello)
got := u.M.Get("req")
if !reflect.DeepEqual(got, want) {
t.Errorf("%d B parsed = %v, want %v", len(clientHello), got, want)
}
}
func TestTlsStreamParsing_ServerHello(t *testing.T) {
// example packet taken from <https://tls12.xargs.org/#server-hello/annotated>
serverHello := []byte{
0x16, 0x03, 0x03, 0x00, 0x31, 0x02, 0x00, 0x00, 0x2d, 0x03, 0x03, 0x70,
0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c,
0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88,
0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x00, 0xc0, 0x13, 0x00, 0x00,
0x05, 0xff, 0x01, 0x00, 0x01, 0x00,
}
want := analyzer.PropMap{
"cipher": uint16(49171),
"compression": uint8(0),
"random": []uint8{112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143},
"session": []uint8{},
"version": uint16(771),
}
s := newTLSStream(nil)
u, _ := s.Feed(true, false, false, 0, serverHello)
got := u.M.Get("resp")
if !reflect.DeepEqual(got, want) {
t.Errorf("%d B parsed = %v, want %v", len(serverHello), got, want)
}
}

517
analyzer/tcp/trojan.go Normal file
View File

@@ -0,0 +1,517 @@
package tcp
import (
"bytes"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
var _ analyzer.TCPAnalyzer = (*TrojanAnalyzer)(nil)
// CCS stands for "Change Cipher Spec"
var ccsPattern = []byte{20, 3, 3, 0, 1, 1}
// TrojanAnalyzer uses length-based heuristics to detect Trojan traffic based on
// its "TLS-in-TLS" nature. The heuristics are trained using a decision tree with
// about 20k Trojan samples and 30k non-Trojan samples. The tree is then converted
// to code using a custom tool and inlined here (isTrojanSeq function).
// Accuracy: 1% false positive rate, 10% false negative rate.
// We do NOT recommend directly blocking all positive connections, as this may
// break legitimate TLS connections.
type TrojanAnalyzer struct{}
func (a *TrojanAnalyzer) Name() string {
return "trojan"
}
func (a *TrojanAnalyzer) Limit() int {
return 512000
}
func (a *TrojanAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream {
return newTrojanStream(logger)
}
type trojanStream struct {
logger analyzer.Logger
first bool
count bool
rev bool
seq [4]int
seqIndex int
}
func newTrojanStream(logger analyzer.Logger) *trojanStream {
return &trojanStream{logger: logger}
}
func (s *trojanStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, done bool) {
if skip != 0 {
return nil, true
}
if len(data) == 0 {
return nil, false
}
if s.first {
s.first = false
// Stop if it's not a valid TLS connection
if !(!rev && len(data) >= 3 && data[0] >= 0x16 && data[0] <= 0x17 &&
data[1] == 0x03 && data[2] <= 0x09) {
return nil, true
}
}
if !rev && !s.count && len(data) >= 6 && bytes.Equal(data[:6], ccsPattern) {
// Client Change Cipher Spec encountered, start counting
s.count = true
}
if s.count {
if rev == s.rev {
// Same direction as last time, just update the number
s.seq[s.seqIndex] += len(data)
} else {
// Different direction, bump the index
s.seqIndex += 1
if s.seqIndex == 4 {
return &analyzer.PropUpdate{
Type: analyzer.PropUpdateReplace,
M: analyzer.PropMap{
"seq": s.seq,
"yes": isTrojanSeq(s.seq),
},
}, true
}
s.seq[s.seqIndex] += len(data)
s.rev = rev
}
}
return nil, false
}
func (s *trojanStream) Close(limited bool) *analyzer.PropUpdate {
return nil
}
func isTrojanSeq(seq [4]int) bool {
length1 := seq[0]
length2 := seq[1]
length3 := seq[2]
length4 := seq[3]
if length2 <= 2431 {
if length2 <= 157 {
if length1 <= 156 {
if length3 <= 108 {
return false
} else {
return false
}
} else {
if length1 <= 892 {
if length3 <= 40 {
return false
} else {
if length3 <= 788 {
if length4 <= 185 {
if length1 <= 411 {
return true
} else {
return false
}
} else {
if length2 <= 112 {
return false
} else {
return true
}
}
} else {
if length3 <= 1346 {
if length1 <= 418 {
return false
} else {
return true
}
} else {
return false
}
}
}
} else {
if length2 <= 120 {
if length2 <= 63 {
return false
} else {
if length4 <= 653 {
return false
} else {
return false
}
}
} else {
return false
}
}
}
} else {
if length1 <= 206 {
if length1 <= 185 {
if length1 <= 171 {
return false
} else {
if length4 <= 211 {
return false
} else {
return false
}
}
} else {
if length2 <= 251 {
return true
} else {
return false
}
}
} else {
if length2 <= 286 {
if length1 <= 1123 {
if length3 <= 70 {
return false
} else {
if length1 <= 659 {
if length3 <= 370 {
return true
} else {
return false
}
} else {
if length4 <= 272 {
return false
} else {
return true
}
}
}
} else {
if length4 <= 537 {
if length2 <= 276 {
if length3 <= 1877 {
return false
} else {
return false
}
} else {
return false
}
} else {
if length1 <= 1466 {
if length1 <= 1435 {
return false
} else {
return true
}
} else {
if length2 <= 193 {
return false
} else {
return false
}
}
}
}
} else {
if length1 <= 284 {
if length1 <= 277 {
if length2 <= 726 {
return false
} else {
if length2 <= 768 {
return true
} else {
return false
}
}
} else {
if length2 <= 782 {
if length4 <= 783 {
return true
} else {
return false
}
} else {
return false
}
}
} else {
if length2 <= 492 {
if length2 <= 396 {
if length2 <= 322 {
return false
} else {
return false
}
} else {
if length4 <= 971 {
return false
} else {
return true
}
}
} else {
if length2 <= 2128 {
if length2 <= 1418 {
return false
} else {
return false
}
} else {
if length3 <= 103 {
return false
} else {
return false
}
}
}
}
}
}
}
} else {
if length2 <= 6232 {
if length3 <= 85 {
if length2 <= 3599 {
return false
} else {
if length1 <= 613 {
return false
} else {
return false
}
}
} else {
if length3 <= 220 {
if length4 <= 1173 {
if length1 <= 874 {
if length4 <= 337 {
if length4 <= 68 {
return true
} else {
return true
}
} else {
if length1 <= 667 {
return true
} else {
return true
}
}
} else {
if length3 <= 108 {
if length1 <= 1930 {
return true
} else {
return true
}
} else {
if length2 <= 5383 {
return false
} else {
return true
}
}
}
} else {
return false
}
} else {
if length1 <= 664 {
if length3 <= 411 {
if length3 <= 383 {
if length4 <= 346 {
return true
} else {
return false
}
} else {
if length1 <= 445 {
return true
} else {
return false
}
}
} else {
if length2 <= 3708 {
if length4 <= 307 {
return true
} else {
return false
}
} else {
if length2 <= 4656 {
return false
} else {
return false
}
}
}
} else {
if length1 <= 1055 {
if length3 <= 580 {
if length1 <= 724 {
return true
} else {
return false
}
} else {
if length1 <= 678 {
return false
} else {
return true
}
}
} else {
if length2 <= 5352 {
if length3 <= 1586 {
return false
} else {
return false
}
} else {
if length4 <= 2173 {
return true
} else {
return false
}
}
}
}
}
}
} else {
if length2 <= 9408 {
if length1 <= 670 {
if length4 <= 76 {
if length3 <= 175 {
return true
} else {
return true
}
} else {
if length2 <= 9072 {
if length3 <= 314 {
if length3 <= 179 {
return false
} else {
return false
}
} else {
if length4 <= 708 {
return false
} else {
return false
}
}
} else {
return true
}
}
} else {
if length1 <= 795 {
if length2 <= 6334 {
if length2 <= 6288 {
return true
} else {
return false
}
} else {
if length4 <= 6404 {
if length2 <= 8194 {
return true
} else {
return true
}
} else {
if length2 <= 8924 {
return false
} else {
return true
}
}
}
} else {
if length3 <= 732 {
if length1 <= 1397 {
if length3 <= 179 {
return false
} else {
return false
}
} else {
if length1 <= 1976 {
return false
} else {
return false
}
}
} else {
if length1 <= 2840 {
if length1 <= 2591 {
return false
} else {
return true
}
} else {
return false
}
}
}
}
} else {
if length4 <= 30 {
return false
} else {
if length2 <= 13314 {
if length4 <= 1786 {
if length2 <= 13018 {
if length4 <= 869 {
return false
} else {
return false
}
} else {
return true
}
} else {
if length3 <= 775 {
return false
} else {
return false
}
}
} else {
if length4 <= 73 {
return false
} else {
if length3 <= 640 {
if length3 <= 237 {
return false
} else {
return false
}
} else {
if length2 <= 43804 {
return false
} else {
return false
}
}
}
}
}
}
}
}
}

265
analyzer/udp/dns.go Normal file
View File

@@ -0,0 +1,265 @@
package udp
import (
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
const (
dnsUDPInvalidCountThreshold = 4
)
// DNSAnalyzer is for both DNS over UDP and TCP.
var (
_ analyzer.UDPAnalyzer = (*DNSAnalyzer)(nil)
_ analyzer.TCPAnalyzer = (*DNSAnalyzer)(nil)
)
type DNSAnalyzer struct{}
func (a *DNSAnalyzer) Name() string {
return "dns"
}
func (a *DNSAnalyzer) Limit() int {
// DNS is a stateless protocol, with unlimited amount
// of back-and-forth exchanges. Don't limit it here.
return 0
}
func (a *DNSAnalyzer) NewUDP(info analyzer.UDPInfo, logger analyzer.Logger) analyzer.UDPStream {
return &dnsUDPStream{logger: logger}
}
func (a *DNSAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream {
s := &dnsTCPStream{logger: logger, reqBuf: &utils.ByteBuffer{}, respBuf: &utils.ByteBuffer{}}
s.reqLSM = utils.NewLinearStateMachine(
s.getReqMessageLength,
s.getReqMessage,
)
s.respLSM = utils.NewLinearStateMachine(
s.getRespMessageLength,
s.getRespMessage,
)
return s
}
type dnsUDPStream struct {
logger analyzer.Logger
invalidCount int
}
func (s *dnsUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done bool) {
m := parseDNSMessage(data)
// To allow non-DNS UDP traffic to get offloaded,
// we consider a UDP stream invalid and "done" if
// it has more than a certain number of consecutive
// packets that are not valid DNS messages.
if m == nil {
s.invalidCount++
return nil, s.invalidCount >= dnsUDPInvalidCountThreshold
}
s.invalidCount = 0 // Reset invalid count on valid DNS message
return &analyzer.PropUpdate{
Type: analyzer.PropUpdateReplace,
M: m,
}, false
}
func (s *dnsUDPStream) Close(limited bool) *analyzer.PropUpdate {
return nil
}
type dnsTCPStream struct {
logger analyzer.Logger
reqBuf *utils.ByteBuffer
reqMap analyzer.PropMap
reqUpdated bool
reqLSM *utils.LinearStateMachine
reqDone bool
respBuf *utils.ByteBuffer
respMap analyzer.PropMap
respUpdated bool
respLSM *utils.LinearStateMachine
respDone bool
reqMsgLen int
respMsgLen int
}
func (s *dnsTCPStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, done bool) {
if skip != 0 {
return nil, true
}
if len(data) == 0 {
return nil, false
}
var update *analyzer.PropUpdate
var cancelled bool
if rev {
s.respBuf.Append(data)
s.respUpdated = false
cancelled, s.respDone = s.respLSM.Run()
if s.respUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateReplace,
M: s.respMap,
}
s.respUpdated = false
}
} else {
s.reqBuf.Append(data)
s.reqUpdated = false
cancelled, s.reqDone = s.reqLSM.Run()
if s.reqUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateReplace,
M: s.reqMap,
}
s.reqUpdated = false
}
}
return update, cancelled || (s.reqDone && s.respDone)
}
func (s *dnsTCPStream) Close(limited bool) *analyzer.PropUpdate {
s.reqBuf.Reset()
s.respBuf.Reset()
s.reqMap = nil
s.respMap = nil
return nil
}
func (s *dnsTCPStream) getReqMessageLength() utils.LSMAction {
bs, ok := s.reqBuf.Get(2, true)
if !ok {
return utils.LSMActionPause
}
s.reqMsgLen = int(bs[0])<<8 | int(bs[1])
return utils.LSMActionNext
}
func (s *dnsTCPStream) getRespMessageLength() utils.LSMAction {
bs, ok := s.respBuf.Get(2, true)
if !ok {
return utils.LSMActionPause
}
s.respMsgLen = int(bs[0])<<8 | int(bs[1])
return utils.LSMActionNext
}
func (s *dnsTCPStream) getReqMessage() utils.LSMAction {
bs, ok := s.reqBuf.Get(s.reqMsgLen, true)
if !ok {
return utils.LSMActionPause
}
m := parseDNSMessage(bs)
if m == nil {
// Invalid DNS message
return utils.LSMActionCancel
}
s.reqMap = m
s.reqUpdated = true
return utils.LSMActionReset
}
func (s *dnsTCPStream) getRespMessage() utils.LSMAction {
bs, ok := s.respBuf.Get(s.respMsgLen, true)
if !ok {
return utils.LSMActionPause
}
m := parseDNSMessage(bs)
if m == nil {
// Invalid DNS message
return utils.LSMActionCancel
}
s.respMap = m
s.respUpdated = true
return utils.LSMActionReset
}
func parseDNSMessage(msg []byte) analyzer.PropMap {
dns := &layers.DNS{}
err := dns.DecodeFromBytes(msg, gopacket.NilDecodeFeedback)
if err != nil {
// Not a DNS packet
return nil
}
m := analyzer.PropMap{
"id": dns.ID,
"qr": dns.QR,
"opcode": dns.OpCode,
"aa": dns.AA,
"tc": dns.TC,
"rd": dns.RD,
"ra": dns.RA,
"z": dns.Z,
"rcode": dns.ResponseCode,
}
if len(dns.Questions) > 0 {
mQuestions := make([]analyzer.PropMap, len(dns.Questions))
for i, q := range dns.Questions {
mQuestions[i] = analyzer.PropMap{
"name": string(q.Name),
"type": q.Type,
"class": q.Class,
}
}
m["questions"] = mQuestions
}
if len(dns.Answers) > 0 {
mAnswers := make([]analyzer.PropMap, len(dns.Answers))
for i, rr := range dns.Answers {
mAnswers[i] = dnsRRToPropMap(rr)
}
m["answers"] = mAnswers
}
if len(dns.Authorities) > 0 {
mAuthorities := make([]analyzer.PropMap, len(dns.Authorities))
for i, rr := range dns.Authorities {
mAuthorities[i] = dnsRRToPropMap(rr)
}
m["authorities"] = mAuthorities
}
if len(dns.Additionals) > 0 {
mAdditionals := make([]analyzer.PropMap, len(dns.Additionals))
for i, rr := range dns.Additionals {
mAdditionals[i] = dnsRRToPropMap(rr)
}
m["additionals"] = mAdditionals
}
return m
}
func dnsRRToPropMap(rr layers.DNSResourceRecord) analyzer.PropMap {
m := analyzer.PropMap{
"name": string(rr.Name),
"type": rr.Type,
"class": rr.Class,
"ttl": rr.TTL,
}
switch rr.Type {
// These are not everything, but is
// all we decided to support for now.
case layers.DNSTypeA:
m["a"] = rr.IP.String()
case layers.DNSTypeAAAA:
m["aaaa"] = rr.IP.String()
case layers.DNSTypeNS:
m["ns"] = string(rr.NS)
case layers.DNSTypeCNAME:
m["cname"] = string(rr.CNAME)
case layers.DNSTypePTR:
m["ptr"] = string(rr.PTR)
case layers.DNSTypeTXT:
m["txt"] = utils.ByteSlicesToStrings(rr.TXTs)
case layers.DNSTypeMX:
m["mx"] = string(rr.MX.Name)
}
return m
}

View File

@@ -0,0 +1,31 @@
Author:: Cuong Manh Le <cuong.manhle.vn@gmail.com>
Copyright:: Copyright (c) 2023, Cuong Manh Le
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
* Neither the name of the @organization@ nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL LE MANH CUONG
BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1 @@
The code here is from https://github.com/cuonglm/quicsni with various modifications.

View File

@@ -0,0 +1,105 @@
package quic
import (
"bytes"
"encoding/binary"
"errors"
"io"
"github.com/quic-go/quic-go/quicvarint"
)
// The Header represents a QUIC header.
type Header struct {
Type uint8
Version uint32
SrcConnectionID []byte
DestConnectionID []byte
Length int64
Token []byte
}
// ParseInitialHeader parses the initial packet of a QUIC connection,
// return the initial header and number of bytes read so far.
func ParseInitialHeader(data []byte) (*Header, int64, error) {
br := bytes.NewReader(data)
hdr, err := parseLongHeader(br)
if err != nil {
return nil, 0, err
}
n := int64(len(data) - br.Len())
return hdr, n, nil
}
func parseLongHeader(b *bytes.Reader) (*Header, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
}
h := &Header{}
ver, err := beUint32(b)
if err != nil {
return nil, err
}
h.Version = ver
if h.Version != 0 && typeByte&0x40 == 0 {
return nil, errors.New("not a QUIC packet")
}
destConnIDLen, err := b.ReadByte()
if err != nil {
return nil, err
}
h.DestConnectionID = make([]byte, int(destConnIDLen))
if err := readConnectionID(b, h.DestConnectionID); err != nil {
return nil, err
}
srcConnIDLen, err := b.ReadByte()
if err != nil {
return nil, err
}
h.SrcConnectionID = make([]byte, int(srcConnIDLen))
if err := readConnectionID(b, h.SrcConnectionID); err != nil {
return nil, err
}
initialPacketType := byte(0b00)
if h.Version == V2 {
initialPacketType = 0b01
}
if (typeByte >> 4 & 0b11) == initialPacketType {
tokenLen, err := quicvarint.Read(b)
if err != nil {
return nil, err
}
if tokenLen > uint64(b.Len()) {
return nil, io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return nil, err
}
}
pl, err := quicvarint.Read(b)
if err != nil {
return nil, err
}
h.Length = int64(pl)
return h, err
}
func readConnectionID(r io.Reader, cid []byte) error {
_, err := io.ReadFull(r, cid)
if err == io.ErrUnexpectedEOF {
return io.EOF
}
return nil
}
func beUint32(r io.Reader) (uint32, error) {
b := make([]byte, 4)
if _, err := io.ReadFull(r, b); err != nil {
return 0, err
}
return binary.BigEndian.Uint32(b), nil
}

View File

@@ -0,0 +1,193 @@
package quic
import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"hash"
"golang.org/x/crypto/chacha20"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/hkdf"
)
// NewProtectionKey creates a new ProtectionKey.
func NewProtectionKey(suite uint16, secret []byte, v uint32) (*ProtectionKey, error) {
return newProtectionKey(suite, secret, v)
}
// NewInitialProtectionKey is like NewProtectionKey, but the returned protection key
// is used for encrypt/decrypt Initial Packet only.
//
// See: https://datatracker.ietf.org/doc/html/draft-ietf-quic-tls-32#name-initial-secrets
func NewInitialProtectionKey(secret []byte, v uint32) (*ProtectionKey, error) {
return NewProtectionKey(tls.TLS_AES_128_GCM_SHA256, secret, v)
}
// NewPacketProtector creates a new PacketProtector.
func NewPacketProtector(key *ProtectionKey) *PacketProtector {
return &PacketProtector{key: key}
}
// PacketProtector is used for protecting a QUIC packet.
//
// See: https://www.rfc-editor.org/rfc/rfc9001.html#name-packet-protection
type PacketProtector struct {
key *ProtectionKey
}
// UnProtect decrypts a QUIC packet.
func (pp *PacketProtector) UnProtect(packet []byte, pnOffset, pnMax int64) ([]byte, error) {
if isLongHeader(packet[0]) && int64(len(packet)) < pnOffset+4+16 {
return nil, errors.New("packet with long header is too small")
}
// https://www.rfc-editor.org/rfc/rfc9001.html#name-header-protection-sample
sampleOffset := pnOffset + 4
sample := packet[sampleOffset : sampleOffset+16]
// https://www.rfc-editor.org/rfc/rfc9001.html#name-header-protection-applicati
mask := pp.key.headerProtection(sample)
if isLongHeader(packet[0]) {
// Long header: 4 bits masked
packet[0] ^= mask[0] & 0x0f
} else {
// Short header: 5 bits masked
packet[0] ^= mask[0] & 0x1f
}
pnLen := packet[0]&0x3 + 1
pn := int64(0)
for i := uint8(0); i < pnLen; i++ {
packet[pnOffset:][i] ^= mask[1+i]
pn = (pn << 8) | int64(packet[pnOffset:][i])
}
pn = decodePacketNumber(pnMax, pn, pnLen)
hdr := packet[:pnOffset+int64(pnLen)]
payload := packet[pnOffset:][pnLen:]
dec, err := pp.key.aead.Open(payload[:0], pp.key.nonce(pn), payload, hdr)
if err != nil {
return nil, fmt.Errorf("decryption failed: %w", err)
}
return dec, nil
}
// ProtectionKey is the key used to protect a QUIC packet.
type ProtectionKey struct {
aead cipher.AEAD
headerProtection func(sample []byte) (mask []byte)
iv []byte
}
// https://datatracker.ietf.org/doc/html/draft-ietf-quic-tls-32#name-aead-usage
//
// "The 62 bits of the reconstructed QUIC packet number in network byte order are
// left-padded with zeros to the size of the IV. The exclusive OR of the padded
// packet number and the IV forms the AEAD nonce."
func (pk *ProtectionKey) nonce(pn int64) []byte {
nonce := make([]byte, len(pk.iv))
binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(pn))
for i := range pk.iv {
nonce[i] ^= pk.iv[i]
}
return nonce
}
func newProtectionKey(suite uint16, secret []byte, v uint32) (*ProtectionKey, error) {
switch suite {
case tls.TLS_AES_128_GCM_SHA256:
key := hkdfExpandLabel(crypto.SHA256.New, secret, keyLabel(v), nil, 16)
c, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(c)
if err != nil {
panic(err)
}
iv := hkdfExpandLabel(crypto.SHA256.New, secret, ivLabel(v), nil, aead.NonceSize())
hpKey := hkdfExpandLabel(crypto.SHA256.New, secret, headerProtectionLabel(v), nil, 16)
hp, err := aes.NewCipher(hpKey)
if err != nil {
panic(err)
}
k := &ProtectionKey{}
k.aead = aead
// https://datatracker.ietf.org/doc/html/draft-ietf-quic-tls-32#name-aes-based-header-protection
k.headerProtection = func(sample []byte) []byte {
mask := make([]byte, hp.BlockSize())
hp.Encrypt(mask, sample)
return mask
}
k.iv = iv
return k, nil
case tls.TLS_CHACHA20_POLY1305_SHA256:
key := hkdfExpandLabel(crypto.SHA256.New, secret, keyLabel(v), nil, chacha20poly1305.KeySize)
aead, err := chacha20poly1305.New(key)
if err != nil {
return nil, err
}
iv := hkdfExpandLabel(crypto.SHA256.New, secret, ivLabel(v), nil, aead.NonceSize())
hpKey := hkdfExpandLabel(sha256.New, secret, headerProtectionLabel(v), nil, chacha20.KeySize)
k := &ProtectionKey{}
k.aead = aead
// https://datatracker.ietf.org/doc/html/draft-ietf-quic-tls-32#name-chacha20-based-header-prote
k.headerProtection = func(sample []byte) []byte {
nonce := sample[4:16]
c, err := chacha20.NewUnauthenticatedCipher(hpKey, nonce)
if err != nil {
panic(err)
}
c.SetCounter(binary.LittleEndian.Uint32(sample[:4]))
mask := make([]byte, 5)
c.XORKeyStream(mask, mask)
return mask
}
k.iv = iv
return k, nil
}
return nil, errors.New("not supported cipher suite")
}
// decodePacketNumber decode the packet number after header protection removed.
//
// See: https://datatracker.ietf.org/doc/html/draft-ietf-quic-transport-32#section-appendix.a
func decodePacketNumber(largest, truncated int64, nbits uint8) int64 {
expected := largest + 1
win := int64(1 << (nbits * 8))
hwin := win / 2
mask := win - 1
candidate := (expected &^ mask) | truncated
switch {
case candidate <= expected-hwin && candidate < (1<<62)-win:
return candidate + win
case candidate > expected+hwin && candidate >= win:
return candidate - win
}
return candidate
}
// Copied from crypto/tls/key_schedule.go.
func hkdfExpandLabel(hash func() hash.Hash, secret []byte, label string, context []byte, length int) []byte {
var hkdfLabel cryptobyte.Builder
hkdfLabel.AddUint16(uint16(length))
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte("tls13 "))
b.AddBytes([]byte(label))
})
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(context)
})
out := make([]byte, length)
n, err := hkdf.Expand(hash, secret, hkdfLabel.BytesOrPanic()).Read(out)
if err != nil || n != length {
panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
}
return out
}

View File

@@ -0,0 +1,94 @@
package quic
import (
"bytes"
"crypto"
"crypto/tls"
"encoding/hex"
"strings"
"testing"
"unicode"
"golang.org/x/crypto/hkdf"
)
func TestInitialPacketProtector_UnProtect(t *testing.T) {
// https://datatracker.ietf.org/doc/html/draft-ietf-quic-tls-32#name-server-initial
protect := mustHexDecodeString(`
c7ff0000200008f067a5502a4262b500 4075fb12ff07823a5d24534d906ce4c7
6782a2167e3479c0f7f6395dc2c91676 302fe6d70bb7cbeb117b4ddb7d173498
44fd61dae200b8338e1b932976b61d91 e64a02e9e0ee72e3a6f63aba4ceeeec5
be2f24f2d86027572943533846caa13e 6f163fb257473d0eda5047360fd4a47e
fd8142fafc0f76
`)
unProtect := mustHexDecodeString(`
02000000000600405a020000560303ee fce7f7b37ba1d1632e96677825ddf739
88cfc79825df566dc5430b9a045a1200 130100002e00330024001d00209d3c94
0d89690b84d08a60993c144eca684d10 81287c834d5311bcf32bb9da1a002b00
020304
`)
connID := mustHexDecodeString(`8394c8f03e515708`)
packet := append([]byte{}, protect...)
hdr, offset, err := ParseInitialHeader(packet)
if err != nil {
t.Fatal(err)
}
initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(hdr.Version))
serverSecret := hkdfExpandLabel(crypto.SHA256.New, initialSecret, "server in", []byte{}, crypto.SHA256.Size())
key, err := NewInitialProtectionKey(serverSecret, hdr.Version)
if err != nil {
t.Fatal(err)
}
pp := NewPacketProtector(key)
got, err := pp.UnProtect(protect, offset, 1)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, unProtect) {
t.Error("UnProtect returns wrong result")
}
}
func TestPacketProtectorShortHeader_UnProtect(t *testing.T) {
// https://datatracker.ietf.org/doc/html/draft-ietf-quic-tls-32#name-chacha20-poly1305-short-hea
protect := mustHexDecodeString(`4cfe4189655e5cd55c41f69080575d7999c25a5bfb`)
unProtect := mustHexDecodeString(`01`)
hdr := mustHexDecodeString(`4200bff4`)
secret := mustHexDecodeString(`9ac312a7f877468ebe69422748ad00a1 5443f18203a07d6060f688f30f21632b`)
k, err := NewProtectionKey(tls.TLS_CHACHA20_POLY1305_SHA256, secret, V1)
if err != nil {
t.Fatal(err)
}
pnLen := int(hdr[0]&0x03) + 1
offset := len(hdr) - pnLen
pp := NewPacketProtector(k)
got, err := pp.UnProtect(protect, int64(offset), 654360564)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, unProtect) {
t.Error("UnProtect returns wrong result")
}
}
func mustHexDecodeString(s string) []byte {
b, err := hex.DecodeString(normalizeHex(s))
if err != nil {
panic(err)
}
return b
}
func normalizeHex(s string) string {
return strings.Map(func(c rune) rune {
if unicode.IsSpace(c) {
return -1
}
return c
}, s)
}

View File

@@ -0,0 +1,122 @@
package quic
import (
"bytes"
"crypto"
"errors"
"fmt"
"io"
"sort"
"github.com/quic-go/quic-go/quicvarint"
"golang.org/x/crypto/hkdf"
)
func ReadCryptoPayload(packet []byte) ([]byte, error) {
hdr, offset, err := ParseInitialHeader(packet)
if err != nil {
return nil, err
}
// Some sanity checks
if hdr.Version != V1 && hdr.Version != V2 {
return nil, fmt.Errorf("unsupported version: %x", hdr.Version)
}
if offset == 0 || hdr.Length == 0 {
return nil, errors.New("invalid packet")
}
initialSecret := hkdf.Extract(crypto.SHA256.New, hdr.DestConnectionID, getSalt(hdr.Version))
clientSecret := hkdfExpandLabel(crypto.SHA256.New, initialSecret, "client in", []byte{}, crypto.SHA256.Size())
key, err := NewInitialProtectionKey(clientSecret, hdr.Version)
if err != nil {
return nil, fmt.Errorf("NewInitialProtectionKey: %w", err)
}
pp := NewPacketProtector(key)
// https://datatracker.ietf.org/doc/html/draft-ietf-quic-tls-32#name-client-initial
//
// "The unprotected header includes the connection ID and a 4-byte packet number encoding for a packet number of 2"
if int64(len(packet)) < offset+hdr.Length {
return nil, fmt.Errorf("packet is too short: %d < %d", len(packet), offset+hdr.Length)
}
unProtectedPayload, err := pp.UnProtect(packet[:offset+hdr.Length], offset, 2)
if err != nil {
return nil, err
}
frs, err := extractCryptoFrames(bytes.NewReader(unProtectedPayload))
if err != nil {
return nil, err
}
data := assembleCryptoFrames(frs)
if data == nil {
return nil, errors.New("unable to assemble crypto frames")
}
return data, nil
}
const (
paddingFrameType = 0x00
pingFrameType = 0x01
cryptoFrameType = 0x06
)
type cryptoFrame struct {
Offset int64
Data []byte
}
func extractCryptoFrames(r *bytes.Reader) ([]cryptoFrame, error) {
var frames []cryptoFrame
for r.Len() > 0 {
typ, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
if typ == paddingFrameType || typ == pingFrameType {
continue
}
if typ != cryptoFrameType {
return nil, fmt.Errorf("encountered unexpected frame type: %d", typ)
}
var frame cryptoFrame
offset, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
frame.Offset = int64(offset)
dataLen, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
frame.Data = make([]byte, dataLen)
if _, err := io.ReadFull(r, frame.Data); err != nil {
return nil, err
}
frames = append(frames, frame)
}
return frames, nil
}
// assembleCryptoFrames assembles multiple crypto frames into a single slice (if possible).
// It returns an error if the frames cannot be assembled. This can happen if the frames are not contiguous.
func assembleCryptoFrames(frames []cryptoFrame) []byte {
if len(frames) == 0 {
return nil
}
if len(frames) == 1 {
return frames[0].Data
}
// sort the frames by offset
sort.Slice(frames, func(i, j int) bool { return frames[i].Offset < frames[j].Offset })
// check if the frames are contiguous
for i := 1; i < len(frames); i++ {
if frames[i].Offset != frames[i-1].Offset+int64(len(frames[i-1].Data)) {
return nil
}
}
// concatenate the frames
data := make([]byte, frames[len(frames)-1].Offset+int64(len(frames[len(frames)-1].Data)))
for _, frame := range frames {
copy(data[frame.Offset:], frame.Data)
}
return data
}

View File

@@ -0,0 +1,59 @@
package quic
const (
V1 uint32 = 0x1
V2 uint32 = 0x6b3343cf
hkdfLabelKeyV1 = "quic key"
hkdfLabelKeyV2 = "quicv2 key"
hkdfLabelIVV1 = "quic iv"
hkdfLabelIVV2 = "quicv2 iv"
hkdfLabelHPV1 = "quic hp"
hkdfLabelHPV2 = "quicv2 hp"
)
var (
quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99}
// https://www.rfc-editor.org/rfc/rfc9001.html#name-initial-secrets
quicSaltV1 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
// https://www.ietf.org/archive/id/draft-ietf-quic-v2-10.html#name-initial-salt-2
quicSaltV2 = []byte{0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb, 0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb, 0xf9, 0xbd, 0x2e, 0xd9}
)
// isLongHeader reports whether b is the first byte of a long header packet.
func isLongHeader(b byte) bool {
return b&0x80 > 0
}
func getSalt(v uint32) []byte {
switch v {
case V1:
return quicSaltV1
case V2:
return quicSaltV2
}
return quicSaltOld
}
func keyLabel(v uint32) string {
kl := hkdfLabelKeyV1
if v == V2 {
kl = hkdfLabelKeyV2
}
return kl
}
func ivLabel(v uint32) string {
ivl := hkdfLabelIVV1
if v == V2 {
ivl = hkdfLabelIVV2
}
return ivl
}
func headerProtectionLabel(v uint32) string {
if v == V2 {
return hkdfLabelHPV2
}
return hkdfLabelHPV1
}

384
analyzer/udp/openvpn.go Normal file
View File

@@ -0,0 +1,384 @@
package udp
import (
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
)
var (
_ analyzer.UDPAnalyzer = (*OpenVPNAnalyzer)(nil)
_ analyzer.TCPAnalyzer = (*OpenVPNAnalyzer)(nil)
)
var (
_ analyzer.UDPStream = (*openvpnUDPStream)(nil)
_ analyzer.TCPStream = (*openvpnTCPStream)(nil)
)
// Ref paper:
// https://www.usenix.org/system/files/sec22fall_xue-diwen.pdf
// OpenVPN Opcodes definitions from:
// https://github.com/OpenVPN/openvpn/blob/master/src/openvpn/ssl_pkt.h
const (
OpenVPNControlHardResetClientV1 = 1
OpenVPNControlHardResetServerV1 = 2
OpenVPNControlSoftResetV1 = 3
OpenVPNControlV1 = 4
OpenVPNAckV1 = 5
OpenVPNDataV1 = 6
OpenVPNControlHardResetClientV2 = 7
OpenVPNControlHardResetServerV2 = 8
OpenVPNDataV2 = 9
OpenVPNControlHardResetClientV3 = 10
OpenVPNControlWkcV1 = 11
)
const (
OpenVPNMinPktLen = 6
OpenVPNTCPPktDefaultLimit = 256
OpenVPNUDPPktDefaultLimit = 256
)
type OpenVPNAnalyzer struct{}
func (a *OpenVPNAnalyzer) Name() string {
return "openvpn"
}
func (a *OpenVPNAnalyzer) Limit() int {
return 0
}
func (a *OpenVPNAnalyzer) NewUDP(info analyzer.UDPInfo, logger analyzer.Logger) analyzer.UDPStream {
return newOpenVPNUDPStream(logger)
}
func (a *OpenVPNAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream {
return newOpenVPNTCPStream(logger)
}
type openvpnPkt struct {
pktLen uint16 // 16 bits, TCP proto only
opcode byte // 5 bits
_keyId byte // 3 bits, not used
// We don't care about the rest of the packet
// payload []byte
}
type openvpnStream struct {
logger analyzer.Logger
reqUpdated bool
reqLSM *utils.LinearStateMachine
reqDone bool
respUpdated bool
respLSM *utils.LinearStateMachine
respDone bool
rxPktCnt int
txPktCnt int
pktLimit int
reqPktParse func() (*openvpnPkt, utils.LSMAction)
respPktParse func() (*openvpnPkt, utils.LSMAction)
lastOpcode byte
}
func (o *openvpnStream) parseCtlHardResetClient() utils.LSMAction {
pkt, action := o.reqPktParse()
if action != utils.LSMActionNext {
return action
}
if pkt.opcode != OpenVPNControlHardResetClientV1 &&
pkt.opcode != OpenVPNControlHardResetClientV2 &&
pkt.opcode != OpenVPNControlHardResetClientV3 {
return utils.LSMActionCancel
}
o.lastOpcode = pkt.opcode
return utils.LSMActionNext
}
func (o *openvpnStream) parseCtlHardResetServer() utils.LSMAction {
if o.lastOpcode != OpenVPNControlHardResetClientV1 &&
o.lastOpcode != OpenVPNControlHardResetClientV2 &&
o.lastOpcode != OpenVPNControlHardResetClientV3 {
return utils.LSMActionCancel
}
pkt, action := o.respPktParse()
if action != utils.LSMActionNext {
return action
}
if pkt.opcode != OpenVPNControlHardResetServerV1 &&
pkt.opcode != OpenVPNControlHardResetServerV2 {
return utils.LSMActionCancel
}
o.lastOpcode = pkt.opcode
return utils.LSMActionNext
}
func (o *openvpnStream) parseReq() utils.LSMAction {
pkt, action := o.reqPktParse()
if action != utils.LSMActionNext {
return action
}
if pkt.opcode != OpenVPNControlSoftResetV1 &&
pkt.opcode != OpenVPNControlV1 &&
pkt.opcode != OpenVPNAckV1 &&
pkt.opcode != OpenVPNDataV1 &&
pkt.opcode != OpenVPNDataV2 &&
pkt.opcode != OpenVPNControlWkcV1 {
return utils.LSMActionCancel
}
o.txPktCnt += 1
o.reqUpdated = true
return utils.LSMActionPause
}
func (o *openvpnStream) parseResp() utils.LSMAction {
pkt, action := o.respPktParse()
if action != utils.LSMActionNext {
return action
}
if pkt.opcode != OpenVPNControlSoftResetV1 &&
pkt.opcode != OpenVPNControlV1 &&
pkt.opcode != OpenVPNAckV1 &&
pkt.opcode != OpenVPNDataV1 &&
pkt.opcode != OpenVPNDataV2 &&
pkt.opcode != OpenVPNControlWkcV1 {
return utils.LSMActionCancel
}
o.rxPktCnt += 1
o.respUpdated = true
return utils.LSMActionPause
}
type openvpnUDPStream struct {
openvpnStream
curPkt []byte
// We don't introduce `invalidCount` here to decrease the false positive rate
// invalidCount int
}
func newOpenVPNUDPStream(logger analyzer.Logger) *openvpnUDPStream {
s := &openvpnUDPStream{
openvpnStream: openvpnStream{
logger: logger,
pktLimit: OpenVPNUDPPktDefaultLimit,
},
}
s.respPktParse = s.parsePkt
s.reqPktParse = s.parsePkt
s.reqLSM = utils.NewLinearStateMachine(
s.parseCtlHardResetClient,
s.parseReq,
)
s.respLSM = utils.NewLinearStateMachine(
s.parseCtlHardResetServer,
s.parseResp,
)
return s
}
func (o *openvpnUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, d bool) {
if len(data) == 0 {
return nil, false
}
var update *analyzer.PropUpdate
var cancelled bool
o.curPkt = data
if rev {
o.respUpdated = false
cancelled, o.respDone = o.respLSM.Run()
if o.respUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateReplace,
M: analyzer.PropMap{"rx_pkt_cnt": o.rxPktCnt, "tx_pkt_cnt": o.txPktCnt},
}
o.respUpdated = false
}
} else {
o.reqUpdated = false
cancelled, o.reqDone = o.reqLSM.Run()
if o.reqUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateReplace,
M: analyzer.PropMap{"rx_pkt_cnt": o.rxPktCnt, "tx_pkt_cnt": o.txPktCnt},
}
o.reqUpdated = false
}
}
return update, cancelled || (o.reqDone && o.respDone) || o.rxPktCnt+o.txPktCnt > o.pktLimit
}
func (o *openvpnUDPStream) Close(limited bool) *analyzer.PropUpdate {
return nil
}
// Parse OpenVPN UDP packet.
func (o *openvpnUDPStream) parsePkt() (p *openvpnPkt, action utils.LSMAction) {
if o.curPkt == nil {
return nil, utils.LSMActionPause
}
if !OpenVPNCheckForValidOpcode(o.curPkt[0] >> 3) {
return nil, utils.LSMActionCancel
}
// Parse packet header
p = &openvpnPkt{}
p.opcode = o.curPkt[0] >> 3
p._keyId = o.curPkt[0] & 0x07
o.curPkt = nil
return p, utils.LSMActionNext
}
type openvpnTCPStream struct {
openvpnStream
reqBuf *utils.ByteBuffer
respBuf *utils.ByteBuffer
}
func newOpenVPNTCPStream(logger analyzer.Logger) *openvpnTCPStream {
s := &openvpnTCPStream{
openvpnStream: openvpnStream{
logger: logger,
pktLimit: OpenVPNTCPPktDefaultLimit,
},
reqBuf: &utils.ByteBuffer{},
respBuf: &utils.ByteBuffer{},
}
s.respPktParse = func() (*openvpnPkt, utils.LSMAction) {
return s.parsePkt(true)
}
s.reqPktParse = func() (*openvpnPkt, utils.LSMAction) {
return s.parsePkt(false)
}
s.reqLSM = utils.NewLinearStateMachine(
s.parseCtlHardResetClient,
s.parseReq,
)
s.respLSM = utils.NewLinearStateMachine(
s.parseCtlHardResetServer,
s.parseResp,
)
return s
}
func (o *openvpnTCPStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, d bool) {
if skip != 0 {
return nil, true
}
if len(data) == 0 {
return nil, false
}
var update *analyzer.PropUpdate
var cancelled bool
if rev {
o.respBuf.Append(data)
o.respUpdated = false
cancelled, o.respDone = o.respLSM.Run()
if o.respUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateReplace,
M: analyzer.PropMap{"rx_pkt_cnt": o.rxPktCnt, "tx_pkt_cnt": o.txPktCnt},
}
o.respUpdated = false
}
} else {
o.reqBuf.Append(data)
o.reqUpdated = false
cancelled, o.reqDone = o.reqLSM.Run()
if o.reqUpdated {
update = &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{"rx_pkt_cnt": o.rxPktCnt, "tx_pkt_cnt": o.txPktCnt},
}
o.reqUpdated = false
}
}
return update, cancelled || (o.reqDone && o.respDone) || o.rxPktCnt+o.txPktCnt > o.pktLimit
}
func (o *openvpnTCPStream) Close(limited bool) *analyzer.PropUpdate {
o.reqBuf.Reset()
o.respBuf.Reset()
return nil
}
// Parse OpenVPN TCP packet.
func (o *openvpnTCPStream) parsePkt(rev bool) (p *openvpnPkt, action utils.LSMAction) {
var buffer *utils.ByteBuffer
if rev {
buffer = o.respBuf
} else {
buffer = o.reqBuf
}
// Parse packet length
pktLen, ok := buffer.GetUint16(false, false)
if !ok {
return nil, utils.LSMActionPause
}
if pktLen < OpenVPNMinPktLen {
return nil, utils.LSMActionCancel
}
pktOp, ok := buffer.Get(3, false)
if !ok {
return nil, utils.LSMActionPause
}
if !OpenVPNCheckForValidOpcode(pktOp[2] >> 3) {
return nil, utils.LSMActionCancel
}
pkt, ok := buffer.Get(int(pktLen)+2, true)
if !ok {
return nil, utils.LSMActionPause
}
pkt = pkt[2:]
// Parse packet header
p = &openvpnPkt{}
p.pktLen = pktLen
p.opcode = pkt[0] >> 3
p._keyId = pkt[0] & 0x07
return p, utils.LSMActionNext
}
func OpenVPNCheckForValidOpcode(opcode byte) bool {
switch opcode {
case OpenVPNControlHardResetClientV1,
OpenVPNControlHardResetServerV1,
OpenVPNControlSoftResetV1,
OpenVPNControlV1,
OpenVPNAckV1,
OpenVPNDataV1,
OpenVPNControlHardResetClientV2,
OpenVPNControlHardResetServerV2,
OpenVPNDataV2,
OpenVPNControlHardResetClientV3,
OpenVPNControlWkcV1:
return true
}
return false
}

81
analyzer/udp/quic.go Normal file
View File

@@ -0,0 +1,81 @@
package udp
import (
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/analyzer/internal"
"git.difuse.io/Difuse/Mellaris/analyzer/udp/internal/quic"
"git.difuse.io/Difuse/Mellaris/analyzer/utils"
)
const (
quicInvalidCountThreshold = 4
)
var (
_ analyzer.UDPAnalyzer = (*QUICAnalyzer)(nil)
_ analyzer.UDPStream = (*quicStream)(nil)
)
type QUICAnalyzer struct{}
func (a *QUICAnalyzer) Name() string {
return "quic"
}
func (a *QUICAnalyzer) Limit() int {
return 0
}
func (a *QUICAnalyzer) NewUDP(info analyzer.UDPInfo, logger analyzer.Logger) analyzer.UDPStream {
return &quicStream{logger: logger}
}
type quicStream struct {
logger analyzer.Logger
invalidCount int
}
func (s *quicStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done bool) {
// minimal data size: protocol version (2 bytes) + random (32 bytes) +
// + session ID (1 byte) + cipher suites (4 bytes) +
// + compression methods (2 bytes) + no extensions
const minDataSize = 41
if rev {
// We don't support server direction for now
s.invalidCount++
return nil, s.invalidCount >= quicInvalidCountThreshold
}
pl, err := quic.ReadCryptoPayload(data)
if err != nil || len(pl) < 4 { // FIXME: isn't length checked inside quic.ReadCryptoPayload? Also, what about error handling?
s.invalidCount++
return nil, s.invalidCount >= quicInvalidCountThreshold
}
if pl[0] != internal.TypeClientHello {
s.invalidCount++
return nil, s.invalidCount >= quicInvalidCountThreshold
}
chLen := int(pl[1])<<16 | int(pl[2])<<8 | int(pl[3])
if chLen < minDataSize {
s.invalidCount++
return nil, s.invalidCount >= quicInvalidCountThreshold
}
m := internal.ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: pl[4:]})
if m == nil {
s.invalidCount++
return nil, s.invalidCount >= quicInvalidCountThreshold
}
return &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{"req": m},
}, true
}
func (s *quicStream) Close(limited bool) *analyzer.PropUpdate {
return nil
}

58
analyzer/udp/quic_test.go Normal file
View File

@@ -0,0 +1,58 @@
package udp
import (
"reflect"
"testing"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
func TestQuicStreamParsing_ClientHello(t *testing.T) {
// example packet taken from <https://quic.xargs.org/#client-initial-packet/annotated>
clientHello := make([]byte, 1200)
clientInitial := []byte{
0xcd, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
0x06, 0x07, 0x05, 0x63, 0x5f, 0x63, 0x69, 0x64, 0x00, 0x41, 0x03, 0x98,
0x1c, 0x36, 0xa7, 0xed, 0x78, 0x71, 0x6b, 0xe9, 0x71, 0x1b, 0xa4, 0x98,
0xb7, 0xed, 0x86, 0x84, 0x43, 0xbb, 0x2e, 0x0c, 0x51, 0x4d, 0x4d, 0x84,
0x8e, 0xad, 0xcc, 0x7a, 0x00, 0xd2, 0x5c, 0xe9, 0xf9, 0xaf, 0xa4, 0x83,
0x97, 0x80, 0x88, 0xde, 0x83, 0x6b, 0xe6, 0x8c, 0x0b, 0x32, 0xa2, 0x45,
0x95, 0xd7, 0x81, 0x3e, 0xa5, 0x41, 0x4a, 0x91, 0x99, 0x32, 0x9a, 0x6d,
0x9f, 0x7f, 0x76, 0x0d, 0xd8, 0xbb, 0x24, 0x9b, 0xf3, 0xf5, 0x3d, 0x9a,
0x77, 0xfb, 0xb7, 0xb3, 0x95, 0xb8, 0xd6, 0x6d, 0x78, 0x79, 0xa5, 0x1f,
0xe5, 0x9e, 0xf9, 0x60, 0x1f, 0x79, 0x99, 0x8e, 0xb3, 0x56, 0x8e, 0x1f,
0xdc, 0x78, 0x9f, 0x64, 0x0a, 0xca, 0xb3, 0x85, 0x8a, 0x82, 0xef, 0x29,
0x30, 0xfa, 0x5c, 0xe1, 0x4b, 0x5b, 0x9e, 0xa0, 0xbd, 0xb2, 0x9f, 0x45,
0x72, 0xda, 0x85, 0xaa, 0x3d, 0xef, 0x39, 0xb7, 0xef, 0xaf, 0xff, 0xa0,
0x74, 0xb9, 0x26, 0x70, 0x70, 0xd5, 0x0b, 0x5d, 0x07, 0x84, 0x2e, 0x49,
0xbb, 0xa3, 0xbc, 0x78, 0x7f, 0xf2, 0x95, 0xd6, 0xae, 0x3b, 0x51, 0x43,
0x05, 0xf1, 0x02, 0xaf, 0xe5, 0xa0, 0x47, 0xb3, 0xfb, 0x4c, 0x99, 0xeb,
0x92, 0xa2, 0x74, 0xd2, 0x44, 0xd6, 0x04, 0x92, 0xc0, 0xe2, 0xe6, 0xe2,
0x12, 0xce, 0xf0, 0xf9, 0xe3, 0xf6, 0x2e, 0xfd, 0x09, 0x55, 0xe7, 0x1c,
0x76, 0x8a, 0xa6, 0xbb, 0x3c, 0xd8, 0x0b, 0xbb, 0x37, 0x55, 0xc8, 0xb7,
0xeb, 0xee, 0x32, 0x71, 0x2f, 0x40, 0xf2, 0x24, 0x51, 0x19, 0x48, 0x70,
0x21, 0xb4, 0xb8, 0x4e, 0x15, 0x65, 0xe3, 0xca, 0x31, 0x96, 0x7a, 0xc8,
0x60, 0x4d, 0x40, 0x32, 0x17, 0x0d, 0xec, 0x28, 0x0a, 0xee, 0xfa, 0x09,
0x5d, 0x08, 0xb3, 0xb7, 0x24, 0x1e, 0xf6, 0x64, 0x6a, 0x6c, 0x86, 0xe5,
0xc6, 0x2c, 0xe0, 0x8b, 0xe0, 0x99,
}
copy(clientHello, clientInitial)
want := analyzer.PropMap{
"alpn": []string{"ping/1.0"},
"ciphers": []uint16{4865, 4866, 4867},
"compression": []uint8{0},
"random": []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31},
"session": []uint8{},
"sni": "example.ulfheim.net",
"supported_versions": []uint16{772},
"version": uint16(771),
}
s := quicStream{}
u, _ := s.Feed(false, clientHello)
got := u.M.Get("req")
if !reflect.DeepEqual(got, want) {
t.Errorf("%d B parsed = %v, want %v", len(clientHello), got, want)
}
}

217
analyzer/udp/wireguard.go Normal file
View File

@@ -0,0 +1,217 @@
package udp
import (
"container/ring"
"encoding/binary"
"slices"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
)
var (
_ analyzer.UDPAnalyzer = (*WireGuardAnalyzer)(nil)
_ analyzer.UDPStream = (*wireGuardUDPStream)(nil)
)
const (
wireguardUDPInvalidCountThreshold = 4
wireguardRememberedIndexCount = 6
wireguardPropKeyMessageType = "message_type"
)
const (
wireguardTypeHandshakeInitiation = 1
wireguardTypeHandshakeResponse = 2
wireguardTypeData = 4
wireguardTypeCookieReply = 3
)
const (
wireguardSizeHandshakeInitiation = 148
wireguardSizeHandshakeResponse = 92
wireguardMinSizePacketData = 32 // 16 bytes header + 16 bytes AEAD overhead
wireguardSizePacketCookieReply = 64
)
type WireGuardAnalyzer struct{}
func (a *WireGuardAnalyzer) Name() string {
return "wireguard"
}
func (a *WireGuardAnalyzer) Limit() int {
return 0
}
func (a *WireGuardAnalyzer) NewUDP(info analyzer.UDPInfo, logger analyzer.Logger) analyzer.UDPStream {
return newWireGuardUDPStream(logger)
}
type wireGuardUDPStream struct {
logger analyzer.Logger
invalidCount int
rememberedIndexes *ring.Ring
rememberedIndexesLock sync.RWMutex
}
func newWireGuardUDPStream(logger analyzer.Logger) *wireGuardUDPStream {
return &wireGuardUDPStream{
logger: logger,
rememberedIndexes: ring.New(wireguardRememberedIndexCount),
}
}
func (s *wireGuardUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done bool) {
m := s.parseWireGuardPacket(rev, data)
if m == nil {
s.invalidCount++
return nil, s.invalidCount >= wireguardUDPInvalidCountThreshold
}
s.invalidCount = 0 // Reset invalid count on valid WireGuard packet
messageType := m[wireguardPropKeyMessageType].(byte)
propUpdateType := analyzer.PropUpdateMerge
if messageType == wireguardTypeHandshakeInitiation {
propUpdateType = analyzer.PropUpdateReplace
}
return &analyzer.PropUpdate{
Type: propUpdateType,
M: m,
}, false
}
func (s *wireGuardUDPStream) Close(limited bool) *analyzer.PropUpdate {
return nil
}
func (s *wireGuardUDPStream) parseWireGuardPacket(rev bool, data []byte) analyzer.PropMap {
if len(data) < 4 {
return nil
}
if slices.Max(data[1:4]) != 0 {
return nil
}
messageType := data[0]
var propKey string
var propValue analyzer.PropMap
switch messageType {
case wireguardTypeHandshakeInitiation:
propKey = "handshake_initiation"
propValue = s.parseWireGuardHandshakeInitiation(rev, data)
case wireguardTypeHandshakeResponse:
propKey = "handshake_response"
propValue = s.parseWireGuardHandshakeResponse(rev, data)
case wireguardTypeData:
propKey = "packet_data"
propValue = s.parseWireGuardPacketData(rev, data)
case wireguardTypeCookieReply:
propKey = "packet_cookie_reply"
propValue = s.parseWireGuardPacketCookieReply(rev, data)
}
if propValue == nil {
return nil
}
m := make(analyzer.PropMap)
m[wireguardPropKeyMessageType] = messageType
m[propKey] = propValue
return m
}
func (s *wireGuardUDPStream) parseWireGuardHandshakeInitiation(rev bool, data []byte) analyzer.PropMap {
if len(data) != wireguardSizeHandshakeInitiation {
return nil
}
m := make(analyzer.PropMap)
senderIndex := binary.LittleEndian.Uint32(data[4:8])
m["sender_index"] = senderIndex
s.putSenderIndex(rev, senderIndex)
return m
}
func (s *wireGuardUDPStream) parseWireGuardHandshakeResponse(rev bool, data []byte) analyzer.PropMap {
if len(data) != wireguardSizeHandshakeResponse {
return nil
}
m := make(analyzer.PropMap)
senderIndex := binary.LittleEndian.Uint32(data[4:8])
m["sender_index"] = senderIndex
s.putSenderIndex(rev, senderIndex)
receiverIndex := binary.LittleEndian.Uint32(data[8:12])
m["receiver_index"] = receiverIndex
m["receiver_index_matched"] = s.matchReceiverIndex(rev, receiverIndex)
return m
}
func (s *wireGuardUDPStream) parseWireGuardPacketData(rev bool, data []byte) analyzer.PropMap {
if len(data) < wireguardMinSizePacketData {
return nil
}
if len(data)%16 != 0 {
// WireGuard zero padding the packet to make the length a multiple of 16
return nil
}
m := make(analyzer.PropMap)
receiverIndex := binary.LittleEndian.Uint32(data[4:8])
m["receiver_index"] = receiverIndex
m["receiver_index_matched"] = s.matchReceiverIndex(rev, receiverIndex)
m["counter"] = binary.LittleEndian.Uint64(data[8:16])
return m
}
func (s *wireGuardUDPStream) parseWireGuardPacketCookieReply(rev bool, data []byte) analyzer.PropMap {
if len(data) != wireguardSizePacketCookieReply {
return nil
}
m := make(analyzer.PropMap)
receiverIndex := binary.LittleEndian.Uint32(data[4:8])
m["receiver_index"] = receiverIndex
m["receiver_index_matched"] = s.matchReceiverIndex(rev, receiverIndex)
return m
}
type wireGuardIndex struct {
SenderIndex uint32
Reverse bool
}
func (s *wireGuardUDPStream) putSenderIndex(rev bool, senderIndex uint32) {
s.rememberedIndexesLock.Lock()
defer s.rememberedIndexesLock.Unlock()
s.rememberedIndexes.Value = &wireGuardIndex{
SenderIndex: senderIndex,
Reverse: rev,
}
s.rememberedIndexes = s.rememberedIndexes.Prev()
}
func (s *wireGuardUDPStream) matchReceiverIndex(rev bool, receiverIndex uint32) bool {
s.rememberedIndexesLock.RLock()
defer s.rememberedIndexesLock.RUnlock()
var found bool
ris := s.rememberedIndexes
for it := ris.Next(); it != ris; it = it.Next() {
if it.Value == nil {
break
}
wgidx := it.Value.(*wireGuardIndex)
if wgidx.Reverse == !rev && wgidx.SenderIndex == receiverIndex {
found = true
break
}
}
return found
}

View File

@@ -0,0 +1,99 @@
package utils
import "bytes"
type ByteBuffer struct {
Buf []byte
}
func (b *ByteBuffer) Append(data []byte) {
b.Buf = append(b.Buf, data...)
}
func (b *ByteBuffer) Len() int {
return len(b.Buf)
}
func (b *ByteBuffer) Index(sep []byte) int {
return bytes.Index(b.Buf, sep)
}
func (b *ByteBuffer) Get(length int, consume bool) (data []byte, ok bool) {
if len(b.Buf) < length {
return nil, false
}
data = b.Buf[:length]
if consume {
b.Buf = b.Buf[length:]
}
return data, true
}
func (b *ByteBuffer) GetString(length int, consume bool) (string, bool) {
data, ok := b.Get(length, consume)
if !ok {
return "", false
}
return string(data), true
}
func (b *ByteBuffer) GetByte(consume bool) (byte, bool) {
data, ok := b.Get(1, consume)
if !ok {
return 0, false
}
return data[0], true
}
func (b *ByteBuffer) GetUint16(littleEndian, consume bool) (uint16, bool) {
data, ok := b.Get(2, consume)
if !ok {
return 0, false
}
if littleEndian {
return uint16(data[0]) | uint16(data[1])<<8, true
}
return uint16(data[1]) | uint16(data[0])<<8, true
}
func (b *ByteBuffer) GetUint32(littleEndian, consume bool) (uint32, bool) {
data, ok := b.Get(4, consume)
if !ok {
return 0, false
}
if littleEndian {
return uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16 | uint32(data[3])<<24, true
}
return uint32(data[3]) | uint32(data[2])<<8 | uint32(data[1])<<16 | uint32(data[0])<<24, true
}
func (b *ByteBuffer) GetUntil(sep []byte, includeSep, consume bool) (data []byte, ok bool) {
index := b.Index(sep)
if index == -1 {
return nil, false
}
if includeSep {
index += len(sep)
}
return b.Get(index, consume)
}
func (b *ByteBuffer) GetSubBuffer(length int, consume bool) (sub *ByteBuffer, ok bool) {
data, ok := b.Get(length, consume)
if !ok {
return nil, false
}
return &ByteBuffer{Buf: data}, true
}
func (b *ByteBuffer) Skip(length int) bool {
if len(b.Buf) < length {
return false
}
b.Buf = b.Buf[length:]
return true
}
func (b *ByteBuffer) Reset() {
b.Buf = nil
}

54
analyzer/utils/lsm.go Normal file
View File

@@ -0,0 +1,54 @@
package utils
type LSMAction int
const (
LSMActionPause LSMAction = iota
LSMActionNext
LSMActionReset
LSMActionCancel
)
type LinearStateMachine struct {
Steps []func() LSMAction
index int
cancelled bool
}
func NewLinearStateMachine(steps ...func() LSMAction) *LinearStateMachine {
return &LinearStateMachine{
Steps: steps,
}
}
// Run runs the state machine until it pauses, finishes or is cancelled.
func (lsm *LinearStateMachine) Run() (cancelled bool, done bool) {
if lsm.index >= len(lsm.Steps) {
return lsm.cancelled, true
}
for lsm.index < len(lsm.Steps) {
action := lsm.Steps[lsm.index]()
switch action {
case LSMActionPause:
return false, false
case LSMActionNext:
lsm.index++
case LSMActionReset:
lsm.index = 0
case LSMActionCancel:
lsm.cancelled = true
return true, true
}
}
return false, true
}
func (lsm *LinearStateMachine) AppendSteps(steps ...func() LSMAction) {
lsm.Steps = append(lsm.Steps, steps...)
}
func (lsm *LinearStateMachine) Reset() {
lsm.index = 0
lsm.cancelled = false
}

9
analyzer/utils/string.go Normal file
View File

@@ -0,0 +1,9 @@
package utils
func ByteSlicesToStrings(bss [][]byte) []string {
ss := make([]string, len(bss))
for i, bs := range bss {
ss[i] = string(bs)
}
return ss
}

203
app.go Normal file
View File

@@ -0,0 +1,203 @@
package mellaris
import (
"context"
"errors"
"fmt"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/engine"
gfwio "git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/modifier"
"git.difuse.io/Difuse/Mellaris/ruleset"
)
// App owns the Mellaris engine and ruleset lifecycle.
type App struct {
engine engine.Engine
io gfwio.PacketIO
rulesetConfig *ruleset.BuiltinConfig
analyzers []analyzer.Analyzer
modifiers []modifier.Modifier
rulesFile string
}
// New builds an App from config and options.
func New(cfg Config, opts Options) (*App, error) {
rules, rulesFile, err := resolveRules(opts)
if err != nil {
return nil, err
}
analyzers := normalizeAnalyzers(opts.Analyzers)
modifiers := normalizeModifiers(opts.Modifiers)
engineLogger := opts.EngineLogger
if engineLogger == nil {
engineLogger = noopEngineLogger{}
}
rulesetLogger := opts.RulesetLogger
if rulesetLogger == nil {
rulesetLogger = noopRulesetLogger{}
}
packetIO := cfg.IO.PacketIO
ownsIO := false
if packetIO == nil {
packetIO, err = gfwio.NewNFQueuePacketIO(gfwio.NFQueuePacketIOConfig{
QueueSize: cfg.IO.QueueSize,
ReadBuffer: cfg.IO.ReadBuffer,
WriteBuffer: cfg.IO.WriteBuffer,
Local: cfg.IO.Local,
RST: cfg.IO.RST,
})
if err != nil {
return nil, ConfigError{Field: "io", Err: err}
}
ownsIO = true
}
cleanup := func() {
if ownsIO {
_ = packetIO.Close()
}
}
rsConfig := &ruleset.BuiltinConfig{
Logger: rulesetLogger,
GeoSiteFilename: cfg.Ruleset.GeoSite,
GeoIpFilename: cfg.Ruleset.GeoIp,
ProtectedDialContext: packetIO.ProtectedDialContext,
}
rs, err := ruleset.CompileExprRules(rules, analyzers, modifiers, rsConfig)
if err != nil {
cleanup()
return nil, err
}
engCfg := engine.Config{
Logger: engineLogger,
IO: packetIO,
Ruleset: rs,
Workers: cfg.Workers.Count,
WorkerQueueSize: cfg.Workers.QueueSize,
WorkerTCPMaxBufferedPagesTotal: cfg.Workers.TCPMaxBufferedPagesTotal,
WorkerTCPMaxBufferedPagesPerConn: cfg.Workers.TCPMaxBufferedPagesPerConn,
WorkerUDPMaxStreams: cfg.Workers.UDPMaxStreams,
}
eng, err := engine.NewEngine(engCfg)
if err != nil {
cleanup()
return nil, err
}
return &App{
engine: eng,
io: packetIO,
rulesetConfig: rsConfig,
analyzers: analyzers,
modifiers: modifiers,
rulesFile: rulesFile,
}, nil
}
// Run starts the engine and blocks until it exits or ctx is cancelled.
func (a *App) Run(ctx context.Context) error {
return a.engine.Run(ctx)
}
// Close releases the underlying PacketIO.
func (a *App) Close() error {
if a == nil || a.io == nil {
return nil
}
return a.io.Close()
}
// ReloadRules reloads rules from the configured rules file.
func (a *App) ReloadRules() error {
if a.rulesFile == "" {
return ConfigError{Field: "rules", Err: errors.New("rules file not set")}
}
rules, err := ruleset.ExprRulesFromYAML(a.rulesFile)
if err != nil {
return fmt.Errorf("load rules file %q: %w", a.rulesFile, err)
}
return a.UpdateRules(rules)
}
// UpdateRules compiles the provided rules and updates the running engine.
func (a *App) UpdateRules(rules []ruleset.ExprRule) error {
rs, err := ruleset.CompileExprRules(rules, a.analyzers, a.modifiers, a.rulesetConfig)
if err != nil {
return err
}
return a.engine.UpdateRuleset(rs)
}
// Engine returns the underlying engine instance.
func (a *App) Engine() engine.Engine {
return a.engine
}
func resolveRules(opts Options) ([]ruleset.ExprRule, string, error) {
if opts.RulesFile != "" && len(opts.Rules) > 0 {
return nil, "", ConfigError{Field: "rules", Err: errors.New("use either RulesFile or Rules")}
}
if opts.RulesFile != "" {
rules, err := ruleset.ExprRulesFromYAML(opts.RulesFile)
if err != nil {
return nil, opts.RulesFile, fmt.Errorf("load rules file %q: %w", opts.RulesFile, err)
}
return rules, opts.RulesFile, nil
}
if len(opts.Rules) > 0 {
return opts.Rules, "", nil
}
return nil, "", ConfigError{Field: "rules", Err: errors.New("no rules provided")}
}
func normalizeAnalyzers(in []analyzer.Analyzer) []analyzer.Analyzer {
if in == nil {
return DefaultAnalyzers()
}
out := make([]analyzer.Analyzer, len(in))
copy(out, in)
return out
}
func normalizeModifiers(in []modifier.Modifier) []modifier.Modifier {
if in == nil {
return DefaultModifiers()
}
out := make([]modifier.Modifier, len(in))
copy(out, in)
return out
}
type noopEngineLogger struct{}
func (noopEngineLogger) WorkerStart(id int) {}
func (noopEngineLogger) WorkerStop(id int) {}
func (noopEngineLogger) TCPStreamNew(workerID int, info ruleset.StreamInfo) {}
func (noopEngineLogger) TCPStreamPropUpdate(info ruleset.StreamInfo, close bool) {}
func (noopEngineLogger) TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
}
func (noopEngineLogger) UDPStreamNew(workerID int, info ruleset.StreamInfo) {}
func (noopEngineLogger) UDPStreamPropUpdate(info ruleset.StreamInfo, close bool) {}
func (noopEngineLogger) UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) {
}
func (noopEngineLogger) ModifyError(info ruleset.StreamInfo, err error) {}
func (noopEngineLogger) AnalyzerDebugf(streamID int64, name string, format string, args ...interface{}) {
}
func (noopEngineLogger) AnalyzerInfof(streamID int64, name string, format string, args ...interface{}) {
}
func (noopEngineLogger) AnalyzerErrorf(streamID int64, name string, format string, args ...interface{}) {
}
type noopRulesetLogger struct{}
func (noopRulesetLogger) Log(info ruleset.StreamInfo, name string) {}
func (noopRulesetLogger) MatchError(info ruleset.StreamInfo, name string, err error) {}

60
config.go Normal file
View File

@@ -0,0 +1,60 @@
package mellaris
import (
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/engine"
gfwio "git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/modifier"
"git.difuse.io/Difuse/Mellaris/ruleset"
)
// Config defines IO, worker, and ruleset settings for the engine.
type Config struct {
IO IOConfig `mapstructure:"io" yaml:"io"`
Workers WorkersConfig `mapstructure:"workers" yaml:"workers"`
Ruleset RulesetConfig `mapstructure:"ruleset" yaml:"ruleset"`
}
// IOConfig configures packet IO.
type IOConfig struct {
QueueSize uint32 `mapstructure:"queueSize" yaml:"queueSize"`
ReadBuffer int `mapstructure:"rcvBuf" yaml:"rcvBuf"`
WriteBuffer int `mapstructure:"sndBuf" yaml:"sndBuf"`
Local bool `mapstructure:"local" yaml:"local"`
RST bool `mapstructure:"rst" yaml:"rst"`
// PacketIO overrides NFQueue creation when set.
// When provided, App.Close will call PacketIO.Close.
PacketIO gfwio.PacketIO `mapstructure:"-" yaml:"-"`
}
// WorkersConfig configures engine worker behavior.
type WorkersConfig struct {
Count int `mapstructure:"count" yaml:"count"`
QueueSize int `mapstructure:"queueSize" yaml:"queueSize"`
TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal" yaml:"tcpMaxBufferedPagesTotal"`
TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn" yaml:"tcpMaxBufferedPagesPerConn"`
UDPMaxStreams int `mapstructure:"udpMaxStreams" yaml:"udpMaxStreams"`
}
// RulesetConfig configures built-in rule helpers.
type RulesetConfig struct {
GeoIp string `mapstructure:"geoip" yaml:"geoip"`
GeoSite string `mapstructure:"geosite" yaml:"geosite"`
}
// Options configures rules, analyzers, modifiers, and logging.
type Options struct {
// RulesFile is a YAML rules file on disk. Mutually exclusive with Rules.
RulesFile string
// Rules provides inline expression rules. Mutually exclusive with RulesFile.
Rules []ruleset.ExprRule
// Analyzers and Modifiers default to built-ins when nil.
Analyzers []analyzer.Analyzer
Modifiers []modifier.Modifier
// EngineLogger and RulesetLogger default to no-op loggers when nil.
EngineLogger engine.Logger
RulesetLogger ruleset.Logger
}

32
defaults.go Normal file
View File

@@ -0,0 +1,32 @@
package mellaris
import (
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/analyzer/tcp"
"git.difuse.io/Difuse/Mellaris/analyzer/udp"
"git.difuse.io/Difuse/Mellaris/modifier"
modUDP "git.difuse.io/Difuse/Mellaris/modifier/udp"
)
// DefaultAnalyzers returns the built-in analyzer set.
func DefaultAnalyzers() []analyzer.Analyzer {
return []analyzer.Analyzer{
&tcp.FETAnalyzer{},
&tcp.HTTPAnalyzer{},
&tcp.SocksAnalyzer{},
&tcp.SSHAnalyzer{},
&tcp.TLSAnalyzer{},
&tcp.TrojanAnalyzer{},
&udp.DNSAnalyzer{},
&udp.OpenVPNAnalyzer{},
&udp.QUICAnalyzer{},
&udp.WireGuardAnalyzer{},
}
}
// DefaultModifiers returns the built-in modifier set.
func DefaultModifiers() []modifier.Modifier {
return []modifier.Modifier{
&modUDP.DNSModifier{},
}
}

BIN
docs/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

115
engine/engine.go Normal file
View File

@@ -0,0 +1,115 @@
package engine
import (
"context"
"runtime"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
var _ Engine = (*engine)(nil)
type engine struct {
logger Logger
io io.PacketIO
workers []*worker
}
func NewEngine(config Config) (Engine, error) {
workerCount := config.Workers
if workerCount <= 0 {
workerCount = runtime.NumCPU()
}
var err error
workers := make([]*worker, workerCount)
for i := range workers {
workers[i], err = newWorker(workerConfig{
ID: i,
ChanSize: config.WorkerQueueSize,
Logger: config.Logger,
Ruleset: config.Ruleset,
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
UDPMaxStreams: config.WorkerUDPMaxStreams,
})
if err != nil {
return nil, err
}
}
return &engine{
logger: config.Logger,
io: config.IO,
workers: workers,
}, nil
}
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
for _, w := range e.workers {
if err := w.UpdateRuleset(r); err != nil {
return err
}
}
return nil
}
func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel() // Stop workers & IO
// Start workers
for _, w := range e.workers {
go w.Run(ioCtx)
}
// Register IO callback
errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
errChan <- err
return false
}
return e.dispatch(p)
})
if err != nil {
return err
}
// Block until IO errors or context is cancelled
select {
case err := <-errChan:
return err
case <-ctx.Done():
return nil
}
}
// dispatch dispatches a packet to a worker.
func (e *engine) dispatch(p io.Packet) bool {
data := p.Data()
ipVersion := data[0] >> 4
var layerType gopacket.LayerType
if ipVersion == 4 {
layerType = layers.LayerTypeIPv4
} else if ipVersion == 6 {
layerType = layers.LayerTypeIPv6
} else {
// Unsupported network layer
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true
}
// Load balance by stream ID
index := p.StreamID() % uint32(len(e.workers))
packet := gopacket.NewPacket(data, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
e.workers[index].Feed(&workerPacket{
StreamID: p.StreamID(),
Packet: packet,
SetVerdict: func(v io.Verdict, b []byte) error {
return e.io.SetVerdict(p, v, b)
},
})
return true
}

49
engine/interface.go Normal file
View File

@@ -0,0 +1,49 @@
package engine
import (
"context"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
)
// Engine is the main engine for Mellaris.
type Engine interface {
// UpdateRuleset updates the ruleset.
UpdateRuleset(ruleset.Ruleset) error
// Run runs the engine, until an error occurs or the context is cancelled.
Run(context.Context) error
}
// Config is the configuration for the engine.
type Config struct {
Logger Logger
IO io.PacketIO
Ruleset ruleset.Ruleset
Workers int // Number of workers. Zero or negative means auto (number of CPU cores).
WorkerQueueSize int
WorkerTCPMaxBufferedPagesTotal int
WorkerTCPMaxBufferedPagesPerConn int
WorkerUDPMaxStreams int
}
// Logger is the combined logging interface for the engine, workers and analyzers.
type Logger interface {
WorkerStart(id int)
WorkerStop(id int)
TCPStreamNew(workerID int, info ruleset.StreamInfo)
TCPStreamPropUpdate(info ruleset.StreamInfo, close bool)
TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool)
UDPStreamNew(workerID int, info ruleset.StreamInfo)
UDPStreamPropUpdate(info ruleset.StreamInfo, close bool)
UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool)
ModifyError(info ruleset.StreamInfo, err error)
AnalyzerDebugf(streamID int64, name string, format string, args ...interface{})
AnalyzerInfof(streamID int64, name string, format string, args ...interface{})
AnalyzerErrorf(streamID int64, name string, format string, args ...interface{})
}

229
engine/tcp.go Normal file
View File

@@ -0,0 +1,229 @@
package engine
import (
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
)
// tcpVerdict is a subset of io.Verdict for TCP streams.
// We don't allow modifying or dropping a single packet
// for TCP streams for now, as it doesn't make much sense.
type tcpVerdict io.Verdict
const (
tcpVerdictAccept = tcpVerdict(io.VerdictAccept)
tcpVerdictAcceptStream = tcpVerdict(io.VerdictAcceptStream)
tcpVerdictDropStream = tcpVerdict(io.VerdictDropStream)
)
type tcpContext struct {
*gopacket.PacketMetadata
Verdict tcpVerdict
}
func (ctx *tcpContext) GetCaptureInfo() gopacket.CaptureInfo {
return ctx.CaptureInfo
}
type tcpStreamFactory struct {
WorkerID int
Logger Logger
Node *snowflake.Node
RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset
}
func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream {
id := f.Node.Generate()
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolTCP,
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(tcp.SrcPort),
DstPort: uint16(tcp.DstPort),
Props: make(analyzer.CombinedPropMap),
}
f.Logger.TCPStreamNew(f.WorkerID, info)
f.RulesetMutex.RLock()
rs := f.Ruleset
f.RulesetMutex.RUnlock()
ans := analyzersToTCPAnalyzers(rs.Analyzers(info))
// Create entries for each analyzer
entries := make([]*tcpStreamEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &tcpStreamEntry{
Name: a.Name(),
Stream: a.NewTCP(analyzer.TCPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(tcp.SrcPort),
DstPort: uint16(tcp.DstPort),
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: f.Logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
return &tcpStream{
info: info,
virgin: true,
logger: f.Logger,
ruleset: rs,
activeEntries: entries,
}
}
func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
f.RulesetMutex.Lock()
defer f.RulesetMutex.Unlock()
f.Ruleset = r
return nil
}
type tcpStream struct {
info ruleset.StreamInfo
virgin bool // true if no packets have been processed
logger Logger
ruleset ruleset.Ruleset
activeEntries []*tcpStreamEntry
doneEntries []*tcpStreamEntry
lastVerdict tcpVerdict
}
type tcpStreamEntry struct {
Name string
Stream analyzer.TCPStream
HasLimit bool
Quota int
}
func (s *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool {
if len(s.activeEntries) > 0 || s.virgin {
// Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched.
return true
} else {
ctx := ac.(*tcpContext)
ctx.Verdict = s.lastVerdict
return false
}
}
func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) {
dir, start, end, skip := sg.Info()
rev := dir == reassembly.TCPDirServerToClient
avail, _ := sg.Lengths()
data := sg.Fetch(avail)
updated := false
for i := len(s.activeEntries) - 1; i >= 0; i-- {
// Important: reverse order so we can remove entries
entry := s.activeEntries[i]
update, closeUpdate, done := s.feedEntry(entry, rev, start, end, skip, data)
up1 := processPropUpdate(s.info.Props, entry.Name, update)
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
updated = updated || up1 || up2
if done {
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
s.doneEntries = append(s.doneEntries, entry)
}
}
ctx := ac.(*tcpContext)
if updated || s.virgin {
s.virgin = false
s.logger.TCPStreamPropUpdate(s.info, false)
// Match properties against ruleset
result := s.ruleset.Match(s.info)
action := result.Action
if action != ruleset.ActionMaybe && action != ruleset.ActionModify {
verdict := actionToTCPVerdict(action)
s.lastVerdict = verdict
ctx.Verdict = verdict
s.logger.TCPStreamAction(s.info, action, false)
// Verdict issued, no need to process any more packets
s.closeActiveEntries()
}
}
if len(s.activeEntries) == 0 && ctx.Verdict == tcpVerdictAccept {
// All entries are done but no verdict issued, accept stream
s.lastVerdict = tcpVerdictAcceptStream
ctx.Verdict = tcpVerdictAcceptStream
s.logger.TCPStreamAction(s.info, ruleset.ActionAllow, true)
}
}
func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool {
s.closeActiveEntries()
return true
}
func (s *tcpStream) closeActiveEntries() {
// Signal close to all active entries & move them to doneEntries
updated := false
for _, entry := range s.activeEntries {
update := entry.Stream.Close(false)
up := processPropUpdate(s.info.Props, entry.Name, update)
updated = updated || up
}
if updated {
s.logger.TCPStreamPropUpdate(s.info, true)
}
s.doneEntries = append(s.doneEntries, s.activeEntries...)
s.activeEntries = nil
}
func (s *tcpStream) feedEntry(entry *tcpStreamEntry, rev, start, end bool, skip int, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
if !entry.HasLimit {
update, done = entry.Stream.Feed(rev, start, end, skip, data)
} else {
qData := data
if len(qData) > entry.Quota {
qData = qData[:entry.Quota]
}
update, done = entry.Stream.Feed(rev, start, end, skip, qData)
entry.Quota -= len(qData)
if entry.Quota <= 0 {
// Quota exhausted, signal close & move to doneEntries
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer {
tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans))
for _, a := range ans {
if tcpM, ok := a.(analyzer.TCPAnalyzer); ok {
tcpAns = append(tcpAns, tcpM)
}
}
return tcpAns
}
func actionToTCPVerdict(a ruleset.Action) tcpVerdict {
switch a {
case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify:
return tcpVerdictAcceptStream
case ruleset.ActionBlock, ruleset.ActionDrop:
return tcpVerdictDropStream
default:
// Should never happen
return tcpVerdictAcceptStream
}
}

299
engine/udp.go Normal file
View File

@@ -0,0 +1,299 @@
package engine
import (
"errors"
"net"
"sync"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/modifier"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
lru "github.com/hashicorp/golang-lru/v2"
)
// udpVerdict is a subset of io.Verdict for UDP streams.
// For UDP, we support all verdicts.
type udpVerdict io.Verdict
const (
udpVerdictAccept = udpVerdict(io.VerdictAccept)
udpVerdictAcceptModify = udpVerdict(io.VerdictAcceptModify)
udpVerdictAcceptStream = udpVerdict(io.VerdictAcceptStream)
udpVerdictDrop = udpVerdict(io.VerdictDrop)
udpVerdictDropStream = udpVerdict(io.VerdictDropStream)
)
var errInvalidModifier = errors.New("invalid modifier")
type udpContext struct {
Verdict udpVerdict
Packet []byte
}
type udpStreamFactory struct {
WorkerID int
Logger Logger
Node *snowflake.Node
RulesetMutex sync.RWMutex
Ruleset ruleset.Ruleset
}
func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream {
id := f.Node.Generate()
ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw())
info := ruleset.StreamInfo{
ID: id.Int64(),
Protocol: ruleset.ProtocolUDP,
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(udp.SrcPort),
DstPort: uint16(udp.DstPort),
Props: make(analyzer.CombinedPropMap),
}
f.Logger.UDPStreamNew(f.WorkerID, info)
f.RulesetMutex.RLock()
rs := f.Ruleset
f.RulesetMutex.RUnlock()
ans := analyzersToUDPAnalyzers(rs.Analyzers(info))
// Create entries for each analyzer
entries := make([]*udpStreamEntry, 0, len(ans))
for _, a := range ans {
entries = append(entries, &udpStreamEntry{
Name: a.Name(),
Stream: a.NewUDP(analyzer.UDPInfo{
SrcIP: ipSrc,
DstIP: ipDst,
SrcPort: uint16(udp.SrcPort),
DstPort: uint16(udp.DstPort),
}, &analyzerLogger{
StreamID: id.Int64(),
Name: a.Name(),
Logger: f.Logger,
}),
HasLimit: a.Limit() > 0,
Quota: a.Limit(),
})
}
return &udpStream{
info: info,
virgin: true,
logger: f.Logger,
ruleset: rs,
activeEntries: entries,
}
}
func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error {
f.RulesetMutex.Lock()
defer f.RulesetMutex.Unlock()
f.Ruleset = r
return nil
}
type udpStreamManager struct {
factory *udpStreamFactory
streams *lru.Cache[uint32, *udpStreamValue]
}
type udpStreamValue struct {
Stream *udpStream
IPFlow gopacket.Flow
UDPFlow gopacket.Flow
}
func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) {
fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow
rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse()
return fwd || rev, rev
}
func newUDPStreamManager(factory *udpStreamFactory, maxStreams int) (*udpStreamManager, error) {
ss, err := lru.New[uint32, *udpStreamValue](maxStreams)
if err != nil {
return nil, err
}
return &udpStreamManager{
factory: factory,
streams: ss,
}, nil
}
func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) {
rev := false
value, ok := m.streams.Get(streamID)
if !ok {
// New stream
value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
IPFlow: ipFlow,
UDPFlow: udp.TransportFlow(),
}
m.streams.Add(streamID, value)
} else {
// Stream ID exists, but is it really the same stream?
ok, rev = value.Match(ipFlow, udp.TransportFlow())
if !ok {
// It's not - close the old stream & replace it with a new one
value.Stream.Close()
value = &udpStreamValue{
Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc),
IPFlow: ipFlow,
UDPFlow: udp.TransportFlow(),
}
m.streams.Add(streamID, value)
}
}
if value.Stream.Accept(udp, rev, uc) {
value.Stream.Feed(udp, rev, uc)
}
}
type udpStream struct {
info ruleset.StreamInfo
virgin bool // true if no packets have been processed
logger Logger
ruleset ruleset.Ruleset
activeEntries []*udpStreamEntry
doneEntries []*udpStreamEntry
lastVerdict udpVerdict
}
type udpStreamEntry struct {
Name string
Stream analyzer.UDPStream
HasLimit bool
Quota int
}
func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool {
if len(s.activeEntries) > 0 || s.virgin {
// Make sure every stream matches against the ruleset at least once,
// even if there are no activeEntries, as the ruleset may have built-in
// properties that need to be matched.
return true
} else {
uc.Verdict = s.lastVerdict
return false
}
}
func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) {
updated := false
for i := len(s.activeEntries) - 1; i >= 0; i-- {
// Important: reverse order so we can remove entries
entry := s.activeEntries[i]
update, closeUpdate, done := s.feedEntry(entry, rev, udp.Payload)
up1 := processPropUpdate(s.info.Props, entry.Name, update)
up2 := processPropUpdate(s.info.Props, entry.Name, closeUpdate)
updated = updated || up1 || up2
if done {
s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...)
s.doneEntries = append(s.doneEntries, entry)
}
}
if updated || s.virgin {
s.virgin = false
s.logger.UDPStreamPropUpdate(s.info, false)
// Match properties against ruleset
result := s.ruleset.Match(s.info)
action := result.Action
if action == ruleset.ActionModify {
// Call the modifier instance
udpMI, ok := result.ModInstance.(modifier.UDPModifierInstance)
if !ok {
// Not for UDP, fallback to maybe
s.logger.ModifyError(s.info, errInvalidModifier)
action = ruleset.ActionMaybe
} else {
var err error
uc.Packet, err = udpMI.Process(udp.Payload)
if err != nil {
// Modifier error, fallback to maybe
s.logger.ModifyError(s.info, err)
action = ruleset.ActionMaybe
}
}
}
if action != ruleset.ActionMaybe {
verdict, final := actionToUDPVerdict(action)
s.lastVerdict = verdict
uc.Verdict = verdict
s.logger.UDPStreamAction(s.info, action, false)
if final {
s.closeActiveEntries()
}
}
}
if len(s.activeEntries) == 0 && uc.Verdict == udpVerdictAccept {
// All entries are done but no verdict issued, accept stream
s.lastVerdict = udpVerdictAcceptStream
uc.Verdict = udpVerdictAcceptStream
s.logger.UDPStreamAction(s.info, ruleset.ActionAllow, true)
}
}
func (s *udpStream) Close() {
s.closeActiveEntries()
}
func (s *udpStream) closeActiveEntries() {
// Signal close to all active entries & move them to doneEntries
updated := false
for _, entry := range s.activeEntries {
update := entry.Stream.Close(false)
up := processPropUpdate(s.info.Props, entry.Name, update)
updated = updated || up
}
if updated {
s.logger.UDPStreamPropUpdate(s.info, true)
}
s.doneEntries = append(s.doneEntries, s.activeEntries...)
s.activeEntries = nil
}
func (s *udpStream) feedEntry(entry *udpStreamEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) {
update, done = entry.Stream.Feed(rev, data)
if entry.HasLimit {
entry.Quota -= len(data)
if entry.Quota <= 0 {
// Quota exhausted, signal close & move to doneEntries
closeUpdate = entry.Stream.Close(true)
done = true
}
}
return
}
func analyzersToUDPAnalyzers(ans []analyzer.Analyzer) []analyzer.UDPAnalyzer {
udpAns := make([]analyzer.UDPAnalyzer, 0, len(ans))
for _, a := range ans {
if udpM, ok := a.(analyzer.UDPAnalyzer); ok {
udpAns = append(udpAns, udpM)
}
}
return udpAns
}
func actionToUDPVerdict(a ruleset.Action) (v udpVerdict, final bool) {
switch a {
case ruleset.ActionMaybe:
return udpVerdictAccept, false
case ruleset.ActionAllow:
return udpVerdictAcceptStream, true
case ruleset.ActionBlock:
return udpVerdictDropStream, true
case ruleset.ActionDrop:
return udpVerdictDrop, false
case ruleset.ActionModify:
return udpVerdictAcceptModify, false
default:
// Should never happen
return udpVerdictAccept, false
}
}

50
engine/utils.go Normal file
View File

@@ -0,0 +1,50 @@
package engine
import "git.difuse.io/Difuse/Mellaris/analyzer"
var _ analyzer.Logger = (*analyzerLogger)(nil)
type analyzerLogger struct {
StreamID int64
Name string
Logger Logger
}
func (l *analyzerLogger) Debugf(format string, args ...interface{}) {
l.Logger.AnalyzerDebugf(l.StreamID, l.Name, format, args...)
}
func (l *analyzerLogger) Infof(format string, args ...interface{}) {
l.Logger.AnalyzerInfof(l.StreamID, l.Name, format, args...)
}
func (l *analyzerLogger) Errorf(format string, args ...interface{}) {
l.Logger.AnalyzerErrorf(l.StreamID, l.Name, format, args...)
}
func processPropUpdate(cpm analyzer.CombinedPropMap, name string, update *analyzer.PropUpdate) (updated bool) {
if update == nil || update.Type == analyzer.PropUpdateNone {
return false
}
switch update.Type {
case analyzer.PropUpdateMerge:
m := cpm[name]
if m == nil {
m = make(analyzer.PropMap, len(update.M))
cpm[name] = m
}
for k, v := range update.M {
m[k] = v
}
return true
case analyzer.PropUpdateReplace:
cpm[name] = update.M
return true
case analyzer.PropUpdateDelete:
delete(cpm, name)
return true
default:
// Invalid update type, ignore for now
return false
}
}

185
engine/worker.go Normal file
View File

@@ -0,0 +1,185 @@
package engine
import (
"context"
"git.difuse.io/Difuse/Mellaris/io"
"git.difuse.io/Difuse/Mellaris/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
)
const (
defaultChanSize = 64
defaultTCPMaxBufferedPagesTotal = 4096
defaultTCPMaxBufferedPagesPerConnection = 64
defaultUDPMaxStreams = 4096
)
type workerPacket struct {
StreamID uint32
Packet gopacket.Packet
SetVerdict func(io.Verdict, []byte) error
}
type worker struct {
id int
packetChan chan *workerPacket
logger Logger
tcpStreamFactory *tcpStreamFactory
tcpStreamPool *reassembly.StreamPool
tcpAssembler *reassembly.Assembler
udpStreamFactory *udpStreamFactory
udpStreamManager *udpStreamManager
modSerializeBuffer gopacket.SerializeBuffer
}
type workerConfig struct {
ID int
ChanSize int
Logger Logger
Ruleset ruleset.Ruleset
TCPMaxBufferedPagesTotal int
TCPMaxBufferedPagesPerConn int
UDPMaxStreams int
}
func (c *workerConfig) fillDefaults() {
if c.ChanSize <= 0 {
c.ChanSize = defaultChanSize
}
if c.TCPMaxBufferedPagesTotal <= 0 {
c.TCPMaxBufferedPagesTotal = defaultTCPMaxBufferedPagesTotal
}
if c.TCPMaxBufferedPagesPerConn <= 0 {
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
}
if c.UDPMaxStreams <= 0 {
c.UDPMaxStreams = defaultUDPMaxStreams
}
}
func newWorker(config workerConfig) (*worker, error) {
config.fillDefaults()
sfNode, err := snowflake.NewNode(int64(config.ID))
if err != nil {
return nil, err
}
tcpSF := &tcpStreamFactory{
WorkerID: config.ID,
Logger: config.Logger,
Node: sfNode,
Ruleset: config.Ruleset,
}
tcpStreamPool := reassembly.NewStreamPool(tcpSF)
tcpAssembler := reassembly.NewAssembler(tcpStreamPool)
tcpAssembler.MaxBufferedPagesTotal = config.TCPMaxBufferedPagesTotal
tcpAssembler.MaxBufferedPagesPerConnection = config.TCPMaxBufferedPagesPerConn
udpSF := &udpStreamFactory{
WorkerID: config.ID,
Logger: config.Logger,
Node: sfNode,
Ruleset: config.Ruleset,
}
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams)
if err != nil {
return nil, err
}
return &worker{
id: config.ID,
packetChan: make(chan *workerPacket, config.ChanSize),
logger: config.Logger,
tcpStreamFactory: tcpSF,
tcpStreamPool: tcpStreamPool,
tcpAssembler: tcpAssembler,
udpStreamFactory: udpSF,
udpStreamManager: udpSM,
modSerializeBuffer: gopacket.NewSerializeBuffer(),
}, nil
}
func (w *worker) Feed(p *workerPacket) {
w.packetChan <- p
}
func (w *worker) Run(ctx context.Context) {
w.logger.WorkerStart(w.id)
defer w.logger.WorkerStop(w.id)
for {
select {
case <-ctx.Done():
return
case wPkt := <-w.packetChan:
if wPkt == nil {
// Closed
return
}
v, b := w.handle(wPkt.StreamID, wPkt.Packet)
_ = wPkt.SetVerdict(v, b)
}
}
}
func (w *worker) UpdateRuleset(r ruleset.Ruleset) error {
if err := w.tcpStreamFactory.UpdateRuleset(r); err != nil {
return err
}
return w.udpStreamFactory.UpdateRuleset(r)
}
func (w *worker) handle(streamID uint32, p gopacket.Packet) (io.Verdict, []byte) {
netLayer, trLayer := p.NetworkLayer(), p.TransportLayer()
if netLayer == nil || trLayer == nil {
// Invalid packet
return io.VerdictAccept, nil
}
ipFlow := netLayer.NetworkFlow()
switch tr := trLayer.(type) {
case *layers.TCP:
return w.handleTCP(ipFlow, p.Metadata(), tr), nil
case *layers.UDP:
v, modPayload := w.handleUDP(streamID, ipFlow, tr)
if v == io.VerdictAcceptModify && modPayload != nil {
tr.Payload = modPayload
_ = tr.SetNetworkLayerForChecksum(netLayer)
_ = w.modSerializeBuffer.Clear()
err := gopacket.SerializePacket(w.modSerializeBuffer,
gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, p)
if err != nil {
// Just accept without modification for now
return io.VerdictAccept, nil
}
return v, w.modSerializeBuffer.Bytes()
}
return v, nil
default:
// Unsupported protocol
return io.VerdictAccept, nil
}
}
func (w *worker) handleTCP(ipFlow gopacket.Flow, pMeta *gopacket.PacketMetadata, tcp *layers.TCP) io.Verdict {
ctx := &tcpContext{
PacketMetadata: pMeta,
Verdict: tcpVerdictAccept,
}
w.tcpAssembler.AssembleWithContext(ipFlow, tcp, ctx)
return io.Verdict(ctx.Verdict)
}
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) {
ctx := &udpContext{
Verdict: udpVerdictAccept,
}
w.udpStreamManager.MatchWithContext(streamID, ipFlow, udp, ctx)
return io.Verdict(ctx.Verdict), ctx.Packet
}

17
errors.go Normal file
View File

@@ -0,0 +1,17 @@
package mellaris
import "fmt"
// ConfigError indicates a configuration issue.
type ConfigError struct {
Field string
Err error
}
func (e ConfigError) Error() string {
return fmt.Sprintf("invalid config: %s: %s", e.Field, e.Err)
}
func (e ConfigError) Unwrap() error {
return e.Err
}

31
go.mod Normal file
View File

@@ -0,0 +1,31 @@
module git.difuse.io/Difuse/Mellaris
go 1.21
require (
github.com/bwmarrin/snowflake v0.3.0
github.com/coreos/go-iptables v0.7.0
github.com/expr-lang/expr v1.16.3
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf
github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/mdlayher/netlink v1.6.0
github.com/quic-go/quic-go v0.41.0
golang.org/x/crypto v0.19.0
golang.org/x/sys v0.17.0
google.golang.org/protobuf v1.31.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/josharian/native v1.0.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/mdlayher/socket v0.1.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/tools v0.13.0 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
)

103
go.sum Normal file
View File

@@ -0,0 +1,103 @@
github.com/bwmarrin/snowflake v0.3.0 h1:xm67bEhkKh6ij1790JB83OujPR5CzNe8QuQqAgISZN0=
github.com/bwmarrin/snowflake v0.3.0/go.mod h1:NdZxfVWX+oR6y2K0o6qAYv6gIOP9rjG0/E9WsDpxqwE=
github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/expr-lang/expr v1.16.3 h1:NLldf786GffptcXNxxJx5dQ+FzeWDKChBDqOOwyK8to=
github.com/expr-lang/expr v1.16.3/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ=
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf h1:NqGS3vTHzVENbIfd87cXZwdpO6MB2R1PjHMJLi4Z3ow=
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf/go.mod h1:eSnAor2YCfMCVYrVNEhkLGN/r1L+J4uDjc0EUy0tfq4=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866 h1:NaJi58bCZZh0jjPw78EqDZekPEfhlzYE01C5R+zh1tE=
github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866/go.mod h1:riddUzxTSBpJXk3qBHtYr4qOhFhT6k/1c0E3qkQjQpA=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/josharian/native v1.0.0 h1:Ts/E8zCSEsG17dUqv7joXJFybuMLjQfWE04tsBODTxk=
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mdlayher/netlink v1.6.0 h1:rOHX5yl7qnlpiVkFWoqccueppMtXzeziFjWAjLg6sz0=
github.com/mdlayher/netlink v1.6.0/go.mod h1:0o3PlBmGst1xve7wQ7j/hwpNaFaH4qCRyWCdcZk8/vA=
github.com/mdlayher/socket v0.1.1 h1:q3uOGirUPfAV2MUoaC7BavjQ154J7+JOkTWyiV+intI=
github.com/mdlayher/socket v0.1.1/go.mod h1:mYV5YIZAfHh4dzDVzI8x8tWLWCliuX8Mon5Awbj+qDs=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/quic-go v0.41.0 h1:aD8MmHfgqTURWNJy48IYFg2OnxwHT3JL7ahGs73lb4k=
github.com/quic-go/quic-go v0.41.0/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

56
io/interface.go Normal file
View File

@@ -0,0 +1,56 @@
package io
import (
"context"
"net"
)
type Verdict int
const (
// VerdictAccept accepts the packet, but continues to process the stream.
VerdictAccept Verdict = iota
// VerdictAcceptModify is like VerdictAccept, but replaces the packet with a new one.
VerdictAcceptModify
// VerdictAcceptStream accepts the packet and stops processing the stream.
VerdictAcceptStream
// VerdictDrop drops the packet, but does not block the stream.
VerdictDrop
// VerdictDropStream drops the packet and blocks the stream.
VerdictDropStream
)
// Packet represents an IP packet.
type Packet interface {
// StreamID is the ID of the stream the packet belongs to.
StreamID() uint32
// Data is the raw packet data, starting with the IP header.
Data() []byte
}
// PacketCallback is called for each packet received.
// Return false to "unregister" and stop receiving packets.
type PacketCallback func(Packet, error) bool
type PacketIO interface {
// Register registers a callback to be called for each packet received.
// The callback should be called in one or more separate goroutines,
// and stop when the context is cancelled.
Register(context.Context, PacketCallback) error
// SetVerdict sets the verdict for a packet.
SetVerdict(Packet, Verdict, []byte) error
// ProtectedDialContext is like net.DialContext, but the connection is "protected"
// in the sense that the packets sent/received through the connection must bypass
// the packet IO and not be processed by the callback.
ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error)
// Close closes the packet IO.
Close() error
}
type ErrInvalidPacket struct {
Err error
}
func (e *ErrInvalidPacket) Error() string {
return "invalid packet: " + e.Err.Error()
}

431
io/nfqueue.go Normal file
View File

@@ -0,0 +1,431 @@
package io
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"os/exec"
"strconv"
"strings"
"syscall"
"github.com/coreos/go-iptables/iptables"
"github.com/florianl/go-nfqueue"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
)
const (
nfqueueNum = 100
nfqueueMaxPacketLen = 0xFFFF
nfqueueDefaultQueueSize = 128
nfqueueConnMarkAccept = 1001
nfqueueConnMarkDrop = 1002
nftFamily = "inet"
nftTable = "mellaris"
)
func generateNftRules(local, rst bool) (*nftTableSpec, error) {
if local && rst {
return nil, errors.New("tcp rst is not supported in local mode")
}
table := &nftTableSpec{
Family: nftFamily,
Table: nftTable,
}
table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept))
table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop))
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNum))
if local {
table.Chains = []nftChainSpec{
{Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"},
{Chain: "OUTPUT", Header: "type filter hook output priority filter; policy accept;"},
}
} else {
table.Chains = []nftChainSpec{
{Chain: "FORWARD", Header: "type filter hook forward priority filter; policy accept;"},
}
}
for i := range table.Chains {
c := &table.Chains[i]
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
if rst {
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
}
c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop")
c.Rules = append(c.Rules, "counter queue num $QUEUE_NUM bypass")
}
return table, nil
}
func generateIptRules(local, rst bool) ([]iptRule, error) {
if local && rst {
return nil, errors.New("tcp rst is not supported in local mode")
}
var chains []string
if local {
chains = []string{"INPUT", "OUTPUT"}
} else {
chains = []string{"FORWARD"}
}
rules := make([]iptRule, 0, 4*len(chains))
for _, chain := range chains {
// Bypass protected connections
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
if rst {
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
}
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}})
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}})
}
return rules, nil
}
var _ PacketIO = (*nfqueuePacketIO)(nil)
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
type nfqueuePacketIO struct {
n *nfqueue.Nfqueue
local bool
rst bool
rSet bool // whether the nftables/iptables rules have been set
// iptables not nil = use iptables instead of nftables
ipt4 *iptables.IPTables
ipt6 *iptables.IPTables
protectedDialer *net.Dialer
}
type NFQueuePacketIOConfig struct {
QueueSize uint32
ReadBuffer int
WriteBuffer int
Local bool
RST bool
}
func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
if config.QueueSize == 0 {
config.QueueSize = nfqueueDefaultQueueSize
}
var ipt4, ipt6 *iptables.IPTables
var err error
if nftCheck() != nil {
// We prefer nftables, but if it's not available, fall back to iptables
ipt4, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, err
}
ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
return nil, err
}
}
n, err := nfqueue.Open(&nfqueue.Config{
NfQueue: nfqueueNum,
MaxPacketLen: nfqueueMaxPacketLen,
MaxQueueLen: config.QueueSize,
Copymode: nfqueue.NfQnlCopyPacket,
Flags: nfqueue.NfQaCfgFlagConntrack,
})
if err != nil {
return nil, err
}
if config.ReadBuffer > 0 {
err = n.Con.SetReadBuffer(config.ReadBuffer)
if err != nil {
_ = n.Close()
return nil, err
}
}
if config.WriteBuffer > 0 {
err = n.Con.SetWriteBuffer(config.WriteBuffer)
if err != nil {
_ = n.Close()
return nil, err
}
}
return &nfqueuePacketIO{
n: n,
local: config.Local,
rst: config.RST,
ipt4: ipt4,
ipt6: ipt6,
protectedDialer: &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
var err error
cErr := c.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept)
})
if cErr != nil {
return cErr
}
return err
},
},
}, nil
}
func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
err := n.n.RegisterWithErrorFunc(ctx,
func(a nfqueue.Attribute) int {
if ok, verdict := n.packetAttributeSanityCheck(a); !ok {
if a.PacketID != nil {
_ = n.n.SetVerdict(*a.PacketID, verdict)
}
return 0
}
p := &nfqueuePacket{
id: *a.PacketID,
streamID: ctIDFromCtBytes(*a.Ct),
data: *a.Payload,
}
return okBoolToInt(cb(p, nil))
},
func(e error) int {
if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) {
if errors.Is(opErr.Err, unix.ENOBUFS) {
// Kernel buffer temporarily full, ignore
return 0
}
}
return okBoolToInt(cb(nil, e))
})
if err != nil {
return err
}
if !n.rSet {
if n.ipt4 != nil {
err = n.setupIpt(n.local, n.rst, false)
} else {
err = n.setupNft(n.local, n.rst, false)
}
if err != nil {
return err
}
n.rSet = true
}
return nil
}
func (n *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) {
if a.PacketID == nil {
// Re-inject to NFQUEUE is actually not possible in this condition
return false, -1
}
if a.Payload == nil || len(*a.Payload) < 20 {
// 20 is the minimum possible size of an IP packet
return false, nfqueue.NfDrop
}
if a.Ct == nil {
// Multicast packets may not have a conntrack, but only appear in local mode
if n.local {
return false, nfqueue.NfAccept
}
return false, nfqueue.NfDrop
}
return true, -1
}
func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
nP, ok := p.(*nfqueuePacket)
if !ok {
return &ErrInvalidPacket{Err: errNotNFQueuePacket}
}
switch v {
case VerdictAccept:
return n.n.SetVerdict(nP.id, nfqueue.NfAccept)
case VerdictAcceptModify:
return n.n.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
case VerdictAcceptStream:
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept)
case VerdictDrop:
return n.n.SetVerdict(nP.id, nfqueue.NfDrop)
case VerdictDropStream:
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop)
default:
// Invalid verdict, ignore for now
return nil
}
}
func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
return n.protectedDialer.DialContext(ctx, network, address)
}
func (n *nfqueuePacketIO) Close() error {
if n.rSet {
if n.ipt4 != nil {
_ = n.setupIpt(n.local, n.rst, true)
} else {
_ = n.setupNft(n.local, n.rst, true)
}
n.rSet = false
}
return n.n.Close()
}
func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
rules, err := generateNftRules(local, rst)
if err != nil {
return err
}
rulesText := rules.String()
if remove {
err = nftDelete(nftFamily, nftTable)
} else {
// Delete first to make sure no leftover rules
_ = nftDelete(nftFamily, nftTable)
err = nftAdd(rulesText)
}
if err != nil {
return err
}
return nil
}
func (n *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
rules, err := generateIptRules(local, rst)
if err != nil {
return err
}
if remove {
err = iptsBatchDeleteIfExists([]*iptables.IPTables{n.ipt4, n.ipt6}, rules)
} else {
err = iptsBatchAppendUnique([]*iptables.IPTables{n.ipt4, n.ipt6}, rules)
}
if err != nil {
return err
}
return nil
}
var _ Packet = (*nfqueuePacket)(nil)
type nfqueuePacket struct {
id uint32
streamID uint32
data []byte
}
func (p *nfqueuePacket) StreamID() uint32 {
return p.streamID
}
func (p *nfqueuePacket) Data() []byte {
return p.data
}
func okBoolToInt(ok bool) int {
if ok {
return 0
} else {
return 1
}
}
func nftCheck() error {
_, err := exec.LookPath("nft")
if err != nil {
return err
}
return nil
}
func nftAdd(input string) error {
cmd := exec.Command("nft", "-f", "-")
cmd.Stdin = strings.NewReader(input)
return cmd.Run()
}
func nftDelete(family, table string) error {
cmd := exec.Command("nft", "delete", "table", family, table)
return cmd.Run()
}
type nftTableSpec struct {
Defines []string
Family, Table string
Chains []nftChainSpec
}
func (t *nftTableSpec) String() string {
chains := make([]string, 0, len(t.Chains))
for _, c := range t.Chains {
chains = append(chains, c.String())
}
return fmt.Sprintf(`
%s
table %s %s {
%s
}
`, strings.Join(t.Defines, "\n"), t.Family, t.Table, strings.Join(chains, ""))
}
type nftChainSpec struct {
Chain string
Header string
Rules []string
}
func (c *nftChainSpec) String() string {
return fmt.Sprintf(`
chain %s {
%s
%s
}
`, c.Chain, c.Header, strings.Join(c.Rules, "\n\x20\x20\x20\x20"))
}
type iptRule struct {
Table, Chain string
RuleSpec []string
}
func iptsBatchAppendUnique(ipts []*iptables.IPTables, rules []iptRule) error {
for _, r := range rules {
for _, ipt := range ipts {
err := ipt.AppendUnique(r.Table, r.Chain, r.RuleSpec...)
if err != nil {
return err
}
}
}
return nil
}
func iptsBatchDeleteIfExists(ipts []*iptables.IPTables, rules []iptRule) error {
for _, r := range rules {
for _, ipt := range ipts {
err := ipt.DeleteIfExists(r.Table, r.Chain, r.RuleSpec...)
if err != nil {
return err
}
}
}
return nil
}
func ctIDFromCtBytes(ct []byte) uint32 {
ctAttrs, err := netlink.UnmarshalAttributes(ct)
if err != nil {
return 0
}
for _, attr := range ctAttrs {
if attr.Type == 12 { // CTA_ID
return binary.BigEndian.Uint32(attr.Data)
}
}
return 0
}

32
modifier/interface.go Normal file
View File

@@ -0,0 +1,32 @@
package modifier
type Modifier interface {
// Name returns the name of the modifier.
Name() string
// New returns a new modifier instance.
New(args map[string]interface{}) (Instance, error)
}
type Instance interface{}
type UDPModifierInstance interface {
Instance
// Process takes a UDP packet and returns a modified UDP packet.
Process(data []byte) ([]byte, error)
}
type ErrInvalidPacket struct {
Err error
}
func (e *ErrInvalidPacket) Error() string {
return "invalid packet: " + e.Err.Error()
}
type ErrInvalidArgs struct {
Err error
}
func (e *ErrInvalidArgs) Error() string {
return "invalid args: " + e.Err.Error()
}

96
modifier/udp/dns.go Normal file
View File

@@ -0,0 +1,96 @@
package udp
import (
"errors"
"net"
"git.difuse.io/Difuse/Mellaris/modifier"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
)
var _ modifier.Modifier = (*DNSModifier)(nil)
var (
errInvalidIP = errors.New("invalid ip")
errNotValidDNSResponse = errors.New("not a valid dns response")
errEmptyDNSQuestion = errors.New("empty dns question")
)
type DNSModifier struct{}
func (m *DNSModifier) Name() string {
return "dns"
}
func (m *DNSModifier) New(args map[string]interface{}) (modifier.Instance, error) {
i := &dnsModifierInstance{}
aStr, ok := args["a"].(string)
if ok {
a := net.ParseIP(aStr).To4()
if a == nil {
return nil, &modifier.ErrInvalidArgs{Err: errInvalidIP}
}
i.A = a
}
aaaaStr, ok := args["aaaa"].(string)
if ok {
aaaa := net.ParseIP(aaaaStr).To16()
if aaaa == nil {
return nil, &modifier.ErrInvalidArgs{Err: errInvalidIP}
}
i.AAAA = aaaa
}
return i, nil
}
var _ modifier.UDPModifierInstance = (*dnsModifierInstance)(nil)
type dnsModifierInstance struct {
A net.IP
AAAA net.IP
}
func (i *dnsModifierInstance) Process(data []byte) ([]byte, error) {
dns := &layers.DNS{}
err := dns.DecodeFromBytes(data, gopacket.NilDecodeFeedback)
if err != nil {
return nil, &modifier.ErrInvalidPacket{Err: err}
}
if !dns.QR || dns.ResponseCode != layers.DNSResponseCodeNoErr {
return nil, &modifier.ErrInvalidPacket{Err: errNotValidDNSResponse}
}
if len(dns.Questions) == 0 {
return nil, &modifier.ErrInvalidPacket{Err: errEmptyDNSQuestion}
}
// In practice, most if not all DNS clients only send one question
// per packet, so we don't care about the rest for now.
q := dns.Questions[0]
switch q.Type {
case layers.DNSTypeA:
if i.A != nil {
dns.Answers = []layers.DNSResourceRecord{{
Name: q.Name,
Type: layers.DNSTypeA,
Class: layers.DNSClassIN,
IP: i.A,
}}
}
case layers.DNSTypeAAAA:
if i.AAAA != nil {
dns.Answers = []layers.DNSResourceRecord{{
Name: q.Name,
Type: layers.DNSTypeAAAA,
Class: layers.DNSClassIN,
IP: i.AAAA,
}}
}
}
buf := gopacket.NewSerializeBuffer() // Modifiers must be safe for concurrent use, so we can't reuse the buffer
err = gopacket.SerializeLayers(buf, gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, dns)
return buf.Bytes(), err
}

18
ruleset/builtins/cidr.go Normal file
View File

@@ -0,0 +1,18 @@
package builtins
import (
"net"
)
func MatchCIDR(ip string, cidr *net.IPNet) bool {
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return false
}
return cidr.Contains(ipAddr)
}
func CompileCIDR(cidr string) (*net.IPNet, error) {
_, ipNet, err := net.ParseCIDR(cidr)
return ipNet, err
}

View File

@@ -0,0 +1,128 @@
package geo
import (
"io"
"net/http"
"os"
"time"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
)
const (
geoipFilename = "geoip.dat"
geoipURL = "https://cdn.jsdelivr.net/gh/Loyalsoldier/v2ray-rules-dat@release/geoip.dat"
geositeFilename = "geosite.dat"
geositeURL = "https://cdn.jsdelivr.net/gh/Loyalsoldier/v2ray-rules-dat@release/geosite.dat"
geoDefaultUpdateInterval = 7 * 24 * time.Hour // 7 days
)
var _ GeoLoader = (*V2GeoLoader)(nil)
// V2GeoLoader provides the on-demand GeoIP/MatchGeoSite database
// loading functionality required by the ACL engine.
// Empty filenames = automatic download from built-in URLs.
type V2GeoLoader struct {
GeoIPFilename string
GeoSiteFilename string
UpdateInterval time.Duration
DownloadFunc func(filename, url string)
DownloadErrFunc func(err error)
geoipMap map[string]*v2geo.GeoIP
geositeMap map[string]*v2geo.GeoSite
}
func NewDefaultGeoLoader(geoSiteFilename, geoIpFilename string) *V2GeoLoader {
return &V2GeoLoader{
GeoIPFilename: geoIpFilename,
GeoSiteFilename: geoSiteFilename,
DownloadFunc: func(filename, url string) {},
DownloadErrFunc: func(err error) {},
}
}
func (l *V2GeoLoader) shouldDownload(filename string) bool {
info, err := os.Stat(filename)
if os.IsNotExist(err) {
return true
}
dt := time.Since(info.ModTime())
if l.UpdateInterval == 0 {
return dt > geoDefaultUpdateInterval
} else {
return dt > l.UpdateInterval
}
}
func (l *V2GeoLoader) download(filename, url string) error {
l.DownloadFunc(filename, url)
resp, err := http.Get(url)
if err != nil {
l.DownloadErrFunc(err)
return err
}
defer resp.Body.Close()
f, err := os.Create(filename)
if err != nil {
l.DownloadErrFunc(err)
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
l.DownloadErrFunc(err)
return err
}
func (l *V2GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) {
if l.geoipMap != nil {
return l.geoipMap, nil
}
autoDL := false
filename := l.GeoIPFilename
if filename == "" {
autoDL = true
filename = geoipFilename
}
if autoDL && l.shouldDownload(filename) {
err := l.download(filename, geoipURL)
if err != nil {
return nil, err
}
}
m, err := v2geo.LoadGeoIP(filename)
if err != nil {
return nil, err
}
l.geoipMap = m
return m, nil
}
func (l *V2GeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) {
if l.geositeMap != nil {
return l.geositeMap, nil
}
autoDL := false
filename := l.GeoSiteFilename
if filename == "" {
autoDL = true
filename = geositeFilename
}
if autoDL && l.shouldDownload(filename) {
err := l.download(filename, geositeURL)
if err != nil {
return nil, err
}
}
m, err := v2geo.LoadGeoSite(filename)
if err != nil {
return nil, err
}
l.geositeMap = m
return m, nil
}

View File

@@ -0,0 +1,113 @@
package geo
import (
"net"
"strings"
"sync"
)
type GeoMatcher struct {
geoLoader GeoLoader
geoSiteMatcher map[string]hostMatcher
siteMatcherLock sync.Mutex
geoIpMatcher map[string]hostMatcher
ipMatcherLock sync.Mutex
}
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher {
return &GeoMatcher{
geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename),
geoSiteMatcher: make(map[string]hostMatcher),
geoIpMatcher: make(map[string]hostMatcher),
}
}
func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {
g.ipMatcherLock.Lock()
defer g.ipMatcherLock.Unlock()
matcher, ok := g.geoIpMatcher[condition]
if !ok {
// GeoIP matcher
condition = strings.ToLower(condition)
country := condition
if len(country) == 0 {
return false
}
gMap, err := g.geoLoader.LoadGeoIP()
if err != nil {
return false
}
list, ok := gMap[country]
if !ok || list == nil {
return false
}
matcher, err = newGeoIPMatcher(list)
if err != nil {
return false
}
g.geoIpMatcher[condition] = matcher
}
parseIp := net.ParseIP(ip)
if parseIp == nil {
return false
}
ipv4 := parseIp.To4()
if ipv4 != nil {
return matcher.Match(HostInfo{IPv4: ipv4})
}
ipv6 := parseIp.To16()
if ipv6 != nil {
return matcher.Match(HostInfo{IPv6: ipv6})
}
return false
}
func (g *GeoMatcher) MatchGeoSite(site, condition string) bool {
g.siteMatcherLock.Lock()
defer g.siteMatcherLock.Unlock()
matcher, ok := g.geoSiteMatcher[condition]
if !ok {
// MatchGeoSite matcher
condition = strings.ToLower(condition)
name, attrs := parseGeoSiteName(condition)
if len(name) == 0 {
return false
}
gMap, err := g.geoLoader.LoadGeoSite()
if err != nil {
return false
}
list, ok := gMap[name]
if !ok || list == nil {
return false
}
matcher, err = newGeositeMatcher(list, attrs)
if err != nil {
return false
}
g.geoSiteMatcher[condition] = matcher
}
return matcher.Match(HostInfo{Name: site})
}
func (g *GeoMatcher) LoadGeoSite() error {
_, err := g.geoLoader.LoadGeoSite()
return err
}
func (g *GeoMatcher) LoadGeoIP() error {
_, err := g.geoLoader.LoadGeoIP()
return err
}
func parseGeoSiteName(s string) (string, []string) {
parts := strings.Split(s, "@")
base := strings.TrimSpace(parts[0])
attrs := parts[1:]
for i := range attrs {
attrs[i] = strings.TrimSpace(attrs[i])
}
return base, attrs
}

View File

@@ -0,0 +1,27 @@
package geo
import (
"fmt"
"net"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
)
type HostInfo struct {
Name string
IPv4 net.IP
IPv6 net.IP
}
func (h HostInfo) String() string {
return fmt.Sprintf("%s|%s|%s", h.Name, h.IPv4, h.IPv6)
}
type GeoLoader interface {
LoadGeoIP() (map[string]*v2geo.GeoIP, error)
LoadGeoSite() (map[string]*v2geo.GeoSite, error)
}
type hostMatcher interface {
Match(HostInfo) bool
}

View File

@@ -0,0 +1,213 @@
package geo
import (
"bytes"
"errors"
"net"
"regexp"
"sort"
"strings"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo/v2geo"
)
var _ hostMatcher = (*geoipMatcher)(nil)
type geoipMatcher struct {
N4 []*net.IPNet // sorted
N6 []*net.IPNet // sorted
Inverse bool
}
// matchIP tries to match the given IP address with the corresponding IPNets.
// Note that this function does NOT handle the Inverse flag.
func (m *geoipMatcher) matchIP(ip net.IP) bool {
var n []*net.IPNet
if ip4 := ip.To4(); ip4 != nil {
// N4 stores IPv4 addresses in 4-byte form.
// Make sure we use it here too, otherwise bytes.Compare will fail.
ip = ip4
n = m.N4
} else {
n = m.N6
}
left, right := 0, len(n)-1
for left <= right {
mid := (left + right) / 2
if n[mid].Contains(ip) {
return true
} else if bytes.Compare(n[mid].IP, ip) < 0 {
left = mid + 1
} else {
right = mid - 1
}
}
return false
}
func (m *geoipMatcher) Match(host HostInfo) bool {
if host.IPv4 != nil {
if m.matchIP(host.IPv4) {
return !m.Inverse
}
}
if host.IPv6 != nil {
if m.matchIP(host.IPv6) {
return !m.Inverse
}
}
return m.Inverse
}
func newGeoIPMatcher(list *v2geo.GeoIP) (*geoipMatcher, error) {
n4 := make([]*net.IPNet, 0)
n6 := make([]*net.IPNet, 0)
for _, cidr := range list.Cidr {
if len(cidr.Ip) == 4 {
// IPv4
n4 = append(n4, &net.IPNet{
IP: cidr.Ip,
Mask: net.CIDRMask(int(cidr.Prefix), 32),
})
} else if len(cidr.Ip) == 16 {
// IPv6
n6 = append(n6, &net.IPNet{
IP: cidr.Ip,
Mask: net.CIDRMask(int(cidr.Prefix), 128),
})
} else {
return nil, errors.New("invalid IP length")
}
}
// Sort the IPNets, so we can do binary search later.
sort.Slice(n4, func(i, j int) bool {
return bytes.Compare(n4[i].IP, n4[j].IP) < 0
})
sort.Slice(n6, func(i, j int) bool {
return bytes.Compare(n6[i].IP, n6[j].IP) < 0
})
return &geoipMatcher{
N4: n4,
N6: n6,
Inverse: list.InverseMatch,
}, nil
}
var _ hostMatcher = (*geositeMatcher)(nil)
type geositeDomainType int
const (
geositeDomainPlain geositeDomainType = iota
geositeDomainRegex
geositeDomainRoot
geositeDomainFull
)
type geositeDomain struct {
Type geositeDomainType
Value string
Regex *regexp.Regexp
Attrs map[string]bool
}
type geositeMatcher struct {
Domains []geositeDomain
// Attributes are matched using "and" logic - if you have multiple attributes here,
// a domain must have all of those attributes to be considered a match.
Attrs []string
}
func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool {
// Match attributes first
if len(m.Attrs) > 0 {
if len(domain.Attrs) == 0 {
return false
}
for _, attr := range m.Attrs {
if !domain.Attrs[attr] {
return false
}
}
}
switch domain.Type {
case geositeDomainPlain:
return strings.Contains(host.Name, domain.Value)
case geositeDomainRegex:
if domain.Regex != nil {
return domain.Regex.MatchString(host.Name)
}
case geositeDomainFull:
return host.Name == domain.Value
case geositeDomainRoot:
if host.Name == domain.Value {
return true
}
return strings.HasSuffix(host.Name, "."+domain.Value)
default:
return false
}
return false
}
func (m *geositeMatcher) Match(host HostInfo) bool {
for _, domain := range m.Domains {
if m.matchDomain(domain, host) {
return true
}
}
return false
}
func newGeositeMatcher(list *v2geo.GeoSite, attrs []string) (*geositeMatcher, error) {
domains := make([]geositeDomain, len(list.Domain))
for i, domain := range list.Domain {
switch domain.Type {
case v2geo.Domain_Plain:
domains[i] = geositeDomain{
Type: geositeDomainPlain,
Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute),
}
case v2geo.Domain_Regex:
regex, err := regexp.Compile(domain.Value)
if err != nil {
return nil, err
}
domains[i] = geositeDomain{
Type: geositeDomainRegex,
Regex: regex,
Attrs: domainAttributeToMap(domain.Attribute),
}
case v2geo.Domain_Full:
domains[i] = geositeDomain{
Type: geositeDomainFull,
Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute),
}
case v2geo.Domain_RootDomain:
domains[i] = geositeDomain{
Type: geositeDomainRoot,
Value: domain.Value,
Attrs: domainAttributeToMap(domain.Attribute),
}
default:
return nil, errors.New("unsupported domain type")
}
}
return &geositeMatcher{
Domains: domains,
Attrs: attrs,
}, nil
}
func domainAttributeToMap(attrs []*v2geo.Domain_Attribute) map[string]bool {
m := make(map[string]bool)
for _, attr := range attrs {
// Supposedly there are also int attributes,
// but nobody seems to use them, so we treat everything as boolean for now.
m[attr.Key] = true
}
return m
}

View File

@@ -0,0 +1,44 @@
package v2geo
import (
"os"
"strings"
"google.golang.org/protobuf/proto"
)
// LoadGeoIP loads a GeoIP data file and converts it to a map.
// The keys of the map (country codes) are all normalized to lowercase.
func LoadGeoIP(filename string) (map[string]*GeoIP, error) {
bs, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
var list GeoIPList
if err := proto.Unmarshal(bs, &list); err != nil {
return nil, err
}
m := make(map[string]*GeoIP)
for _, entry := range list.Entry {
m[strings.ToLower(entry.CountryCode)] = entry
}
return m, nil
}
// LoadGeoSite loads a GeoSite data file and converts it to a map.
// The keys of the map (site keys) are all normalized to lowercase.
func LoadGeoSite(filename string) (map[string]*GeoSite, error) {
bs, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
var list GeoSiteList
if err := proto.Unmarshal(bs, &list); err != nil {
return nil, err
}
m := make(map[string]*GeoSite)
for _, entry := range list.Entry {
m[strings.ToLower(entry.CountryCode)] = entry
}
return m, nil
}

View File

@@ -0,0 +1,745 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.31.0
// protoc v4.24.4
// source: v2geo.proto
package v2geo
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// Type of domain value.
type Domain_Type int32
const (
// The value is used as is.
Domain_Plain Domain_Type = 0
// The value is used as a regular expression.
Domain_Regex Domain_Type = 1
// The value is a root domain.
Domain_RootDomain Domain_Type = 2
// The value is a domain.
Domain_Full Domain_Type = 3
)
// Enum value maps for Domain_Type.
var (
Domain_Type_name = map[int32]string{
0: "Plain",
1: "Regex",
2: "RootDomain",
3: "Full",
}
Domain_Type_value = map[string]int32{
"Plain": 0,
"Regex": 1,
"RootDomain": 2,
"Full": 3,
}
)
func (x Domain_Type) Enum() *Domain_Type {
p := new(Domain_Type)
*p = x
return p
}
func (x Domain_Type) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (Domain_Type) Descriptor() protoreflect.EnumDescriptor {
return file_v2geo_proto_enumTypes[0].Descriptor()
}
func (Domain_Type) Type() protoreflect.EnumType {
return &file_v2geo_proto_enumTypes[0]
}
func (x Domain_Type) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use Domain_Type.Descriptor instead.
func (Domain_Type) EnumDescriptor() ([]byte, []int) {
return file_v2geo_proto_rawDescGZIP(), []int{0, 0}
}
// Domain for routing decision.
type Domain struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Domain matching type.
Type Domain_Type `protobuf:"varint,1,opt,name=type,proto3,enum=Domain_Type" json:"type,omitempty"`
// Domain value.
Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
// Attributes of this domain. May be used for filtering.
Attribute []*Domain_Attribute `protobuf:"bytes,3,rep,name=attribute,proto3" json:"attribute,omitempty"`
}
func (x *Domain) Reset() {
*x = Domain{}
if protoimpl.UnsafeEnabled {
mi := &file_v2geo_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Domain) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Domain) ProtoMessage() {}
func (x *Domain) ProtoReflect() protoreflect.Message {
mi := &file_v2geo_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Domain.ProtoReflect.Descriptor instead.
func (*Domain) Descriptor() ([]byte, []int) {
return file_v2geo_proto_rawDescGZIP(), []int{0}
}
func (x *Domain) GetType() Domain_Type {
if x != nil {
return x.Type
}
return Domain_Plain
}
func (x *Domain) GetValue() string {
if x != nil {
return x.Value
}
return ""
}
func (x *Domain) GetAttribute() []*Domain_Attribute {
if x != nil {
return x.Attribute
}
return nil
}
// IP for routing decision, in CIDR form.
type CIDR struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// IP address, should be either 4 or 16 bytes.
Ip []byte `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"`
// Number of leading ones in the network mask.
Prefix uint32 `protobuf:"varint,2,opt,name=prefix,proto3" json:"prefix,omitempty"`
}
func (x *CIDR) Reset() {
*x = CIDR{}
if protoimpl.UnsafeEnabled {
mi := &file_v2geo_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *CIDR) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*CIDR) ProtoMessage() {}
func (x *CIDR) ProtoReflect() protoreflect.Message {
mi := &file_v2geo_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use CIDR.ProtoReflect.Descriptor instead.
func (*CIDR) Descriptor() ([]byte, []int) {
return file_v2geo_proto_rawDescGZIP(), []int{1}
}
func (x *CIDR) GetIp() []byte {
if x != nil {
return x.Ip
}
return nil
}
func (x *CIDR) GetPrefix() uint32 {
if x != nil {
return x.Prefix
}
return 0
}
type GeoIP struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
CountryCode string `protobuf:"bytes,1,opt,name=country_code,json=countryCode,proto3" json:"country_code,omitempty"`
Cidr []*CIDR `protobuf:"bytes,2,rep,name=cidr,proto3" json:"cidr,omitempty"`
InverseMatch bool `protobuf:"varint,3,opt,name=inverse_match,json=inverseMatch,proto3" json:"inverse_match,omitempty"`
// resource_hash instruct simplified config converter to load domain from geo file.
ResourceHash []byte `protobuf:"bytes,4,opt,name=resource_hash,json=resourceHash,proto3" json:"resource_hash,omitempty"`
Code string `protobuf:"bytes,5,opt,name=code,proto3" json:"code,omitempty"`
}
func (x *GeoIP) Reset() {
*x = GeoIP{}
if protoimpl.UnsafeEnabled {
mi := &file_v2geo_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GeoIP) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GeoIP) ProtoMessage() {}
func (x *GeoIP) ProtoReflect() protoreflect.Message {
mi := &file_v2geo_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GeoIP.ProtoReflect.Descriptor instead.
func (*GeoIP) Descriptor() ([]byte, []int) {
return file_v2geo_proto_rawDescGZIP(), []int{2}
}
func (x *GeoIP) GetCountryCode() string {
if x != nil {
return x.CountryCode
}
return ""
}
func (x *GeoIP) GetCidr() []*CIDR {
if x != nil {
return x.Cidr
}
return nil
}
func (x *GeoIP) GetInverseMatch() bool {
if x != nil {
return x.InverseMatch
}
return false
}
func (x *GeoIP) GetResourceHash() []byte {
if x != nil {
return x.ResourceHash
}
return nil
}
func (x *GeoIP) GetCode() string {
if x != nil {
return x.Code
}
return ""
}
type GeoIPList struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Entry []*GeoIP `protobuf:"bytes,1,rep,name=entry,proto3" json:"entry,omitempty"`
}
func (x *GeoIPList) Reset() {
*x = GeoIPList{}
if protoimpl.UnsafeEnabled {
mi := &file_v2geo_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GeoIPList) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GeoIPList) ProtoMessage() {}
func (x *GeoIPList) ProtoReflect() protoreflect.Message {
mi := &file_v2geo_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GeoIPList.ProtoReflect.Descriptor instead.
func (*GeoIPList) Descriptor() ([]byte, []int) {
return file_v2geo_proto_rawDescGZIP(), []int{3}
}
func (x *GeoIPList) GetEntry() []*GeoIP {
if x != nil {
return x.Entry
}
return nil
}
type GeoSite struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
CountryCode string `protobuf:"bytes,1,opt,name=country_code,json=countryCode,proto3" json:"country_code,omitempty"`
Domain []*Domain `protobuf:"bytes,2,rep,name=domain,proto3" json:"domain,omitempty"`
// resource_hash instruct simplified config converter to load domain from geo file.
ResourceHash []byte `protobuf:"bytes,3,opt,name=resource_hash,json=resourceHash,proto3" json:"resource_hash,omitempty"`
Code string `protobuf:"bytes,4,opt,name=code,proto3" json:"code,omitempty"`
}
func (x *GeoSite) Reset() {
*x = GeoSite{}
if protoimpl.UnsafeEnabled {
mi := &file_v2geo_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GeoSite) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GeoSite) ProtoMessage() {}
func (x *GeoSite) ProtoReflect() protoreflect.Message {
mi := &file_v2geo_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MatchGeoSite.ProtoReflect.Descriptor instead.
func (*GeoSite) Descriptor() ([]byte, []int) {
return file_v2geo_proto_rawDescGZIP(), []int{4}
}
func (x *GeoSite) GetCountryCode() string {
if x != nil {
return x.CountryCode
}
return ""
}
func (x *GeoSite) GetDomain() []*Domain {
if x != nil {
return x.Domain
}
return nil
}
func (x *GeoSite) GetResourceHash() []byte {
if x != nil {
return x.ResourceHash
}
return nil
}
func (x *GeoSite) GetCode() string {
if x != nil {
return x.Code
}
return ""
}
type GeoSiteList struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Entry []*GeoSite `protobuf:"bytes,1,rep,name=entry,proto3" json:"entry,omitempty"`
}
func (x *GeoSiteList) Reset() {
*x = GeoSiteList{}
if protoimpl.UnsafeEnabled {
mi := &file_v2geo_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *GeoSiteList) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GeoSiteList) ProtoMessage() {}
func (x *GeoSiteList) ProtoReflect() protoreflect.Message {
mi := &file_v2geo_proto_msgTypes[5]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GeoSiteList.ProtoReflect.Descriptor instead.
func (*GeoSiteList) Descriptor() ([]byte, []int) {
return file_v2geo_proto_rawDescGZIP(), []int{5}
}
func (x *GeoSiteList) GetEntry() []*GeoSite {
if x != nil {
return x.Entry
}
return nil
}
type Domain_Attribute struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
// Types that are assignable to TypedValue:
//
// *Domain_Attribute_BoolValue
// *Domain_Attribute_IntValue
TypedValue isDomain_Attribute_TypedValue `protobuf_oneof:"typed_value"`
}
func (x *Domain_Attribute) Reset() {
*x = Domain_Attribute{}
if protoimpl.UnsafeEnabled {
mi := &file_v2geo_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Domain_Attribute) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Domain_Attribute) ProtoMessage() {}
func (x *Domain_Attribute) ProtoReflect() protoreflect.Message {
mi := &file_v2geo_proto_msgTypes[6]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Domain_Attribute.ProtoReflect.Descriptor instead.
func (*Domain_Attribute) Descriptor() ([]byte, []int) {
return file_v2geo_proto_rawDescGZIP(), []int{0, 0}
}
func (x *Domain_Attribute) GetKey() string {
if x != nil {
return x.Key
}
return ""
}
func (m *Domain_Attribute) GetTypedValue() isDomain_Attribute_TypedValue {
if m != nil {
return m.TypedValue
}
return nil
}
func (x *Domain_Attribute) GetBoolValue() bool {
if x, ok := x.GetTypedValue().(*Domain_Attribute_BoolValue); ok {
return x.BoolValue
}
return false
}
func (x *Domain_Attribute) GetIntValue() int64 {
if x, ok := x.GetTypedValue().(*Domain_Attribute_IntValue); ok {
return x.IntValue
}
return 0
}
type isDomain_Attribute_TypedValue interface {
isDomain_Attribute_TypedValue()
}
type Domain_Attribute_BoolValue struct {
BoolValue bool `protobuf:"varint,2,opt,name=bool_value,json=boolValue,proto3,oneof"`
}
type Domain_Attribute_IntValue struct {
IntValue int64 `protobuf:"varint,3,opt,name=int_value,json=intValue,proto3,oneof"`
}
func (*Domain_Attribute_BoolValue) isDomain_Attribute_TypedValue() {}
func (*Domain_Attribute_IntValue) isDomain_Attribute_TypedValue() {}
var File_v2geo_proto protoreflect.FileDescriptor
var file_v2geo_proto_rawDesc = []byte{
0x0a, 0x0b, 0x76, 0x32, 0x67, 0x65, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x97, 0x02,
0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x20, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0c, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x2e,
0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61,
0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65,
0x12, 0x2f, 0x0a, 0x09, 0x61, 0x74, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x18, 0x03, 0x20,
0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x2e, 0x41, 0x74, 0x74,
0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x52, 0x09, 0x61, 0x74, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74,
0x65, 0x1a, 0x6c, 0x0a, 0x09, 0x41, 0x74, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x12, 0x10,
0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79,
0x12, 0x1f, 0x0a, 0x0a, 0x62, 0x6f, 0x6f, 0x6c, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02,
0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x09, 0x62, 0x6f, 0x6f, 0x6c, 0x56, 0x61, 0x6c, 0x75,
0x65, 0x12, 0x1d, 0x0a, 0x09, 0x69, 0x6e, 0x74, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x03,
0x20, 0x01, 0x28, 0x03, 0x48, 0x00, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x56, 0x61, 0x6c, 0x75, 0x65,
0x42, 0x0d, 0x0a, 0x0b, 0x74, 0x79, 0x70, 0x65, 0x64, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22,
0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x6c, 0x61, 0x69, 0x6e,
0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x65, 0x67, 0x65, 0x78, 0x10, 0x01, 0x12, 0x0e, 0x0a,
0x0a, 0x52, 0x6f, 0x6f, 0x74, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x10, 0x02, 0x12, 0x08, 0x0a,
0x04, 0x46, 0x75, 0x6c, 0x6c, 0x10, 0x03, 0x22, 0x2e, 0x0a, 0x04, 0x43, 0x49, 0x44, 0x52, 0x12,
0x0e, 0x0a, 0x02, 0x69, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x02, 0x69, 0x70, 0x12,
0x16, 0x0a, 0x06, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52,
0x06, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x22, 0xa3, 0x01, 0x0a, 0x05, 0x47, 0x65, 0x6f, 0x49,
0x50, 0x12, 0x21, 0x0a, 0x0c, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x79, 0x5f, 0x63, 0x6f, 0x64,
0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x79,
0x43, 0x6f, 0x64, 0x65, 0x12, 0x19, 0x0a, 0x04, 0x63, 0x69, 0x64, 0x72, 0x18, 0x02, 0x20, 0x03,
0x28, 0x0b, 0x32, 0x05, 0x2e, 0x43, 0x49, 0x44, 0x52, 0x52, 0x04, 0x63, 0x69, 0x64, 0x72, 0x12,
0x23, 0x0a, 0x0d, 0x69, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x65, 0x5f, 0x6d, 0x61, 0x74, 0x63, 0x68,
0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x69, 0x6e, 0x76, 0x65, 0x72, 0x73, 0x65, 0x4d,
0x61, 0x74, 0x63, 0x68, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65,
0x5f, 0x68, 0x61, 0x73, 0x68, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0c, 0x72, 0x65, 0x73,
0x6f, 0x75, 0x72, 0x63, 0x65, 0x48, 0x61, 0x73, 0x68, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64,
0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x22, 0x29, 0x0a,
0x09, 0x47, 0x65, 0x6f, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x05, 0x65, 0x6e,
0x74, 0x72, 0x79, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x06, 0x2e, 0x47, 0x65, 0x6f, 0x49,
0x50, 0x52, 0x05, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x22, 0x86, 0x01, 0x0a, 0x07, 0x47, 0x65, 0x6f,
0x53, 0x69, 0x74, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x72, 0x79, 0x5f,
0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x63, 0x6f, 0x75, 0x6e,
0x74, 0x72, 0x79, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1f, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69,
0x6e, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x07, 0x2e, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e,
0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x73, 0x6f,
0x75, 0x72, 0x63, 0x65, 0x5f, 0x68, 0x61, 0x73, 0x68, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52,
0x0c, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x48, 0x61, 0x73, 0x68, 0x12, 0x12, 0x0a,
0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x63, 0x6f, 0x64,
0x65, 0x22, 0x2d, 0x0a, 0x0b, 0x47, 0x65, 0x6f, 0x53, 0x69, 0x74, 0x65, 0x4c, 0x69, 0x73, 0x74,
0x12, 0x1e, 0x0a, 0x05, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32,
0x08, 0x2e, 0x47, 0x65, 0x6f, 0x53, 0x69, 0x74, 0x65, 0x52, 0x05, 0x65, 0x6e, 0x74, 0x72, 0x79,
0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x76, 0x32, 0x67, 0x65, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
}
var (
file_v2geo_proto_rawDescOnce sync.Once
file_v2geo_proto_rawDescData = file_v2geo_proto_rawDesc
)
func file_v2geo_proto_rawDescGZIP() []byte {
file_v2geo_proto_rawDescOnce.Do(func() {
file_v2geo_proto_rawDescData = protoimpl.X.CompressGZIP(file_v2geo_proto_rawDescData)
})
return file_v2geo_proto_rawDescData
}
var file_v2geo_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_v2geo_proto_msgTypes = make([]protoimpl.MessageInfo, 7)
var file_v2geo_proto_goTypes = []interface{}{
(Domain_Type)(0), // 0: Domain.Type
(*Domain)(nil), // 1: Domain
(*CIDR)(nil), // 2: CIDR
(*GeoIP)(nil), // 3: GeoIP
(*GeoIPList)(nil), // 4: GeoIPList
(*GeoSite)(nil), // 5: MatchGeoSite
(*GeoSiteList)(nil), // 6: GeoSiteList
(*Domain_Attribute)(nil), // 7: Domain.Attribute
}
var file_v2geo_proto_depIdxs = []int32{
0, // 0: Domain.type:type_name -> Domain.Type
7, // 1: Domain.attribute:type_name -> Domain.Attribute
2, // 2: GeoIP.cidr:type_name -> CIDR
3, // 3: GeoIPList.entry:type_name -> GeoIP
1, // 4: MatchGeoSite.domain:type_name -> Domain
5, // 5: GeoSiteList.entry:type_name -> MatchGeoSite
6, // [6:6] is the sub-list for method output_type
6, // [6:6] is the sub-list for method input_type
6, // [6:6] is the sub-list for extension type_name
6, // [6:6] is the sub-list for extension extendee
0, // [0:6] is the sub-list for field type_name
}
func init() { file_v2geo_proto_init() }
func file_v2geo_proto_init() {
if File_v2geo_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_v2geo_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Domain); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_v2geo_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*CIDR); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_v2geo_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GeoIP); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_v2geo_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GeoIPList); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_v2geo_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GeoSite); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_v2geo_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GeoSiteList); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_v2geo_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Domain_Attribute); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_v2geo_proto_msgTypes[6].OneofWrappers = []interface{}{
(*Domain_Attribute_BoolValue)(nil),
(*Domain_Attribute_IntValue)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_v2geo_proto_rawDesc,
NumEnums: 1,
NumMessages: 7,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_v2geo_proto_goTypes,
DependencyIndexes: file_v2geo_proto_depIdxs,
EnumInfos: file_v2geo_proto_enumTypes,
MessageInfos: file_v2geo_proto_msgTypes,
}.Build()
File_v2geo_proto = out.File
file_v2geo_proto_rawDesc = nil
file_v2geo_proto_goTypes = nil
file_v2geo_proto_depIdxs = nil
}

View File

@@ -0,0 +1,76 @@
syntax = "proto3";
option go_package = "./v2geo";
// This file is copied from
// https://github.com/v2fly/v2ray-core/blob/master/app/router/routercommon/common.proto
// with some modifications.
// Domain for routing decision.
message Domain {
// Type of domain value.
enum Type {
// The value is used as is.
Plain = 0;
// The value is used as a regular expression.
Regex = 1;
// The value is a root domain.
RootDomain = 2;
// The value is a domain.
Full = 3;
}
// Domain matching type.
Type type = 1;
// Domain value.
string value = 2;
message Attribute {
string key = 1;
oneof typed_value {
bool bool_value = 2;
int64 int_value = 3;
}
}
// Attributes of this domain. May be used for filtering.
repeated Attribute attribute = 3;
}
// IP for routing decision, in CIDR form.
message CIDR {
// IP address, should be either 4 or 16 bytes.
bytes ip = 1;
// Number of leading ones in the network mask.
uint32 prefix = 2;
}
message GeoIP {
string country_code = 1;
repeated CIDR cidr = 2;
bool inverse_match = 3;
// resource_hash instruct simplified config converter to load domain from geo file.
bytes resource_hash = 4;
string code = 5;
}
message GeoIPList {
repeated GeoIP entry = 1;
}
message GeoSite {
string country_code = 1;
repeated Domain domain = 2;
// resource_hash instruct simplified config converter to load domain from geo file.
bytes resource_hash = 3;
string code = 4;
}
message GeoSiteList {
repeated GeoSite entry = 1;
}

379
ruleset/expr.go Normal file
View File

@@ -0,0 +1,379 @@
package ruleset
import (
"context"
"fmt"
"net"
"os"
"reflect"
"strings"
"time"
"github.com/expr-lang/expr/builtin"
"github.com/expr-lang/expr"
"github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/conf"
"github.com/expr-lang/expr/vm"
"gopkg.in/yaml.v3"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/modifier"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins"
"git.difuse.io/Difuse/Mellaris/ruleset/builtins/geo"
)
// ExprRule is the external representation of an expression rule.
type ExprRule struct {
Name string `yaml:"name"`
Action string `yaml:"action"`
Log bool `yaml:"log"`
Modifier ModifierEntry `yaml:"modifier"`
Expr string `yaml:"expr"`
}
type ModifierEntry struct {
Name string `yaml:"name"`
Args map[string]interface{} `yaml:"args"`
}
func ExprRulesFromYAML(file string) ([]ExprRule, error) {
bs, err := os.ReadFile(file)
if err != nil {
return nil, err
}
var rules []ExprRule
err = yaml.Unmarshal(bs, &rules)
return rules, err
}
// compiledExprRule is the internal, compiled representation of an expression rule.
type compiledExprRule struct {
Name string
Action *Action // fallthrough if nil
Log bool
ModInstance modifier.Instance
Program *vm.Program
}
var _ Ruleset = (*exprRuleset)(nil)
type exprRuleset struct {
Rules []compiledExprRule
Ans []analyzer.Analyzer
Logger Logger
}
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
return r.Ans
}
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
env := streamInfoToExprEnv(info)
for _, rule := range r.Rules {
v, err := vm.Run(rule.Program, env)
if err != nil {
// Log the error and continue to the next rule.
r.Logger.MatchError(info, rule.Name, err)
continue
}
if vBool, ok := v.(bool); ok && vBool {
if rule.Log {
r.Logger.Log(info, rule.Name)
}
if rule.Action != nil {
return MatchResult{
Action: *rule.Action,
ModInstance: rule.ModInstance,
}
}
}
}
// No match
return MatchResult{
Action: ActionMaybe,
}
}
// CompileExprRules compiles a list of expression rules into a ruleset.
// It returns an error if any of the rules are invalid, or if any of the analyzers
// used by the rules are unknown (not provided in the analyzer list).
func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier.Modifier, config *BuiltinConfig) (Ruleset, error) {
var compiledRules []compiledExprRule
fullAnMap := analyzersToMap(ans)
fullModMap := modifiersToMap(mods)
depAnMap := make(map[string]analyzer.Analyzer)
funcMap := buildFunctionMap(config)
// Compile all rules and build a map of analyzers that are used by the rules.
for _, rule := range rules {
if rule.Action == "" && !rule.Log {
return nil, fmt.Errorf("rule %q must have at least one of action or log", rule.Name)
}
var action *Action
if rule.Action != "" {
a, ok := actionStringToAction(rule.Action)
if !ok {
return nil, fmt.Errorf("rule %q has invalid action %q", rule.Name, rule.Action)
}
action = &a
}
visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)}
patcher := &idPatcher{FuncMap: funcMap}
program, err := expr.Compile(rule.Expr,
func(c *conf.Config) {
c.Strict = false
c.Expect = reflect.Bool
c.Visitors = append(c.Visitors, visitor, patcher)
for name, f := range funcMap {
c.Functions[name] = &builtin.Function{
Name: name,
Func: f.Func,
Types: f.Types,
}
}
},
)
if err != nil {
return nil, fmt.Errorf("rule %q has invalid expression: %w", rule.Name, err)
}
if patcher.Err != nil {
return nil, fmt.Errorf("rule %q failed to patch expression: %w", rule.Name, patcher.Err)
}
for name := range visitor.Identifiers {
// Skip built-in analyzers & user-defined variables
if isBuiltInAnalyzer(name) || visitor.Variables[name] {
continue
}
if f, ok := funcMap[name]; ok {
// Built-in function, initialize if necessary
if f.InitFunc != nil {
if err := f.InitFunc(); err != nil {
return nil, fmt.Errorf("rule %q failed to initialize function %q: %w", rule.Name, name, err)
}
}
} else if a, ok := fullAnMap[name]; ok {
// Analyzer, add to dependency map
depAnMap[name] = a
}
}
cr := compiledExprRule{
Name: rule.Name,
Action: action,
Log: rule.Log,
Program: program,
}
if action != nil && *action == ActionModify {
mod, ok := fullModMap[rule.Modifier.Name]
if !ok {
return nil, fmt.Errorf("rule %q uses unknown modifier %q", rule.Name, rule.Modifier.Name)
}
modInst, err := mod.New(rule.Modifier.Args)
if err != nil {
return nil, fmt.Errorf("rule %q failed to create modifier instance: %w", rule.Name, err)
}
cr.ModInstance = modInst
}
compiledRules = append(compiledRules, cr)
}
// Convert the analyzer map to a list.
var depAns []analyzer.Analyzer
for _, a := range depAnMap {
depAns = append(depAns, a)
}
return &exprRuleset{
Rules: compiledRules,
Ans: depAns,
Logger: config.Logger,
}, nil
}
func streamInfoToExprEnv(info StreamInfo) map[string]interface{} {
m := map[string]interface{}{
"id": info.ID,
"proto": info.Protocol.String(),
"ip": map[string]string{
"src": info.SrcIP.String(),
"dst": info.DstIP.String(),
},
"port": map[string]uint16{
"src": info.SrcPort,
"dst": info.DstPort,
},
}
for anName, anProps := range info.Props {
if len(anProps) != 0 {
// Ignore analyzers with empty properties
m[anName] = anProps
}
}
return m
}
func isBuiltInAnalyzer(name string) bool {
switch name {
case "id", "proto", "ip", "port":
return true
default:
return false
}
}
func actionStringToAction(action string) (Action, bool) {
switch strings.ToLower(action) {
case "allow":
return ActionAllow, true
case "block":
return ActionBlock, true
case "drop":
return ActionDrop, true
case "modify":
return ActionModify, true
default:
return ActionMaybe, false
}
}
// analyzersToMap converts a list of analyzers to a map of name -> analyzer.
// This is for easier lookup when compiling rules.
func analyzersToMap(ans []analyzer.Analyzer) map[string]analyzer.Analyzer {
anMap := make(map[string]analyzer.Analyzer)
for _, a := range ans {
anMap[a.Name()] = a
}
return anMap
}
// modifiersToMap converts a list of modifiers to a map of name -> modifier.
// This is for easier lookup when compiling rules.
func modifiersToMap(mods []modifier.Modifier) map[string]modifier.Modifier {
modMap := make(map[string]modifier.Modifier)
for _, m := range mods {
modMap[m.Name()] = m
}
return modMap
}
// idVisitor is a visitor that collects all identifiers in an expression.
// This is for determining which analyzers are used by the expression.
type idVisitor struct {
Variables map[string]bool
Identifiers map[string]bool
}
func (v *idVisitor) Visit(node *ast.Node) {
if varNode, ok := (*node).(*ast.VariableDeclaratorNode); ok {
v.Variables[varNode.Name] = true
} else if idNode, ok := (*node).(*ast.IdentifierNode); ok {
v.Identifiers[idNode.Value] = true
}
}
// idPatcher patches the AST during expr compilation, replacing certain values with
// their internal representations for better runtime performance.
type idPatcher struct {
FuncMap map[string]*Function
Err error
}
func (p *idPatcher) Visit(node *ast.Node) {
switch (*node).(type) {
case *ast.CallNode:
callNode := (*node).(*ast.CallNode)
if callNode.Callee == nil {
// Ignore invalid call nodes
return
}
if f, ok := p.FuncMap[callNode.Callee.String()]; ok {
if f.PatchFunc != nil {
if err := f.PatchFunc(&callNode.Arguments); err != nil {
p.Err = err
return
}
}
}
}
}
type Function struct {
InitFunc func() error
PatchFunc func(args *[]ast.Node) error
Func func(params ...any) (any, error)
Types []reflect.Type
}
func buildFunctionMap(config *BuiltinConfig) map[string]*Function {
geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
return map[string]*Function{
"geoip": {
InitFunc: geoMatcher.LoadGeoIP,
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
},
"geosite": {
InitFunc: geoMatcher.LoadGeoSite,
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
},
"cidr": {
InitFunc: nil,
PatchFunc: func(args *[]ast.Node) error {
cidrStringNode, ok := (*args)[1].(*ast.StringNode)
if !ok {
return fmt.Errorf("cidr: invalid argument type")
}
cidr, err := builtins.CompileCIDR(cidrStringNode.Value)
if err != nil {
return err
}
(*args)[1] = &ast.ConstantNode{Value: cidr}
return nil
},
Func: func(params ...any) (any, error) {
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
},
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
},
"lookup": {
InitFunc: nil,
PatchFunc: func(args *[]ast.Node) error {
var serverStr *ast.StringNode
if len(*args) > 1 {
// Has the optional server argument
var ok bool
serverStr, ok = (*args)[1].(*ast.StringNode)
if !ok {
return fmt.Errorf("lookup: invalid argument type")
}
}
r := &net.Resolver{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
if serverStr != nil {
address = serverStr.Value
}
return config.ProtectedDialContext(ctx, network, address)
},
}
if len(*args) > 1 {
(*args)[1] = &ast.ConstantNode{Value: r}
} else {
*args = append(*args, &ast.ConstantNode{Value: r})
}
return nil
},
Func: func(params ...any) (any, error) {
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel()
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
},
Types: []reflect.Type{
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
},
},
}
}

108
ruleset/interface.go Normal file
View File

@@ -0,0 +1,108 @@
package ruleset
import (
"context"
"net"
"strconv"
"git.difuse.io/Difuse/Mellaris/analyzer"
"git.difuse.io/Difuse/Mellaris/modifier"
)
type Action int
const (
// ActionMaybe indicates that the ruleset hasn't seen anything worth blocking based on
// current information, but that may change if volatile fields change in the future.
ActionMaybe Action = iota
// ActionAllow indicates that the stream should be allowed regardless of future changes.
ActionAllow
// ActionBlock indicates that the stream should be blocked.
ActionBlock
// ActionDrop indicates that the current packet should be dropped,
// but the stream should be allowed to continue.
// Only valid for UDP streams. Equivalent to ActionBlock for TCP streams.
ActionDrop
// ActionModify indicates that the current packet should be modified,
// and the stream should be allowed to continue.
// Only valid for UDP streams. Equivalent to ActionMaybe for TCP streams.
ActionModify
)
func (a Action) String() string {
switch a {
case ActionMaybe:
return "maybe"
case ActionAllow:
return "allow"
case ActionBlock:
return "block"
case ActionDrop:
return "drop"
case ActionModify:
return "modify"
default:
return "unknown"
}
}
type Protocol int
func (p Protocol) String() string {
switch p {
case ProtocolTCP:
return "tcp"
case ProtocolUDP:
return "udp"
default:
return "unknown"
}
}
const (
ProtocolTCP Protocol = iota
ProtocolUDP
)
type StreamInfo struct {
ID int64
Protocol Protocol
SrcIP, DstIP net.IP
SrcPort, DstPort uint16
Props analyzer.CombinedPropMap
}
func (i StreamInfo) SrcString() string {
return net.JoinHostPort(i.SrcIP.String(), strconv.Itoa(int(i.SrcPort)))
}
func (i StreamInfo) DstString() string {
return net.JoinHostPort(i.DstIP.String(), strconv.Itoa(int(i.DstPort)))
}
type MatchResult struct {
Action Action
ModInstance modifier.Instance
}
type Ruleset interface {
// Analyzers returns the list of analyzers to use for a stream.
// It must be safe for concurrent use by multiple workers.
Analyzers(StreamInfo) []analyzer.Analyzer
// Match matches a stream against the ruleset and returns the result.
// It must be safe for concurrent use by multiple workers.
Match(StreamInfo) MatchResult
}
// Logger is the logging interface for the ruleset.
type Logger interface {
Log(info StreamInfo, name string)
MatchError(info StreamInfo, name string, err error)
}
type BuiltinConfig struct {
Logger Logger
GeoSiteFilename string
GeoIpFilename string
ProtectedDialContext func(ctx context.Context, network, address string) (net.Conn, error)
}