diff --git a/README.md b/README.md index c6780b25..ad3834b9 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,7 @@ Common configurations: |mysql.password||MySQL Password| |mysql.db|"test"|MySQL Database| |tidb.cluster_index|true|Whether to use cluster index, for TiDB only| +|tidb.instances|""|Comma-seperated address list of tidb instances (eg: `tidb-0:4000,tidb-1:4000`)| ### TiKV diff --git a/db/mysql/db.go b/db/mysql/db.go index 1f2a20e9..bb34e2a6 100644 --- a/db/mysql/db.go +++ b/db/mysql/db.go @@ -16,16 +16,18 @@ package mysql import ( "bytes" "context" + "crypto/sha1" "database/sql" + "database/sql/driver" + "encoding/hex" "fmt" "strings" + "sync/atomic" + "github.com/go-sql-driver/mysql" + "github.com/magiconair/properties" "github.com/pingcap/go-ycsb/pkg/prop" "github.com/pingcap/go-ycsb/pkg/util" - - // mysql package - _ "github.com/go-sql-driver/mysql" - "github.com/magiconair/properties" "github.com/pingcap/go-ycsb/pkg/ycsb" ) @@ -40,8 +42,38 @@ const ( // TODO: support batch and auto commit tidbClusterIndex = "tidb.cluster_index" + tidbInstances = "tidb.instances" ) +type muxDriver struct { + cursor uint64 + instances []string + internal driver.Driver +} + +func (drv *muxDriver) Open(name string) (driver.Conn, error) { + k := atomic.AddUint64(&drv.cursor, 1) + return drv.internal.Open(drv.instances[int(k)%len(drv.instances)]) +} + +func openTiDBInstances(addrs []string, user string, pass string, db string) (*sql.DB, error) { + instances := make([]string, len(addrs)) + hash := sha1.New() + for i, addr := range addrs { + hash.Write([]byte("+" + addr)) + instances[i] = fmt.Sprintf("%s:%s@tcp(%s)/%s", user, pass, addr, db) + } + digest := hash.Sum(nil) + driver := "tidb:" + hex.EncodeToString(digest[:]) + for _, n := range sql.Drivers() { + if n == driver { + return sql.Open(driver, "") + } + } + sql.Register(driver, &muxDriver{instances: instances, internal: &mysql.MySQLDriver{}}) + return sql.Open(driver, "") +} + type mysqlCreator struct { name string } @@ -75,10 +107,25 @@ func (c mysqlCreator) Create(p *properties.Properties) (ycsb.DB, error) { user := p.GetString(mysqlUser, "root") password := p.GetString(mysqlPassword, "") dbName := p.GetString(mysqlDBName, "test") - - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", user, password, host, port, dbName) - var err error - db, err := sql.Open("mysql", dsn) + tidbList := p.GetString(tidbInstances, "") + + var ( + db *sql.DB + err error + tidbs []string + ) + for _, tidb := range strings.Split(tidbList, ",") { + tidb = strings.TrimSpace(tidb) + if len(tidb) > 0 { + tidbs = append(tidbs, tidb) + } + } + if len(tidbs) > 0 { + db, err = openTiDBInstances(tidbs, user, password, dbName) + } else { + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", user, password, host, port, dbName) + db, err = sql.Open("mysql", dsn) + } if err != nil { return nil, err }