@@ -88,61 +88,62 @@ func rebindBuff(bindType int, query string) string {
8888 return rqb .String ()
8989}
9090
91- // In expands slice query params in args, returning the modified query string
92- // and a new list of args passable to Exec/Query/etc. It requires queries using
93- // the '?' bindvar and returns queries using the '?' bindvar .
91+ // In expands slice values in args, returning the modified query string
92+ // and a new arg list that can be executed by a database. The `query` should
93+ // use the `?` bindVar. The return value uses the `?` bindVar .
9494func In (query string , args ... interface {}) (string , []interface {}, error ) {
95- // TODO: validate this short circuit as actually saving any time..
96- type slice struct {
97- v reflect.Value
98- t reflect.Type
99- l int
95+ type ra struct {
96+ v reflect.Value
97+ t reflect.Type
98+ isSlice bool
10099 }
101- slices := make ([]* slice , 0 , len (args ))
100+ ras := make ([]ra , 0 , len (args ))
102101 for _ , arg := range args {
103102 v := reflect .ValueOf (arg )
104103 t , _ := baseType (v .Type (), reflect .Slice )
105- if t != nil {
106- slices = append (slices , & slice {v , t , v .Len ()})
107- } else {
108- slices = append (slices , nil )
109- }
104+ ras = append (ras , ra {v , t , t != nil })
110105 }
111- numArgs := 0
106+
112107 anySlices := false
113- for _ , s := range slices {
114- if s != nil {
108+ for _ , s := range ras {
109+ if s . isSlice {
115110 anySlices = true
116- numArgs += s .l
117- if s .l == 0 {
111+ if s .v .Len () == 0 {
118112 return "" , nil , errors .New ("empty slice passed to 'in' query" )
119113 }
120- } else {
121- numArgs ++
122114 }
123115 }
124116
125- // if there's no slice kind args at all, just return the original query & args
117+ // don't do any parsing if there aren't any slices; note that this means
118+ // some errors that we might have caught below will not be returned.
126119 if ! anySlices {
127120 return query , args , nil
128121 }
129122
130- a := make ( []interface {}, 0 , numArgs )
123+ var a []interface {}
131124 var buf bytes.Buffer
132125 var pos int
126+
133127 for _ , r := range query {
134128 if r == '?' {
135- // XXX: we have probably done something quite wrong here
136- if pos >= len (slices ) {
129+ if pos >= len (ras ) {
130+ // if this argument wasn't passed, lets return an error; this is
131+ // not actually how database/sql Exec/Query works, but since we are
132+ // creating an argument list programmatically, we want to be able
133+ // to catch these programmer errors earlier.
137134 return "" , nil , errors .New ("number of bindVars exceeds arguments" )
138- } else if slices [pos ] != nil {
139- for i := 0 ; i < slices [pos ].l - 1 ; i ++ {
135+ } else if ras [pos ].isSlice {
136+ // if this argument is a slice, expand the slice into arguments and
137+ // assume that the bindVars should be comma separated.
138+ length := ras [pos ].v .Len ()
139+ for i := 0 ; i < length - 1 ; i ++ {
140140 buf .Write ([]byte ("?, " ))
141- a = append (a , slices [pos ].v .Index (i ).Interface ())
141+ a = append (a , ras [pos ].v .Index (i ).Interface ())
142142 }
143- a = append (a , slices [pos ].v .Index (slices [ pos ]. l - 1 ).Interface ())
143+ a = append (a , ras [pos ].v .Index (length - 1 ).Interface ())
144144 buf .WriteRune ('?' )
145145 } else {
146+ // a normal argument, procede as normal.
146147 a = append (a , args [pos ])
147148 buf .WriteRune (r )
148149 }
@@ -152,7 +153,7 @@ func In(query string, args ...interface{}) (string, []interface{}, error) {
152153 }
153154 }
154155
155- if pos != len (slices ) {
156+ if pos != len (ras ) {
156157 return "" , nil , errors .New ("number of bindVars less than number arguments" )
157158 }
158159
0 commit comments