|
5 | 5 | "errors" |
6 | 6 | "reflect" |
7 | 7 | "strconv" |
| 8 | + "strings" |
8 | 9 | ) |
9 | 10 |
|
10 | 11 | // Bindvar types supported by Rebind, BindMap and BindStruct. |
@@ -92,70 +93,95 @@ func rebindBuff(bindType int, query string) string { |
92 | 93 | // and a new arg list that can be executed by a database. The `query` should |
93 | 94 | // use the `?` bindVar. The return value uses the `?` bindVar. |
94 | 95 | func In(query string, args ...interface{}) (string, []interface{}, error) { |
95 | | - type ra struct { |
96 | | - v reflect.Value |
97 | | - t reflect.Type |
98 | | - isSlice bool |
| 96 | + // argMeta stores reflect.Value and length for slices and |
| 97 | + // the value itself for non-slice arguments |
| 98 | + type argMeta struct { |
| 99 | + v reflect.Value |
| 100 | + i interface{} |
| 101 | + length int |
99 | 102 | } |
100 | | - ras := make([]ra, 0, len(args)) |
101 | | - for _, arg := range args { |
| 103 | + |
| 104 | + var flatArgsCount, sliceCount int |
| 105 | + |
| 106 | + meta := make([]argMeta, len(args)) |
| 107 | + |
| 108 | + for i, arg := range args { |
102 | 109 | v := reflect.ValueOf(arg) |
103 | 110 | t, _ := baseType(v.Type(), reflect.Slice) |
104 | | - ras = append(ras, ra{v, t, t != nil}) |
105 | | - } |
106 | 111 |
|
107 | | - anySlices := false |
108 | | - for _, s := range ras { |
109 | | - if s.isSlice { |
110 | | - anySlices = true |
111 | | - if s.v.Len() == 0 { |
| 112 | + if t != nil { |
| 113 | + meta[i].length = v.Len() |
| 114 | + meta[i].v = v |
| 115 | + |
| 116 | + sliceCount++ |
| 117 | + flatArgsCount += meta[i].length |
| 118 | + |
| 119 | + if meta[i].length == 0 { |
112 | 120 | return "", nil, errors.New("empty slice passed to 'in' query") |
113 | 121 | } |
| 122 | + } else { |
| 123 | + meta[i].i = arg |
| 124 | + flatArgsCount++ |
114 | 125 | } |
115 | 126 | } |
116 | 127 |
|
117 | 128 | // don't do any parsing if there aren't any slices; note that this means |
118 | 129 | // some errors that we might have caught below will not be returned. |
119 | | - if !anySlices { |
| 130 | + if sliceCount == 0 { |
120 | 131 | return query, args, nil |
121 | 132 | } |
122 | 133 |
|
123 | | - var a []interface{} |
| 134 | + newArgs := make([]interface{}, 0, flatArgsCount) |
| 135 | + |
| 136 | + var arg, offset int |
124 | 137 | var buf bytes.Buffer |
125 | | - var pos int |
126 | 138 |
|
127 | | - for _, r := range query { |
128 | | - if r == '?' { |
129 | | - if pos >= len(ras) { |
130 | | - // if this argument wasn't passed, lets return an error; this is |
131 | | - // not actually how database/sql Exec/Query works, but since we are |
132 | | - // creating an argument list programmatically, we want to be able |
133 | | - // to catch these programmer errors earlier. |
134 | | - return "", nil, errors.New("number of bindVars exceeds arguments") |
135 | | - } else if ras[pos].isSlice { |
136 | | - // if this argument is a slice, expand the slice into arguments and |
137 | | - // assume that the bindVars should be comma separated. |
138 | | - length := ras[pos].v.Len() |
139 | | - for i := 0; i < length-1; i++ { |
140 | | - buf.Write([]byte("?, ")) |
141 | | - a = append(a, ras[pos].v.Index(i).Interface()) |
142 | | - } |
143 | | - a = append(a, ras[pos].v.Index(length-1).Interface()) |
144 | | - buf.WriteRune('?') |
145 | | - } else { |
146 | | - // a normal argument, procede as normal. |
147 | | - a = append(a, args[pos]) |
148 | | - buf.WriteRune(r) |
149 | | - } |
150 | | - pos++ |
151 | | - } else { |
152 | | - buf.WriteRune(r) |
| 139 | + for i := strings.IndexByte(query[offset:], '?'); i != -1 && arg < len(meta); i = strings.IndexByte(query[offset:], '?') { |
| 140 | + argMeta := meta[arg] |
| 141 | + arg++ |
| 142 | + |
| 143 | + // not a slice, continue. |
| 144 | + // our questionmark will either be written before the next expansion |
| 145 | + // of a slice or after the loop when writing the rest of the query |
| 146 | + if argMeta.length == 0 { |
| 147 | + offset = offset + i + 1 |
| 148 | + newArgs = append(newArgs, argMeta.i) |
| 149 | + continue |
| 150 | + } |
| 151 | + |
| 152 | + // write everything up to and including our ? character |
| 153 | + buf.WriteString(query[:offset+i+1]) |
| 154 | + |
| 155 | + newArgs = append(newArgs, argMeta.v.Index(0).Interface()) |
| 156 | + |
| 157 | + for si := 1; si < argMeta.length; si++ { |
| 158 | + buf.WriteString(", ?") |
| 159 | + newArgs = append(newArgs, argMeta.v.Index(si).Interface()) |
153 | 160 | } |
| 161 | + |
| 162 | + // slice the query and reset the offset. this avoids some bookkeeping for |
| 163 | + // the write after the loop |
| 164 | + query = query[offset+i+1:] |
| 165 | + offset = 0 |
154 | 166 | } |
155 | 167 |
|
156 | | - if pos != len(ras) { |
| 168 | + buf.WriteString(query) |
| 169 | + |
| 170 | + if arg < len(meta) { |
157 | 171 | return "", nil, errors.New("number of bindVars less than number arguments") |
158 | 172 | } |
159 | 173 |
|
160 | | - return buf.String(), a, nil |
| 174 | + // get the result as bytes first, to avoid converting to a string if we return |
| 175 | + // an error |
| 176 | + res := buf.Bytes() |
| 177 | + |
| 178 | + if bytes.Count(res, []byte{'?'}) > flatArgsCount { |
| 179 | + // if an argument wasn't passed, lets return an error; this is |
| 180 | + // not actually how database/sql Exec/Query works, but since we are |
| 181 | + // creating an argument list programmatically, we want to be able |
| 182 | + // to catch these programmer errors earlier. |
| 183 | + return "", nil, errors.New("number of bindVars exceeds arguments") |
| 184 | + } |
| 185 | + |
| 186 | + return string(res), newArgs, nil |
161 | 187 | } |
0 commit comments