Skip to content

Commit 55f1bb4

Browse files
committed
clean up the In code a little
1 parent 27e05d7 commit 55f1bb4

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed

bind.go

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -88,61 +88,62 @@ func rebindBuff(bindType int, query string) string {
8888
return rqb.String()
8989
}
9090

91-
// In expands slice query params in args, returning the modified query string
92-
// and a new list of args passable to Exec/Query/etc. It requires queries using
93-
// the '?' bindvar and returns queries using the '?' bindvar.
91+
// In expands slice values in args, returning the modified query string
92+
// and a new arg list that can be executed by a database. The `query` should
93+
// use the `?` bindVar. The return value uses the `?` bindVar.
9494
func In(query string, args ...interface{}) (string, []interface{}, error) {
95-
// TODO: validate this short circuit as actually saving any time..
96-
type slice struct {
97-
v reflect.Value
98-
t reflect.Type
99-
l int
95+
type ra struct {
96+
v reflect.Value
97+
t reflect.Type
98+
isSlice bool
10099
}
101-
slices := make([]*slice, 0, len(args))
100+
ras := make([]ra, 0, len(args))
102101
for _, arg := range args {
103102
v := reflect.ValueOf(arg)
104103
t, _ := baseType(v.Type(), reflect.Slice)
105-
if t != nil {
106-
slices = append(slices, &slice{v, t, v.Len()})
107-
} else {
108-
slices = append(slices, nil)
109-
}
104+
ras = append(ras, ra{v, t, t != nil})
110105
}
111-
numArgs := 0
106+
112107
anySlices := false
113-
for _, s := range slices {
114-
if s != nil {
108+
for _, s := range ras {
109+
if s.isSlice {
115110
anySlices = true
116-
numArgs += s.l
117-
if s.l == 0 {
111+
if s.v.Len() == 0 {
118112
return "", nil, errors.New("empty slice passed to 'in' query")
119113
}
120-
} else {
121-
numArgs++
122114
}
123115
}
124116

125-
// if there's no slice kind args at all, just return the original query & args
117+
// don't do any parsing if there aren't any slices; note that this means
118+
// some errors that we might have caught below will not be returned.
126119
if !anySlices {
127120
return query, args, nil
128121
}
129122

130-
a := make([]interface{}, 0, numArgs)
123+
var a []interface{}
131124
var buf bytes.Buffer
132125
var pos int
126+
133127
for _, r := range query {
134128
if r == '?' {
135-
// XXX: we have probably done something quite wrong here
136-
if pos >= len(slices) {
129+
if pos >= len(ras) {
130+
// if this argument wasn't passed, lets return an error; this is
131+
// not actually how database/sql Exec/Query works, but since we are
132+
// creating an argument list programmatically, we want to be able
133+
// to catch these programmer errors earlier.
137134
return "", nil, errors.New("number of bindVars exceeds arguments")
138-
} else if slices[pos] != nil {
139-
for i := 0; i < slices[pos].l-1; i++ {
135+
} else if ras[pos].isSlice {
136+
// if this argument is a slice, expand the slice into arguments and
137+
// assume that the bindVars should be comma separated.
138+
length := ras[pos].v.Len()
139+
for i := 0; i < length-1; i++ {
140140
buf.Write([]byte("?, "))
141-
a = append(a, slices[pos].v.Index(i).Interface())
141+
a = append(a, ras[pos].v.Index(i).Interface())
142142
}
143-
a = append(a, slices[pos].v.Index(slices[pos].l-1).Interface())
143+
a = append(a, ras[pos].v.Index(length-1).Interface())
144144
buf.WriteRune('?')
145145
} else {
146+
// a normal argument, procede as normal.
146147
a = append(a, args[pos])
147148
buf.WriteRune(r)
148149
}
@@ -152,7 +153,7 @@ func In(query string, args ...interface{}) (string, []interface{}, error) {
152153
}
153154
}
154155

155-
if pos != len(slices) {
156+
if pos != len(ras) {
156157
return "", nil, errors.New("number of bindVars less than number arguments")
157158
}
158159

0 commit comments

Comments
 (0)