Skip to content

Commit 2545fb9

Browse files
committed
Merge pull request jmoiron#156 from nuss-justin/master
Refactor and optimize 'In'. Add benchmark.
2 parents 56b62f2 + 56cd151 commit 2545fb9

File tree

2 files changed

+77
-44
lines changed

2 files changed

+77
-44
lines changed

bind.go

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ import (
55
"errors"
66
"reflect"
77
"strconv"
8+
"strings"
9+
10+
"github.com/jmoiron/sqlx/reflectx"
811
)
912

1013
// Bindvar types supported by Rebind, BindMap and BindStruct.
@@ -92,25 +95,36 @@ func rebindBuff(bindType int, query string) string {
9295
// and a new arg list that can be executed by a database. The `query` should
9396
// use the `?` bindVar. The return value uses the `?` bindVar.
9497
func In(query string, args ...interface{}) (string, []interface{}, error) {
95-
type ra struct {
96-
v reflect.Value
97-
t reflect.Type
98-
isSlice bool
98+
// argMeta stores reflect.Value and length for slices and
99+
// the value itself for non-slice arguments
100+
type argMeta struct {
101+
v reflect.Value
102+
i interface{}
103+
length int
99104
}
100-
ras := make([]ra, 0, len(args))
101-
for _, arg := range args {
105+
106+
var flatArgsCount int
107+
var anySlices bool
108+
109+
meta := make([]argMeta, len(args))
110+
111+
for i, arg := range args {
102112
v := reflect.ValueOf(arg)
103-
t, _ := baseType(v.Type(), reflect.Slice)
104-
ras = append(ras, ra{v, t, t != nil})
105-
}
113+
t := reflectx.Deref(v.Type())
114+
115+
if t.Kind() == reflect.Slice {
116+
meta[i].length = v.Len()
117+
meta[i].v = v
106118

107-
anySlices := false
108-
for _, s := range ras {
109-
if s.isSlice {
110119
anySlices = true
111-
if s.v.Len() == 0 {
120+
flatArgsCount += meta[i].length
121+
122+
if meta[i].length == 0 {
112123
return "", nil, errors.New("empty slice passed to 'in' query")
113124
}
125+
} else {
126+
meta[i].i = arg
127+
flatArgsCount++
114128
}
115129
}
116130

@@ -120,42 +134,53 @@ func In(query string, args ...interface{}) (string, []interface{}, error) {
120134
return query, args, nil
121135
}
122136

123-
var a []interface{}
137+
newArgs := make([]interface{}, 0, flatArgsCount)
138+
139+
var arg, offset int
124140
var buf bytes.Buffer
125-
var pos int
126141

127-
for _, r := range query {
128-
if r == '?' {
129-
if pos >= len(ras) {
130-
// if this argument wasn't passed, lets return an error; this is
131-
// not actually how database/sql Exec/Query works, but since we are
132-
// creating an argument list programmatically, we want to be able
133-
// to catch these programmer errors earlier.
134-
return "", nil, errors.New("number of bindVars exceeds arguments")
135-
} else if ras[pos].isSlice {
136-
// if this argument is a slice, expand the slice into arguments and
137-
// assume that the bindVars should be comma separated.
138-
length := ras[pos].v.Len()
139-
for i := 0; i < length-1; i++ {
140-
buf.Write([]byte("?, "))
141-
a = append(a, ras[pos].v.Index(i).Interface())
142-
}
143-
a = append(a, ras[pos].v.Index(length-1).Interface())
144-
buf.WriteRune('?')
145-
} else {
146-
// a normal argument, procede as normal.
147-
a = append(a, args[pos])
148-
buf.WriteRune(r)
149-
}
150-
pos++
151-
} else {
152-
buf.WriteRune(r)
142+
for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') {
143+
if arg >= len(meta) {
144+
// if an argument wasn't passed, lets return an error; this is
145+
// not actually how database/sql Exec/Query works, but since we are
146+
// creating an argument list programmatically, we want to be able
147+
// to catch these programmer errors earlier.
148+
return "", nil, errors.New("number of bindVars exceeds arguments")
149+
}
150+
151+
argMeta := meta[arg]
152+
arg++
153+
154+
// not a slice, continue.
155+
// our questionmark will either be written before the next expansion
156+
// of a slice or after the loop when writing the rest of the query
157+
if argMeta.length == 0 {
158+
offset = offset + i + 1
159+
newArgs = append(newArgs, argMeta.i)
160+
continue
153161
}
162+
163+
// write everything up to and including our ? character
164+
buf.WriteString(query[:offset+i+1])
165+
166+
newArgs = append(newArgs, argMeta.v.Index(0).Interface())
167+
168+
for si := 1; si < argMeta.length; si++ {
169+
buf.WriteString(", ?")
170+
newArgs = append(newArgs, argMeta.v.Index(si).Interface())
171+
}
172+
173+
// slice the query and reset the offset. this avoids some bookkeeping for
174+
// the write after the loop
175+
query = query[offset+i+1:]
176+
offset = 0
154177
}
155178

156-
if pos != len(ras) {
179+
buf.WriteString(query)
180+
181+
if arg < len(meta) {
157182
return "", nil, errors.New("number of bindVars less than number arguments")
158183
}
159184

160-
return buf.String(), a, nil
185+
return buf.String(), newArgs, nil
161186
}

sqlx_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ func TestIn(t *testing.T) {
12281228
t.Error(err)
12291229
}
12301230
if len(a) != test.c {
1231-
t.Errorf("Expected %d args, but got %d (%+v)", len(a), test.c, a)
1231+
t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a)
12321232
}
12331233
if strings.Count(q, "?") != test.c {
12341234
t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?"))
@@ -1460,6 +1460,14 @@ func BenchmarkBindMap(b *testing.B) {
14601460
}
14611461
}
14621462

1463+
func BenchmarkIn(b *testing.B) {
1464+
q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?`
1465+
1466+
for i := 0; i < b.N; i++ {
1467+
_, _, _ = In(q, []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}...)
1468+
}
1469+
}
1470+
14631471
func BenchmarkRebind(b *testing.B) {
14641472
b.StopTimer()
14651473
q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`

0 commit comments

Comments
 (0)