Skip to content

Commit 04a39d1

Browse files
authored
Merge pull request jmoiron#270 from wyattjoh/master
Add experimental support for 1.8's new Context based database/sql functions
2 parents bc8b1f8 + 909e40f commit 04a39d1

File tree

5 files changed

+1923
-1
lines changed

5 files changed

+1923
-1
lines changed

named_context.go

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// +build go1.8
2+
3+
package sqlx
4+
5+
import (
6+
"context"
7+
"database/sql"
8+
)
9+
10+
// A union interface of contextPreparer and binder, required to be able to
11+
// prepare named statements with context (as the bindtype must be determined).
12+
type namedPreparerContext interface {
13+
PreparerContext
14+
binder
15+
}
16+
17+
func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) {
18+
bindType := BindType(p.DriverName())
19+
q, args, err := compileNamedQuery([]byte(query), bindType)
20+
if err != nil {
21+
return nil, err
22+
}
23+
stmt, err := PreparexContext(ctx, p, q)
24+
if err != nil {
25+
return nil, err
26+
}
27+
return &NamedStmt{
28+
QueryString: q,
29+
Params: args,
30+
Stmt: stmt,
31+
}, nil
32+
}
33+
34+
// ExecContext executes a named statement using the struct passed.
35+
// Any named placeholder parameters are replaced with fields from arg.
36+
func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) {
37+
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
38+
if err != nil {
39+
return *new(sql.Result), err
40+
}
41+
return n.Stmt.ExecContext(ctx, args...)
42+
}
43+
44+
// QueryContext executes a named statement using the struct argument, returning rows.
45+
// Any named placeholder parameters are replaced with fields from arg.
46+
func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) {
47+
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
48+
if err != nil {
49+
return nil, err
50+
}
51+
return n.Stmt.QueryContext(ctx, args...)
52+
}
53+
54+
// QueryRowContext executes a named statement against the database. Because sqlx cannot
55+
// create a *sql.Row with an error condition pre-set for binding errors, sqlx
56+
// returns a *sqlx.Row instead.
57+
// Any named placeholder parameters are replaced with fields from arg.
58+
func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row {
59+
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
60+
if err != nil {
61+
return &Row{err: err}
62+
}
63+
return n.Stmt.QueryRowxContext(ctx, args...)
64+
}
65+
66+
// MustExecContext execs a NamedStmt, panicing on error
67+
// Any named placeholder parameters are replaced with fields from arg.
68+
func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result {
69+
res, err := n.ExecContext(ctx, arg)
70+
if err != nil {
71+
panic(err)
72+
}
73+
return res
74+
}
75+
76+
// QueryxContext using this NamedStmt
77+
// Any named placeholder parameters are replaced with fields from arg.
78+
func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) {
79+
r, err := n.QueryContext(ctx, arg)
80+
if err != nil {
81+
return nil, err
82+
}
83+
return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err
84+
}
85+
86+
// QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is
87+
// an alias for QueryRow.
88+
// Any named placeholder parameters are replaced with fields from arg.
89+
func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row {
90+
return n.QueryRowContext(ctx, arg)
91+
}
92+
93+
// SelectContext using this NamedStmt
94+
// Any named placeholder parameters are replaced with fields from arg.
95+
func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error {
96+
rows, err := n.QueryxContext(ctx, arg)
97+
if err != nil {
98+
return err
99+
}
100+
// if something happens here, we want to make sure the rows are Closed
101+
defer rows.Close()
102+
return scanAll(rows, dest, false)
103+
}
104+
105+
// GetContext using this NamedStmt
106+
// Any named placeholder parameters are replaced with fields from arg.
107+
func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error {
108+
r := n.QueryRowxContext(ctx, arg)
109+
return r.scanAny(dest, false)
110+
}
111+
112+
// NamedQueryContext binds a named query and then runs Query on the result using the
113+
// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with
114+
// map[string]interface{} types.
115+
func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) {
116+
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
117+
if err != nil {
118+
return nil, err
119+
}
120+
return e.QueryxContext(ctx, q, args...)
121+
}
122+
123+
// NamedExecContext uses BindStruct to get a query executable by the driver and
124+
// then runs Exec on the result. Returns an error from the binding
125+
// or the query excution itself.
126+
func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) {
127+
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
128+
if err != nil {
129+
return nil, err
130+
}
131+
return e.ExecContext(ctx, q, args...)
132+
}

named_context_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// +build go1.8
2+
3+
package sqlx
4+
5+
import (
6+
"context"
7+
"database/sql"
8+
"testing"
9+
)
10+
11+
func TestNamedContextQueries(t *testing.T) {
12+
RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
13+
loadDefaultFixture(db, t)
14+
test := Test{t}
15+
var ns *NamedStmt
16+
var err error
17+
18+
ctx := context.Background()
19+
20+
// Check that invalid preparations fail
21+
ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first:name")
22+
if err == nil {
23+
t.Error("Expected an error with invalid prepared statement.")
24+
}
25+
26+
ns, err = db.PrepareNamedContext(ctx, "invalid sql")
27+
if err == nil {
28+
t.Error("Expected an error with invalid prepared statement.")
29+
}
30+
31+
// Check closing works as anticipated
32+
ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first_name")
33+
test.Error(err)
34+
err = ns.Close()
35+
test.Error(err)
36+
37+
ns, err = db.PrepareNamedContext(ctx, `
38+
SELECT first_name, last_name, email
39+
FROM person WHERE first_name=:first_name AND email=:email`)
40+
test.Error(err)
41+
42+
// test Queryx w/ uses Query
43+
p := Person{FirstName: "Jason", LastName: "Moiron", Email: "[email protected]"}
44+
45+
rows, err := ns.QueryxContext(ctx, p)
46+
test.Error(err)
47+
for rows.Next() {
48+
var p2 Person
49+
rows.StructScan(&p2)
50+
if p.FirstName != p2.FirstName {
51+
t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName)
52+
}
53+
if p.LastName != p2.LastName {
54+
t.Errorf("got %s, expected %s", p.LastName, p2.LastName)
55+
}
56+
if p.Email != p2.Email {
57+
t.Errorf("got %s, expected %s", p.Email, p2.Email)
58+
}
59+
}
60+
61+
// test Select
62+
people := make([]Person, 0, 5)
63+
err = ns.SelectContext(ctx, &people, p)
64+
test.Error(err)
65+
66+
if len(people) != 1 {
67+
t.Errorf("got %d results, expected %d", len(people), 1)
68+
}
69+
if p.FirstName != people[0].FirstName {
70+
t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName)
71+
}
72+
if p.LastName != people[0].LastName {
73+
t.Errorf("got %s, expected %s", p.LastName, people[0].LastName)
74+
}
75+
if p.Email != people[0].Email {
76+
t.Errorf("got %s, expected %s", p.Email, people[0].Email)
77+
}
78+
79+
// test Exec
80+
ns, err = db.PrepareNamedContext(ctx, `
81+
INSERT INTO person (first_name, last_name, email)
82+
VALUES (:first_name, :last_name, :email)`)
83+
test.Error(err)
84+
85+
js := Person{
86+
FirstName: "Julien",
87+
LastName: "Savea",
88+
89+
}
90+
_, err = ns.ExecContext(ctx, js)
91+
test.Error(err)
92+
93+
// Make sure we can pull him out again
94+
p2 := Person{}
95+
db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email)
96+
if p2.Email != js.Email {
97+
t.Errorf("expected %s, got %s", js.Email, p2.Email)
98+
}
99+
100+
// test Txn NamedStmts
101+
tx := db.MustBeginTx(ctx, nil)
102+
txns := tx.NamedStmtContext(ctx, ns)
103+
104+
// We're going to add Steven in this txn
105+
sl := Person{
106+
FirstName: "Steven",
107+
LastName: "Luatua",
108+
109+
}
110+
111+
_, err = txns.ExecContext(ctx, sl)
112+
test.Error(err)
113+
// then rollback...
114+
tx.Rollback()
115+
// looking for Steven after a rollback should fail
116+
err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email)
117+
if err != sql.ErrNoRows {
118+
t.Errorf("expected no rows error, got %v", err)
119+
}
120+
121+
// now do the same, but commit
122+
tx = db.MustBeginTx(ctx, nil)
123+
txns = tx.NamedStmtContext(ctx, ns)
124+
_, err = txns.ExecContext(ctx, sl)
125+
test.Error(err)
126+
tx.Commit()
127+
128+
// looking for Steven after a Commit should succeed
129+
err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email)
130+
test.Error(err)
131+
if p2.Email != sl.Email {
132+
t.Errorf("expected %s, got %s", sl.Email, p2.Email)
133+
}
134+
135+
})
136+
}

0 commit comments

Comments
 (0)