summaryrefslogtreecommitdiff
path: root/libgo/go/database/sql/fakedb_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/database/sql/fakedb_test.go')
-rw-r--r--libgo/go/database/sql/fakedb_test.go598
1 files changed, 598 insertions, 0 deletions
diff --git a/libgo/go/database/sql/fakedb_test.go b/libgo/go/database/sql/fakedb_test.go
new file mode 100644
index 00000000000..b0d137cd715
--- /dev/null
+++ b/libgo/go/database/sql/fakedb_test.go
@@ -0,0 +1,598 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql
+
+import (
+ "database/sql/driver"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+var _ = log.Printf
+
+// fakeDriver is a fake database that implements Go's driver.Driver
+// interface, just for testing.
+//
+// It speaks a query language that's semantically similar to but
+// syntantically different and simpler than SQL. The syntax is as
+// follows:
+//
+// WIPE
+// CREATE|<tablename>|<col>=<type>,<col>=<type>,...
+// where types are: "string", [u]int{8,16,32,64}, "bool"
+// INSERT|<tablename>|col=val,col2=val2,col3=?
+// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
+//
+// When opening a a fakeDriver's database, it starts empty with no
+// tables. All tables and data are stored in memory only.
+type fakeDriver struct {
+ mu sync.Mutex
+ openCount int
+ dbs map[string]*fakeDB
+}
+
+type fakeDB struct {
+ name string
+
+ mu sync.Mutex
+ free []*fakeConn
+ tables map[string]*table
+}
+
+type table struct {
+ mu sync.Mutex
+ colname []string
+ coltype []string
+ rows []*row
+}
+
+func (t *table) columnIndex(name string) int {
+ for n, nname := range t.colname {
+ if name == nname {
+ return n
+ }
+ }
+ return -1
+}
+
+type row struct {
+ cols []interface{} // must be same size as its table colname + coltype
+}
+
+func (r *row) clone() *row {
+ nrow := &row{cols: make([]interface{}, len(r.cols))}
+ copy(nrow.cols, r.cols)
+ return nrow
+}
+
+type fakeConn struct {
+ db *fakeDB // where to return ourselves to
+
+ currTx *fakeTx
+
+ // Stats for tests:
+ mu sync.Mutex
+ stmtsMade int
+ stmtsClosed int
+}
+
+func (c *fakeConn) incrStat(v *int) {
+ c.mu.Lock()
+ *v++
+ c.mu.Unlock()
+}
+
+type fakeTx struct {
+ c *fakeConn
+}
+
+type fakeStmt struct {
+ c *fakeConn
+ q string // just for debugging
+
+ cmd string
+ table string
+
+ closed bool
+
+ colName []string // used by CREATE, INSERT, SELECT (selected columns)
+ colType []string // used by CREATE
+ colValue []interface{} // used by INSERT (mix of strings and "?" for bound params)
+ placeholders int // used by INSERT/SELECT: number of ? params
+
+ whereCol []string // used by SELECT (all placeholders)
+
+ placeholderConverter []driver.ValueConverter // used by INSERT
+}
+
+var fdriver driver.Driver = &fakeDriver{}
+
+func init() {
+ Register("test", fdriver)
+}
+
+// Supports dsn forms:
+// <dbname>
+// <dbname>;<opts> (no currently supported options)
+func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
+ parts := strings.Split(dsn, ";")
+ if len(parts) < 1 {
+ return nil, errors.New("fakedb: no database name")
+ }
+ name := parts[0]
+
+ db := d.getDB(name)
+
+ d.mu.Lock()
+ d.openCount++
+ d.mu.Unlock()
+ return &fakeConn{db: db}, nil
+}
+
+func (d *fakeDriver) getDB(name string) *fakeDB {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ if d.dbs == nil {
+ d.dbs = make(map[string]*fakeDB)
+ }
+ db, ok := d.dbs[name]
+ if !ok {
+ db = &fakeDB{name: name}
+ d.dbs[name] = db
+ }
+ return db
+}
+
+func (db *fakeDB) wipe() {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ db.tables = nil
+}
+
+func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ if db.tables == nil {
+ db.tables = make(map[string]*table)
+ }
+ if _, exist := db.tables[name]; exist {
+ return fmt.Errorf("table %q already exists", name)
+ }
+ if len(columnNames) != len(columnTypes) {
+ return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d",
+ name, len(columnNames), len(columnTypes))
+ }
+ db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
+ return nil
+}
+
+// must be called with db.mu lock held
+func (db *fakeDB) table(table string) (*table, bool) {
+ if db.tables == nil {
+ return nil, false
+ }
+ t, ok := db.tables[table]
+ return t, ok
+}
+
+func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ t, ok := db.table(table)
+ if !ok {
+ return
+ }
+ for n, cname := range t.colname {
+ if cname == column {
+ return t.coltype[n], true
+ }
+ }
+ return "", false
+}
+
+func (c *fakeConn) Begin() (driver.Tx, error) {
+ if c.currTx != nil {
+ return nil, errors.New("already in a transaction")
+ }
+ c.currTx = &fakeTx{c: c}
+ return c.currTx, nil
+}
+
+func (c *fakeConn) Close() error {
+ if c.currTx != nil {
+ return errors.New("can't close; in a Transaction")
+ }
+ if c.db == nil {
+ return errors.New("can't close; already closed")
+ }
+ c.db = nil
+ return nil
+}
+
+func checkSubsetTypes(args []interface{}) error {
+ for n, arg := range args {
+ switch arg.(type) {
+ case int64, float64, bool, nil, []byte, string, time.Time:
+ default:
+ return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg)
+ }
+ }
+ return nil
+}
+
+func (c *fakeConn) Exec(query string, args []interface{}) (driver.Result, error) {
+ // This is an optional interface, but it's implemented here
+ // just to check that all the args of of the proper types.
+ // ErrSkip is returned so the caller acts as if we didn't
+ // implement this at all.
+ err := checkSubsetTypes(args)
+ if err != nil {
+ return nil, err
+ }
+ return nil, driver.ErrSkip
+}
+
+func errf(msg string, args ...interface{}) error {
+ return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
+}
+
+// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
+// (note that where where columns must always contain ? marks,
+// just a limitation for fakedb)
+func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+ if len(parts) != 3 {
+ return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
+ }
+ stmt.table = parts[0]
+ stmt.colName = strings.Split(parts[1], ",")
+ for n, colspec := range strings.Split(parts[2], ",") {
+ if colspec == "" {
+ continue
+ }
+ nameVal := strings.Split(colspec, "=")
+ if len(nameVal) != 2 {
+ return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
+ }
+ column, value := nameVal[0], nameVal[1]
+ _, ok := c.db.columnType(stmt.table, column)
+ if !ok {
+ return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
+ }
+ if value != "?" {
+ return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
+ stmt.table, column)
+ }
+ stmt.whereCol = append(stmt.whereCol, column)
+ stmt.placeholders++
+ }
+ return stmt, nil
+}
+
+// parts are table|col=type,col2=type2
+func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+ if len(parts) != 2 {
+ return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
+ }
+ stmt.table = parts[0]
+ for n, colspec := range strings.Split(parts[1], ",") {
+ nameType := strings.Split(colspec, "=")
+ if len(nameType) != 2 {
+ return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
+ }
+ stmt.colName = append(stmt.colName, nameType[0])
+ stmt.colType = append(stmt.colType, nameType[1])
+ }
+ return stmt, nil
+}
+
+// parts are table|col=?,col2=val
+func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+ if len(parts) != 2 {
+ return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
+ }
+ stmt.table = parts[0]
+ for n, colspec := range strings.Split(parts[1], ",") {
+ nameVal := strings.Split(colspec, "=")
+ if len(nameVal) != 2 {
+ return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
+ }
+ column, value := nameVal[0], nameVal[1]
+ ctype, ok := c.db.columnType(stmt.table, column)
+ if !ok {
+ return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
+ }
+ stmt.colName = append(stmt.colName, column)
+
+ if value != "?" {
+ var subsetVal interface{}
+ // Convert to driver subset type
+ switch ctype {
+ case "string":
+ subsetVal = []byte(value)
+ case "blob":
+ subsetVal = []byte(value)
+ case "int32":
+ i, err := strconv.Atoi(value)
+ if err != nil {
+ return nil, errf("invalid conversion to int32 from %q", value)
+ }
+ subsetVal = int64(i) // int64 is a subset type, but not int32
+ default:
+ return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
+ }
+ stmt.colValue = append(stmt.colValue, subsetVal)
+ } else {
+ stmt.placeholders++
+ stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
+ stmt.colValue = append(stmt.colValue, "?")
+ }
+ }
+ return stmt, nil
+}
+
+func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
+ if c.db == nil {
+ panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
+ }
+ parts := strings.Split(query, "|")
+ if len(parts) < 1 {
+ return nil, errf("empty query")
+ }
+ cmd := parts[0]
+ parts = parts[1:]
+ stmt := &fakeStmt{q: query, c: c, cmd: cmd}
+ c.incrStat(&c.stmtsMade)
+ switch cmd {
+ case "WIPE":
+ // Nothing
+ case "SELECT":
+ return c.prepareSelect(stmt, parts)
+ case "CREATE":
+ return c.prepareCreate(stmt, parts)
+ case "INSERT":
+ return c.prepareInsert(stmt, parts)
+ default:
+ return nil, errf("unsupported command type %q", cmd)
+ }
+ return stmt, nil
+}
+
+func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
+ return s.placeholderConverter[idx]
+}
+
+func (s *fakeStmt) Close() error {
+ if !s.closed {
+ s.c.incrStat(&s.c.stmtsClosed)
+ s.closed = true
+ }
+ return nil
+}
+
+var errClosed = errors.New("fakedb: statement has been closed")
+
+func (s *fakeStmt) Exec(args []interface{}) (driver.Result, error) {
+ if s.closed {
+ return nil, errClosed
+ }
+ err := checkSubsetTypes(args)
+ if err != nil {
+ return nil, err
+ }
+
+ db := s.c.db
+ switch s.cmd {
+ case "WIPE":
+ db.wipe()
+ return driver.DDLSuccess, nil
+ case "CREATE":
+ if err := db.createTable(s.table, s.colName, s.colType); err != nil {
+ return nil, err
+ }
+ return driver.DDLSuccess, nil
+ case "INSERT":
+ return s.execInsert(args)
+ }
+ fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s)
+ return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd)
+}
+
+func (s *fakeStmt) execInsert(args []interface{}) (driver.Result, error) {
+ db := s.c.db
+ if len(args) != s.placeholders {
+ panic("error in pkg db; should only get here if size is correct")
+ }
+ db.mu.Lock()
+ t, ok := db.table(s.table)
+ db.mu.Unlock()
+ if !ok {
+ return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ cols := make([]interface{}, len(t.colname))
+ argPos := 0
+ for n, colname := range s.colName {
+ colidx := t.columnIndex(colname)
+ if colidx == -1 {
+ return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
+ }
+ var val interface{}
+ if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" {
+ val = args[argPos]
+ argPos++
+ } else {
+ val = s.colValue[n]
+ }
+ cols[colidx] = val
+ }
+
+ t.rows = append(t.rows, &row{cols: cols})
+ return driver.RowsAffected(1), nil
+}
+
+func (s *fakeStmt) Query(args []interface{}) (driver.Rows, error) {
+ if s.closed {
+ return nil, errClosed
+ }
+ err := checkSubsetTypes(args)
+ if err != nil {
+ return nil, err
+ }
+
+ db := s.c.db
+ if len(args) != s.placeholders {
+ panic("error in pkg db; should only get here if size is correct")
+ }
+
+ db.mu.Lock()
+ t, ok := db.table(s.table)
+ db.mu.Unlock()
+ if !ok {
+ return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
+ }
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ colIdx := make(map[string]int) // select column name -> column index in table
+ for _, name := range s.colName {
+ idx := t.columnIndex(name)
+ if idx == -1 {
+ return nil, fmt.Errorf("fakedb: unknown column name %q", name)
+ }
+ colIdx[name] = idx
+ }
+
+ mrows := []*row{}
+rows:
+ for _, trow := range t.rows {
+ // Process the where clause, skipping non-match rows. This is lazy
+ // and just uses fmt.Sprintf("%v") to test equality. Good enough
+ // for test code.
+ for widx, wcol := range s.whereCol {
+ idx := t.columnIndex(wcol)
+ if idx == -1 {
+ return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+ }
+ tcol := trow.cols[idx]
+ if bs, ok := tcol.([]byte); ok {
+ // lazy hack to avoid sprintf %v on a []byte
+ tcol = string(bs)
+ }
+ if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
+ continue rows
+ }
+ }
+ mrow := &row{cols: make([]interface{}, len(s.colName))}
+ for seli, name := range s.colName {
+ mrow.cols[seli] = trow.cols[colIdx[name]]
+ }
+ mrows = append(mrows, mrow)
+ }
+
+ cursor := &rowsCursor{
+ pos: -1,
+ rows: mrows,
+ cols: s.colName,
+ }
+ return cursor, nil
+}
+
+func (s *fakeStmt) NumInput() int {
+ return s.placeholders
+}
+
+func (tx *fakeTx) Commit() error {
+ tx.c.currTx = nil
+ return nil
+}
+
+func (tx *fakeTx) Rollback() error {
+ tx.c.currTx = nil
+ return nil
+}
+
+type rowsCursor struct {
+ cols []string
+ pos int
+ rows []*row
+ closed bool
+
+ // a clone of slices to give out to clients, indexed by the
+ // the original slice's first byte address. we clone them
+ // just so we're able to corrupt them on close.
+ bytesClone map[*byte][]byte
+}
+
+func (rc *rowsCursor) Close() error {
+ if !rc.closed {
+ for _, bs := range rc.bytesClone {
+ bs[0] = 255 // first byte corrupted
+ }
+ }
+ rc.closed = true
+ return nil
+}
+
+func (rc *rowsCursor) Columns() []string {
+ return rc.cols
+}
+
+func (rc *rowsCursor) Next(dest []interface{}) error {
+ if rc.closed {
+ return errors.New("fakedb: cursor is closed")
+ }
+ rc.pos++
+ if rc.pos >= len(rc.rows) {
+ return io.EOF // per interface spec
+ }
+ for i, v := range rc.rows[rc.pos].cols {
+ // TODO(bradfitz): convert to subset types? naah, I
+ // think the subset types should only be input to
+ // driver, but the sql package should be able to handle
+ // a wider range of types coming out of drivers. all
+ // for ease of drivers, and to prevent drivers from
+ // messing up conversions or doing them differently.
+ dest[i] = v
+
+ if bs, ok := v.([]byte); ok {
+ if rc.bytesClone == nil {
+ rc.bytesClone = make(map[*byte][]byte)
+ }
+ clone, ok := rc.bytesClone[&bs[0]]
+ if !ok {
+ clone = make([]byte, len(bs))
+ copy(clone, bs)
+ rc.bytesClone[&bs[0]] = clone
+ }
+ dest[i] = clone
+ }
+ }
+ return nil
+}
+
+func converterForType(typ string) driver.ValueConverter {
+ switch typ {
+ case "bool":
+ return driver.Bool
+ case "int32":
+ return driver.Int32
+ case "string":
+ return driver.NotNull{driver.String}
+ case "nullstring":
+ return driver.Null{driver.String}
+ case "datetime":
+ return driver.DefaultParameterConverter
+ }
+ panic("invalid fakedb column type of " + typ)
+}