From 00366b224b2e28861b80f677e8aa604c5d08dae3 Mon Sep 17 00:00:00 2001
From: Kelo <meetkelo@outlook.com>
Date: Sat, 29 Oct 2022 16:27:26 +0800
Subject: [PATCH] optimize: reduce disk writes

---
 service/db/boltdb.go  | 43 +++++++++++++++++++++++++++++++----
 service/db/listOp.go  | 48 +++++++++++++++++++++------------------
 service/db/plainOp.go | 52 ++++++++++++++++++++++++-------------------
 service/db/setOp.go   | 20 +++++++++--------
 4 files changed, 105 insertions(+), 58 deletions(-)

--- a/db/boltdb.go
+++ b/db/boltdb.go
@@ -1,13 +1,14 @@
 package db
 
 import (
-	"go.etcd.io/bbolt"
-	"github.com/v2rayA/v2rayA/conf"
-	"github.com/v2rayA/v2rayA/pkg/util/copyfile"
-	"github.com/v2rayA/v2rayA/pkg/util/log"
 	"os"
 	"path/filepath"
 	"sync"
+
+	"github.com/v2rayA/v2rayA/conf"
+	"github.com/v2rayA/v2rayA/pkg/util/copyfile"
+	"github.com/v2rayA/v2rayA/pkg/util/log"
+	"go.etcd.io/bbolt"
 )
 
 var once sync.Once
@@ -46,3 +47,37 @@ func DB() *bbolt.DB {
 	once.Do(initDB)
 	return db
 }
+
+// The function should return a dirty flag.
+// If the dirty flag is true and there is no error then the transaction is commited.
+// Otherwise, the transaction is rolled back.
+func Transaction(db *bbolt.DB, fn func(*bbolt.Tx) (bool, error)) error {
+	tx, err := db.Begin(true)
+	if err != nil {
+		return err
+	}
+	defer tx.Rollback()
+	dirty, err := fn(tx)
+	if err != nil {
+		_ = tx.Rollback()
+		return err
+	}
+	if !dirty {
+		return nil
+	}
+	return tx.Commit()
+}
+
+// If the bucket does not exist, the dirty flag is setted
+func CreateBucketIfNotExists(tx *bbolt.Tx, name []byte, dirty *bool) (*bbolt.Bucket, error) {
+	bkt := tx.Bucket(name)
+	if bkt != nil {
+		return bkt, nil
+	}
+	bkt, err := tx.CreateBucket(name)
+	if err != nil {
+		return nil, err
+	}
+	*dirty = true
+	return bkt, nil
+}
--- a/db/listOp.go
+++ b/db/listOp.go
@@ -2,13 +2,14 @@ package db
 
 import (
 	"fmt"
-	"go.etcd.io/bbolt"
-	jsoniter "github.com/json-iterator/go"
-	"github.com/tidwall/gjson"
-	"github.com/tidwall/sjson"
 	"reflect"
 	"sort"
 	"strconv"
+
+	jsoniter "github.com/json-iterator/go"
+	"github.com/tidwall/gjson"
+	"github.com/tidwall/sjson"
+	"go.etcd.io/bbolt"
 )
 
 func ListSet(bucket string, key string, index int, val interface{}) (err error) {
@@ -31,20 +32,21 @@ func ListSet(bucket string, key string,
 }
 
 func ListGet(bucket string, key string, index int) (b []byte, err error) {
-	err = DB().Update(func(tx *bbolt.Tx) error {
-		if bkt, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil {
-			return err
+	err = Transaction(DB(), func(tx *bbolt.Tx) (bool, error) {
+		dirty := false
+		if bkt, err := CreateBucketIfNotExists(tx, []byte(bucket), &dirty); err != nil {
+			return dirty, err
 		} else {
 			v := bkt.Get([]byte(key))
 			if v == nil {
-				return fmt.Errorf("ListGet: can't get element from an empty list")
+				return dirty, fmt.Errorf("ListGet: can't get element from an empty list")
 			}
 			r := gjson.GetBytes(v, strconv.Itoa(index))
 			if r.Exists() {
 				b = []byte(r.Raw)
-				return nil
+				return dirty, nil
 			} else {
-				return fmt.Errorf("ListGet: no such element")
+				return dirty, fmt.Errorf("ListGet: no such element")
 			}
 		}
 	})
@@ -79,24 +81,25 @@ func ListAppend(bucket string, key strin
 }
 
 func ListGetAll(bucket string, key string) (list [][]byte, err error) {
-	err = DB().Update(func(tx *bbolt.Tx) error {
-		if bkt, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil {
-			return err
+	err = Transaction(DB(), func(tx *bbolt.Tx) (bool, error) {
+		dirty := false
+		if bkt, err := CreateBucketIfNotExists(tx, []byte(bucket), &dirty); err != nil {
+			return dirty, err
 		} else {
 			b := bkt.Get([]byte(key))
 			if b == nil {
-				return nil
+				return dirty, nil
 			}
 			parsed := gjson.ParseBytes(b)
 			if !parsed.IsArray() {
-				return fmt.Errorf("ListGetAll: is not array")
+				return dirty, fmt.Errorf("ListGetAll: is not array")
 			}
 			results := parsed.Array()
 			for _, r := range results {
 				list = append(list, []byte(r.Raw))
 			}
 		}
-		return nil
+		return dirty, nil
 	})
 	return list, err
 }
@@ -143,21 +146,22 @@ func ListRemove(bucket, key string, inde
 }
 
 func ListLen(bucket string, key string) (length int, err error) {
-	err = DB().Update(func(tx *bbolt.Tx) error {
-		if bkt, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil {
-			return err
+	err = Transaction(DB(), func(tx *bbolt.Tx) (bool, error) {
+		dirty := false
+		if bkt, err := CreateBucketIfNotExists(tx, []byte(bucket), &dirty); err != nil {
+			return dirty, err
 		} else {
 			b := bkt.Get([]byte(key))
 			if b == nil {
-				return nil
+				return dirty, nil
 			}
 			parsed := gjson.ParseBytes(b)
 			if !parsed.IsArray() {
-				return fmt.Errorf("ListLen: is not array")
+				return dirty, fmt.Errorf("ListLen: is not array")
 			}
 			length = len(parsed.Array())
 		}
-		return nil
+		return dirty, nil
 	})
 	return length, err
 }
--- a/db/plainOp.go
+++ b/db/plainOp.go
@@ -2,50 +2,54 @@ package db
 
 import (
 	"fmt"
-	"go.etcd.io/bbolt"
+
 	jsoniter "github.com/json-iterator/go"
 	"github.com/v2rayA/v2rayA/common"
 	"github.com/v2rayA/v2rayA/pkg/util/log"
+	"go.etcd.io/bbolt"
 )
 
 func Get(bucket string, key string, val interface{}) (err error) {
-	return DB().Update(func(tx *bbolt.Tx) error {
-		if bkt, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil {
-			return err
+	return Transaction(DB(), func(tx *bbolt.Tx) (bool, error) {
+		dirty := false
+		if bkt, err := CreateBucketIfNotExists(tx, []byte(bucket), &dirty); err != nil {
+			return dirty, err
 		} else {
 			if v := bkt.Get([]byte(key)); v == nil {
-				return fmt.Errorf("Get: key is not found")
+				return dirty, fmt.Errorf("Get: key is not found")
 			} else {
-				return jsoniter.Unmarshal(v, val)
+				return dirty, jsoniter.Unmarshal(v, val)
 			}
 		}
 	})
 }
 
 func GetRaw(bucket string, key string) (b []byte, err error) {
-	err = DB().Update(func(tx *bbolt.Tx) error {
-		if bkt, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil {
-			return err
+	err = Transaction(DB(), func(tx *bbolt.Tx) (bool, error) {
+		dirty := false
+		if bkt, err := CreateBucketIfNotExists(tx, []byte(bucket), &dirty); err != nil {
+			return dirty, err
 		} else {
 			v := bkt.Get([]byte(key))
 			if v == nil {
-				return fmt.Errorf("GetRaw: key is not found")
+				return dirty, fmt.Errorf("GetRaw: key is not found")
 			}
 			b = common.BytesCopy(v)
-			return nil
+			return dirty, nil
 		}
 	})
 	return b, err
 }
 
 func Exists(bucket string, key string) (exists bool) {
-	if err := DB().Update(func(tx *bbolt.Tx) error {
-		if bkt, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil {
-			return err
+	if err := Transaction(DB(), func(tx *bbolt.Tx) (bool, error) {
+		dirty := false
+		if bkt, err := CreateBucketIfNotExists(tx, []byte(bucket), &dirty); err != nil {
+			return dirty, err
 		} else {
 			v := bkt.Get([]byte(key))
 			exists = v != nil
-			return nil
+			return dirty, nil
 		}
 	}); err != nil {
 		log.Warn("%v", err)
@@ -55,23 +59,25 @@ func Exists(bucket string, key string) (
 }
 
 func GetBucketLen(bucket string) (length int, err error) {
-	err = DB().Update(func(tx *bbolt.Tx) error {
-		if bkt, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil {
-			return err
+	err = Transaction(DB(), func(tx *bbolt.Tx) (bool, error) {
+		dirty := false
+		if bkt, err := CreateBucketIfNotExists(tx, []byte(bucket), &dirty); err != nil {
+			return dirty, err
 		} else {
 			length = bkt.Stats().KeyN
 		}
-		return nil
+		return dirty, nil
 	})
 	return length, err
 }
 
 func GetBucketKeys(bucket string) (keys []string, err error) {
-	err = DB().Update(func(tx *bbolt.Tx) error {
-		if bkt, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil {
-			return err
+	err = Transaction(DB(), func(tx *bbolt.Tx) (bool, error) {
+		dirty := false
+		if bkt, err := CreateBucketIfNotExists(tx, []byte(bucket), &dirty); err != nil {
+			return dirty, err
 		} else {
-			return bkt.ForEach(func(k, v []byte) error {
+			return dirty, bkt.ForEach(func(k, v []byte) error {
 				keys = append(keys, string(k))
 				return nil
 			})
--- a/db/setOp.go
+++ b/db/setOp.go
@@ -4,8 +4,9 @@ import (
 	"bytes"
 	"crypto/sha256"
 	"encoding/gob"
-	"go.etcd.io/bbolt"
+
 	"github.com/v2rayA/v2rayA/common"
+	"go.etcd.io/bbolt"
 )
 
 type set map[[32]byte]interface{}
@@ -28,26 +29,27 @@ func toSha256(val interface{}) (hash [32
 }
 
 func setOp(bucket string, key string, f func(m set) (readonly bool, err error)) (err error) {
-	return DB().Update(func(tx *bbolt.Tx) error {
-		if bkt, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil {
-			return err
+	return Transaction(DB(), func(tx *bbolt.Tx) (bool, error) {
+		dirty := false
+		if bkt, err := CreateBucketIfNotExists(tx, []byte(bucket), &dirty); err != nil {
+			return dirty, err
 		} else {
 			var m set
 			v := bkt.Get([]byte(key))
 			if v == nil {
 				m = make(set)
 			} else if err := gob.NewDecoder(bytes.NewReader(v)).Decode(&m); err != nil {
-				return err
+				return dirty, err
 			}
 			if readonly, err := f(m); err != nil {
-				return err
+				return dirty, err
 			} else if readonly {
-				return nil
+				return dirty, nil
 			}
 			if b, err := common.ToBytes(m); err != nil {
-				return err
+				return dirty, err
 			} else {
-				return bkt.Put([]byte(key), b)
+				return true, bkt.Put([]byte(key), b)
 			}
 		}
 	})