Skip to content

Commit e2189dc

Browse files
committed
wrap sql.Conn
1 parent d7d9517 commit e2189dc

File tree

4 files changed

+160
-0
lines changed

4 files changed

+160
-0
lines changed

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ require (
55
github.com/lib/pq v1.0.0
66
github.com/mattn/go-sqlite3 v1.9.0
77
)
8+
9+
go 1.13

sqlx.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,14 @@ func (db *DB) PrepareNamed(query string) (*NamedStmt, error) {
380380
return prepareNamed(db, query)
381381
}
382382

383+
// Conn is a wrapper around sql.Conn with extra functionality
384+
type Conn struct {
385+
*sql.Conn
386+
driverName string
387+
unsafe bool
388+
Mapper *reflectx.Mapper
389+
}
390+
383391
// Tx is an sqlx wrapper around sql.Tx with extra functionality
384392
type Tx struct {
385393
*sql.Tx

sqlx_context.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,74 @@ func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
208208
return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err
209209
}
210210

211+
// Connx returns an *sqlx.Conn instead of an *sql.Conn.
212+
func (db *DB) Connx(ctx context.Context) (*Conn, error) {
213+
conn, err := db.DB.Conn(ctx)
214+
if err != nil {
215+
return nil, err
216+
}
217+
218+
return &Conn{Conn: conn, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, nil
219+
}
220+
221+
// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an
222+
// *sql.Tx.
223+
//
224+
// The provided context is used until the transaction is committed or rolled
225+
// back. If the context is canceled, the sql package will roll back the
226+
// transaction. Tx.Commit will return an error if the context provided to
227+
// BeginxContext is canceled.
228+
func (c *Conn) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
229+
tx, err := c.Conn.BeginTx(ctx, opts)
230+
if err != nil {
231+
return nil, err
232+
}
233+
return &Tx{Tx: tx, driverName: c.driverName, unsafe: c.unsafe, Mapper: c.Mapper}, err
234+
}
235+
236+
// SelectContext using this Conn.
237+
// Any placeholder parameters are replaced with supplied args.
238+
func (c *Conn) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
239+
return SelectContext(ctx, c, dest, query, args...)
240+
}
241+
242+
// GetContext using this Conn.
243+
// Any placeholder parameters are replaced with supplied args.
244+
// An error is returned if the result set is empty.
245+
func (c *Conn) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
246+
return GetContext(ctx, c, dest, query, args...)
247+
}
248+
249+
// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt.
250+
//
251+
// The provided context is used for the preparation of the statement, not for
252+
// the execution of the statement.
253+
func (c *Conn) PreparexContext(ctx context.Context, query string) (*Stmt, error) {
254+
return PreparexContext(ctx, c, query)
255+
}
256+
257+
// QueryxContext queries the database and returns an *sqlx.Rows.
258+
// Any placeholder parameters are replaced with supplied args.
259+
func (c *Conn) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
260+
r, err := c.Conn.QueryContext(ctx, query, args...)
261+
if err != nil {
262+
return nil, err
263+
}
264+
return &Rows{Rows: r, unsafe: c.unsafe, Mapper: c.Mapper}, err
265+
}
266+
267+
// QueryRowxContext queries the database and returns an *sqlx.Row.
268+
// Any placeholder parameters are replaced with supplied args.
269+
func (c *Conn) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
270+
rows, err := c.Conn.QueryContext(ctx, query, args...)
271+
return &Row{rows: rows, err: err, unsafe: c.unsafe, Mapper: c.Mapper}
272+
}
273+
274+
// Rebind a query within a Conn's bindvar type.
275+
func (c *Conn) Rebind(query string) string {
276+
return Rebind(BindType(c.driverName), query)
277+
}
278+
211279
// StmtxContext returns a version of the prepared statement which runs within a
212280
// transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt.
213281
func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt {

sqlx_context_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,3 +1342,85 @@ func TestEmbeddedLiteralsContext(t *testing.T) {
13421342
}
13431343
})
13441344
}
1345+
1346+
func TestConn(t *testing.T) {
1347+
var schema = Schema{
1348+
create: `
1349+
CREATE TABLE tt_conn (
1350+
id integer,
1351+
value text NULL DEFAULT NULL
1352+
);`,
1353+
drop: "drop table tt_conn;",
1354+
}
1355+
1356+
RunWithSchemaContext(context.Background(), schema, t, func(ctx context.Context, db *DB, t *testing.T) {
1357+
conn, err := db.Connx(ctx)
1358+
defer conn.Close()
1359+
if err != nil {
1360+
t.Fatal(err)
1361+
}
1362+
1363+
_, err = conn.ExecContext(ctx, conn.Rebind(`INSERT INTO tt_conn (id, value) VALUES (?, ?), (?, ?)`), 1, "a", 2, "b")
1364+
if err != nil {
1365+
t.Fatal(err)
1366+
}
1367+
1368+
type s struct {
1369+
ID int `db:"id"`
1370+
Value string `db:"value"`
1371+
}
1372+
1373+
v := []s{}
1374+
1375+
err = conn.SelectContext(ctx, &v, "SELECT * FROM tt_conn ORDER BY id ASC")
1376+
if err != nil {
1377+
t.Fatal(err)
1378+
}
1379+
1380+
if v[0].ID != 1 {
1381+
t.Errorf("Expecting ID of 1, got %d", v[0].ID)
1382+
}
1383+
1384+
v1 := s{}
1385+
err = conn.GetContext(ctx, &v1, conn.Rebind("SELECT * FROM tt_conn WHERE id=?"), 1)
1386+
1387+
if err != nil {
1388+
t.Fatal(err)
1389+
}
1390+
if v1.ID != 1 {
1391+
t.Errorf("Expecting to get back 1, but got %v\n", v1.ID)
1392+
}
1393+
1394+
stmt, err := conn.PreparexContext(ctx, conn.Rebind("SELECT * FROM tt_conn WHERE id=?"))
1395+
if err != nil {
1396+
t.Fatal(err)
1397+
}
1398+
v1 = s{}
1399+
tx, err := conn.BeginTxx(ctx, nil)
1400+
if err != nil {
1401+
t.Fatal(err)
1402+
}
1403+
tstmt := tx.Stmtx(stmt)
1404+
row := tstmt.QueryRowx(1)
1405+
err = row.StructScan(&v1)
1406+
if err != nil {
1407+
t.Error(err)
1408+
}
1409+
tx.Commit()
1410+
if v1.ID != 1 {
1411+
t.Errorf("Expecting to get back 1, but got %v\n", v1.ID)
1412+
}
1413+
1414+
rows, err := conn.QueryxContext(ctx, "SELECT * FROM tt_conn")
1415+
if err != nil {
1416+
t.Fatal(err)
1417+
}
1418+
1419+
for rows.Next() {
1420+
err = rows.StructScan(&v1)
1421+
if err != nil {
1422+
t.Fatal(err)
1423+
}
1424+
}
1425+
})
1426+
}

0 commit comments

Comments
 (0)