diff --git a/ssdb/ssdb.go b/ssdb/ssdb.go index 21bef9e..206d240 100644 --- a/ssdb/ssdb.go +++ b/ssdb/ssdb.go @@ -8,8 +8,41 @@ import ( ) type Client struct { - sock *net.TCPConn + sock chan *net.TCPConn recv_buf bytes.Buffer + _sock *net.TCPConn +} + +type ConnectionPoolWrapper struct { + size int + conn chan *Client +} + +func InitPool(ip string, port int, size int) (*ConnectionPoolWrapper, error) { + + cpm := new(ConnectionPoolWrapper) + + cpm.conn = make(chan *Client, size) + for x := 0; x < size; x++ { + conn, err := Connect(ip, port) + if err != nil { + return cpm, err + } + + // If the init function succeeded, add the connection to the channel + cpm.conn <- conn + } + cpm.size = size + return cpm, nil + +} + +func (p *ConnectionPoolWrapper) GetConnection() *Client { + return <-p.conn +} + +func (p *ConnectionPoolWrapper) ReleaseConnection(conn *Client) { + p.conn <- conn } func Connect(ip string, port int) (*Client, error) { @@ -22,11 +55,24 @@ func Connect(ip string, port int) (*Client, error) { return nil, err } var c Client - c.sock = sock + + c.sock = make(chan *net.TCPConn, 1) + c.sock <- sock + return &c, nil } func (c *Client) Do(args ...interface{}) ([]string, error) { + + c._sock = <- c.sock + defer func () { + c.sock <- c._sock + }() + + return c.do(args...) +} + +func (c *Client) do(args ...interface{}) ([]string, error) { err := c.send(args) if err != nil { return nil, err @@ -40,19 +86,29 @@ func (c *Client) Set(key string, val string) (interface{}, error) { if err != nil { return nil, err } - if len(resp) == 2 && resp[0] == "ok" { + if len(resp) > 0 && resp[0] == "ok" { return true, nil } return nil, fmt.Errorf("bad response") } + +func (c *ConnectionPoolWrapper) Set(key string, val string) (interface{}, error) { + + db := c.GetConnection() + defer c.ReleaseConnection(db) + + return db.Set(key, val) +} + // TODO: Will somebody write addition semantic methods? func (c *Client) Get(key string) (interface{}, error) { resp, err := c.Do("get", key) if err != nil { return nil, err } - if len(resp) == 2 && resp[0] == "ok" { + if len(resp) > 0 && resp[0] == "ok" { + // return resp[1], nil return resp[1], nil } if resp[0] == "not_found" { @@ -61,6 +117,28 @@ func (c *Client) Get(key string) (interface{}, error) { return nil, fmt.Errorf("bad response") } +func (c *Client) Info() (interface{}, error) { + resp, err := c.Do("info") + if err != nil { + return nil, err + } + if len(resp) > 0 && resp[0] == "ok" { + return resp, nil + } + if resp[0] == "not_found" { + return nil, nil + } + return nil, fmt.Errorf("bad response") +} + +func (c *ConnectionPoolWrapper) Get(key string) (interface{}, error) { + + db := c.GetConnection() + defer c.ReleaseConnection(db) + + return db.Get(key) +} + func (c *Client) Del(key string) (interface{}, error) { resp, err := c.Do("del", key) if err != nil { @@ -74,11 +152,21 @@ func (c *Client) Del(key string) (interface{}, error) { return nil, fmt.Errorf("bad response:resp:%v:", resp) } +func (c *ConnectionPoolWrapper) Del(key string) (interface{}, error) { + + db := c.GetConnection() + defer c.ReleaseConnection(db) + + return db.Del(key) +} + func (c *Client) Send(args ...interface{}) error { return c.send(args); } func (c *Client) send(args []interface{}) error { + + var sock = c._sock var buf bytes.Buffer for _, arg := range args { var s string @@ -118,7 +206,9 @@ func (c *Client) send(args []interface{}) error { buf.WriteByte('\n') } buf.WriteByte('\n') - _, err := c.sock.Write(buf.Bytes()) + + _, err := sock.Write(buf.Bytes()) + return err } @@ -127,14 +217,19 @@ func (c *Client) Recv() ([]string, error) { } func (c *Client) recv() ([]string, error) { + + var sock = c._sock + var tmp [1]byte for { resp := c.parse() if resp == nil || len(resp) > 0 { return resp, nil } - n, err := c.sock.Read(tmp[0:]) + n, err := sock.Read(tmp[0:]) + if err != nil { + return nil, err } c.recv_buf.Write(tmp[0:n]) @@ -184,5 +279,32 @@ func (c *Client) parse() []string { // Close The Client Connection func (c *Client) Close() error { - return c.sock.Close() + + sock := <- c.sock + + defer func () { + + c.sock <- sock + + }() + + return sock.Close() + +} + + +func (cpm *ConnectionPoolWrapper) Close() error { + + for { + + select { + case db := <- cpm.conn: + db.Close() + default: + return nil + } + + } + + return nil } diff --git a/ssdb/ssdb_test.go b/ssdb/ssdb_test.go new file mode 100644 index 0000000..1ef3e83 --- /dev/null +++ b/ssdb/ssdb_test.go @@ -0,0 +1,192 @@ +package ssdb + +import "testing" + +var ip = "ssdb" +var port = 16379 + +func TestConnect(t *testing.T) { + + db, err := Connect("zup", port) + if err == nil { + t.Error("connect to bad host did not return err") + } + if db != nil { + t.Error("connect to bad host returned non nil db") + } + + db, err = Connect("ssdb", 0) + if err == nil { + t.Error("connect to bad port did not return err") + } + if db != nil { + t.Error("connect to bad port returned non nil db") + } + + db, err = Connect(ip, port) + if err != nil { + t.Error("Failed to connect") + } + defer db.Close() + + db, err = Connect(ip, port) + if err != nil { + t.Error("Close:second connect raised an error:%v:", err) + } + +} + + +func TestClose(t *testing.T) { + + db, err := Connect(ip, port) + if err != nil { + t.Error("Close:connect returned an err:%v:", err) + } + if db == nil { + t.Error("Close:connect returned a nil db") + } + + db.Close() + if err != nil { + t.Error("Close:returned an error:%v:", err) + } + + val, err := db.Set("a", "xxx") + if val == true { + t.Error("Close:Set:val returned true after db closed") + } + if val != nil { + t.Error("Close:Set:val returned non-nil after db closed") + } + if err == nil { + t.Error("Close:Set returned no error after db closed") + } + +} + +func TestInitPool(t *testing.T) { + + cpm, err := InitPool(ip, port, 11) + if err != nil { + t.Error("failed to init pool") + } + if cpm == nil { + t.Error("failed to init pool") + } + +} + +func TestSet(t *testing.T) { + + db, err := Connect(ip, port) + if err != nil { + t.Error("Failed to connect") + } + //defer db.Close() + + val, err := db.Set("a", "xxx") + if val != true { + t.Error("Set val returned false") + } + if err != nil { + t.Error("Set err returned not nil err") + } + + // add negative testts + db.Close() + + val, err = db.Set("a", "xxx") + if val == true { + t.Error("Set val on closed db returned true") + } + if err == nil { + t.Error("Set err on closed db returned not nil err") + } + + +} + +func TestGet(t *testing.T) { + + db, err := Connect(ip, port) + if err != nil { + t.Error("Failed to connect") + } + defer db.Close() + + val, err := db.Set("a", "xxx") + if val != true { + t.Error("Set val returned false") + } + if err != nil { + t.Error("Set err returned not nil err") + } + + val, err = db.Get("a") + if val == nil { + t.Error("Get returned nil") + } + if err != nil { + t.Error("Get returned err") + } + if val != "xxx" { + t.Error("Get did not return a") + } + +} + +func TestDel(t *testing.T) { + + db, err := Connect(ip, port) + if err != nil { + t.Error("Failed to connect") + } + defer db.Close() + + val, err := db.Set("a", "xxx") + if val != true { + t.Error("Set val returned false") + } + if err != nil { + t.Error("Set err returned not nil err") + } + + val, err = db.Get("a") + if val != "xxx" { + t.Error("Get did not return xxx") + } + if val == nil { + t.Error("Get returned nil") + } + if err != nil { + t.Error("Get returned err") + } + + val, err = db.Del("a") + if val != true { + t.Error("Del returned false") + } + if err != nil { + t.Error("Del returned err") + } + + val, err = db.Get("a") + if val == "xxx" { + t.Error("Get returned xxx after Del") + } + if val != nil { + t.Error("Get returned non-nil") + } + if err != nil { + t.Error("Get returned err") + } + + val, err = db.Del("a") + if val != true { + t.Error("Del returned non-nil:%v:", val) + } + if err != nil { + t.Error("Get returned err") + } +} \ No newline at end of file diff --git a/test.go b/test.go index 4bc95b9..c88ed69 100644 --- a/test.go +++ b/test.go @@ -12,6 +12,7 @@ func main() { port := 8888 db, err := ssdb.Connect(ip, port) if err != nil { + fmt.Errorf("ssdb.Connect:err:%v:\n", err) os.Exit(1) } @@ -22,14 +23,42 @@ func main() { keys = append(keys, "c"); keys = append(keys, "d"); val, err = db.Do("multi_get", "a", "b", keys); - fmt.Printf("%s\n", val); + fmt.Printf(":%s:\n", val); + if err != nil { + os.Exit(1) + } + + val, err = db.Do("info"); + fmt.Printf("called info:%s:\n", val); + if err != nil { + os.Exit(1) + } + + val, err = db.Info(); + fmt.Printf("called info again:%s:\n", val); + if err != nil { + os.Exit(1) + } + + val, err = db.Do("keys", "", "", 100); + fmt.Printf("called keys:%s:\n", val); + if err != nil { + os.Exit(1) + } + + val, err = db.Do("scan", "", "", 100); + fmt.Printf("called scan:%s:\n", val); + if err != nil { + os.Exit(1) + } db.Set("a", "xxx") val, err = db.Get("a") - fmt.Printf("%s\n", val) - db.Del("a") + fmt.Printf("Got:%v:\n", val) + _, err = db.Del("a") + fmt.Printf("deleted it:err:%v:\n", err) val, err = db.Get("a") - fmt.Printf("%s\n", val) + fmt.Printf("got it again:val:%v:err:%v:\n", val, err) fmt.Printf("----\n"); @@ -49,6 +78,26 @@ func main() { fmt.Printf(" %s : %3s\n", resp[i], resp[i+1]) } + fmt.Printf("call:init pool\n") + cpm, err := ssdb.InitPool(ip, port, 11) + fmt.Printf("done call:init pool\n") + + if err != nil { + fmt.Printf("failed to init pool:%v:", err) + os.Exit(1) + } + + // cpm.Put + fmt.Printf("cpm.Set:a:xxx\n") + cpm.Set("a", "xxx") + + fmt.Printf("cpm.Get:a:\n") + val, err = cpm.Get("a") + fmt.Printf("cpm.Get:return:val:%s:err:%v:\n", val, err) + + cpm.Close() + + //_ = db.Send("dump", "", "", "-1"); _ = db.Send("sync140"); // receive multi responses on one request