Skip to content

Commit 7400168

Browse files
committed
Optimize sqlx.In and add benchmark
1 parent 56b62f2 commit 7400168

File tree

2 files changed

+79
-45
lines changed

2 files changed

+79
-45
lines changed

bind.go

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"reflect"
77
"strconv"
8+
"strings"
89
)
910

1011
// Bindvar types supported by Rebind, BindMap and BindStruct.
@@ -92,70 +93,95 @@ func rebindBuff(bindType int, query string) string {
9293
// and a new arg list that can be executed by a database. The `query` should
9394
// use the `?` bindVar. The return value uses the `?` bindVar.
9495
func In(query string, args ...interface{}) (string, []interface{}, error) {
95-
type ra struct {
96-
v reflect.Value
97-
t reflect.Type
98-
isSlice bool
96+
// argMeta stores reflect.Value and length for slices and
97+
// the value itself for non-slice arguments
98+
type argMeta struct {
99+
v reflect.Value
100+
i interface{}
101+
length int
99102
}
100-
ras := make([]ra, 0, len(args))
101-
for _, arg := range args {
103+
104+
var flatArgsCount, sliceCount int
105+
106+
meta := make([]argMeta, len(args))
107+
108+
for i, arg := range args {
102109
v := reflect.ValueOf(arg)
103110
t, _ := baseType(v.Type(), reflect.Slice)
104-
ras = append(ras, ra{v, t, t != nil})
105-
}
106111

107-
anySlices := false
108-
for _, s := range ras {
109-
if s.isSlice {
110-
anySlices = true
111-
if s.v.Len() == 0 {
112+
if t != nil {
113+
meta[i].length = v.Len()
114+
meta[i].v = v
115+
116+
sliceCount++
117+
flatArgsCount += meta[i].length
118+
119+
if meta[i].length == 0 {
112120
return "", nil, errors.New("empty slice passed to 'in' query")
113121
}
122+
} else {
123+
meta[i].i = arg
124+
flatArgsCount++
114125
}
115126
}
116127

117128
// don't do any parsing if there aren't any slices; note that this means
118129
// some errors that we might have caught below will not be returned.
119-
if !anySlices {
130+
if sliceCount == 0 {
120131
return query, args, nil
121132
}
122133

123-
var a []interface{}
134+
newArgs := make([]interface{}, 0, flatArgsCount)
135+
136+
var arg, offset int
124137
var buf bytes.Buffer
125-
var pos int
126138

127-
for _, r := range query {
128-
if r == '?' {
129-
if pos >= len(ras) {
130-
// if this argument wasn't passed, lets return an error; this is
131-
// not actually how database/sql Exec/Query works, but since we are
132-
// creating an argument list programmatically, we want to be able
133-
// to catch these programmer errors earlier.
134-
return "", nil, errors.New("number of bindVars exceeds arguments")
135-
} else if ras[pos].isSlice {
136-
// if this argument is a slice, expand the slice into arguments and
137-
// assume that the bindVars should be comma separated.
138-
length := ras[pos].v.Len()
139-
for i := 0; i < length-1; i++ {
140-
buf.Write([]byte("?, "))
141-
a = append(a, ras[pos].v.Index(i).Interface())
142-
}
143-
a = append(a, ras[pos].v.Index(length-1).Interface())
144-
buf.WriteRune('?')
145-
} else {
146-
// a normal argument, procede as normal.
147-
a = append(a, args[pos])
148-
buf.WriteRune(r)
149-
}
150-
pos++
151-
} else {
152-
buf.WriteRune(r)
139+
for i := strings.IndexByte(query[offset:], '?'); i != -1 && arg < len(meta); i = strings.IndexByte(query[offset:], '?') {
140+
argMeta := meta[arg]
141+
arg++
142+
143+
// not a slice, continue.
144+
// our questionmark will either be written before the next expansion
145+
// of a slice or after the loop when writing the rest of the query
146+
if argMeta.length == 0 {
147+
offset = offset + i + 1
148+
newArgs = append(newArgs, argMeta.i)
149+
continue
150+
}
151+
152+
// write everything up to and including our ? character
153+
buf.WriteString(query[:offset+i+1])
154+
155+
newArgs = append(newArgs, argMeta.v.Index(0).Interface())
156+
157+
for si := 1; si < argMeta.length; si++ {
158+
buf.WriteString(", ?")
159+
newArgs = append(newArgs, argMeta.v.Index(si).Interface())
153160
}
161+
162+
// slice the query and reset the offset. this avoids some bookkeeping for
163+
// the write after the loop
164+
query = query[offset+i+1:]
165+
offset = 0
154166
}
155167

156-
if pos != len(ras) {
168+
buf.WriteString(query)
169+
170+
if arg < len(meta) {
157171
return "", nil, errors.New("number of bindVars less than number arguments")
158172
}
159173

160-
return buf.String(), a, nil
174+
// get the result as bytes first, to avoid converting to a string if we return
175+
// an error
176+
res := buf.Bytes()
177+
178+
if bytes.Count(res, []byte{'?'}) > flatArgsCount {
179+
// if an argument wasn't passed, lets return an error; this is
180+
// not actually how database/sql Exec/Query works, but since we are
181+
// creating an argument list programmatically, we want to be able
182+
// to catch these programmer errors earlier.
183+
return "", nil, errors.New("number of bindVars exceeds arguments")
184+
}
185+
186+
return string(res), newArgs, nil
161187
}

sqlx_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ func TestIn(t *testing.T) {
12281228
t.Error(err)
12291229
}
12301230
if len(a) != test.c {
1231-
t.Errorf("Expected %d args, but got %d (%+v)", len(a), test.c, a)
1231+
t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a)
12321232
}
12331233
if strings.Count(q, "?") != test.c {
12341234
t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?"))
@@ -1460,6 +1460,14 @@ func BenchmarkBindMap(b *testing.B) {
14601460
}
14611461
}
14621462

1463+
func BenchmarkIn(b *testing.B) {
1464+
q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?`
1465+
1466+
for i := 0; i < b.N; i++ {
1467+
_, _, _ = In(q, []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}...)
1468+
}
1469+
}
1470+
14631471
func BenchmarkRebind(b *testing.B) {
14641472
b.StopTimer()
14651473
q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`

0 commit comments

Comments
 (0)