Skip to content

Commit 1723f86

Browse files
authored
Merge pull request jmoiron#718 from abraithwaite/alan/fix-named-batch
NamedExec Bulk Insert Fix
2 parents a1d5e64 + df9bf98 commit 1723f86

File tree

2 files changed

+87
-9
lines changed

2 files changed

+87
-9
lines changed

named.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -224,21 +224,28 @@ func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper)
224224
return bound, arglist, nil
225225
}
226226

227-
var valueBracketReg = regexp.MustCompile(`\([^(]*.[^(]\)\s*$`)
227+
var valueBracketReg = regexp.MustCompile(`VALUES\s+(\([^(]*.[^(]\))`)
228228

229229
func fixBound(bound string, loop int) string {
230-
loc := valueBracketReg.FindStringIndex(bound)
231-
if len(loc) != 2 {
230+
loc := valueBracketReg.FindAllStringSubmatchIndex(bound, -1)
231+
// Either no VALUES () found or more than one found??
232+
if len(loc) != 1 {
233+
return bound
234+
}
235+
// defensive guard. loc should be len 4 representing the starting and
236+
// ending index for the whole regex match and the starting + ending
237+
// index for the single inside group
238+
if len(loc[0]) != 4 {
232239
return bound
233240
}
234241
var buffer bytes.Buffer
235242

236-
buffer.WriteString(bound[0:loc[1]])
243+
buffer.WriteString(bound[0:loc[0][1]])
237244
for i := 0; i < loop-1; i++ {
238245
buffer.WriteString(",")
239-
buffer.WriteString(bound[loc[0]:loc[1]])
246+
buffer.WriteString(bound[loc[0][2]:loc[0][3]])
240247
}
241-
buffer.WriteString(bound[loc[1]:])
248+
buffer.WriteString(bound[loc[0][1]:])
242249
return buffer.String()
243250
}
244251

named_test.go

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sqlx
33
import (
44
"database/sql"
55
"fmt"
6+
"regexp"
67
"testing"
78
)
89

@@ -202,7 +203,10 @@ func TestNamedQueries(t *testing.T) {
202203
{FirstName: "Ngani", LastName: "Laumape", Email: "[email protected]"},
203204
}
204205

205-
insert := fmt.Sprintf("INSERT INTO person (first_name, last_name, email, added_at) VALUES (:first_name, :last_name, :email, %v)\n", now)
206+
insert := fmt.Sprintf(
207+
"INSERT INTO person (first_name, last_name, email, added_at) VALUES (:first_name, :last_name, :email, %v)\n",
208+
now,
209+
)
206210
_, err = db.NamedExec(insert, sls)
207211
test.Error(err)
208212

@@ -214,7 +218,7 @@ func TestNamedQueries(t *testing.T) {
214218
}
215219

216220
_, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email)
217-
VALUES (:first_name, :last_name, :email) `, slsMap)
221+
VALUES (:first_name, :last_name, :email) ;--`, slsMap)
218222
test.Error(err)
219223

220224
type A map[string]interface{}
@@ -226,7 +230,7 @@ func TestNamedQueries(t *testing.T) {
226230
}
227231

228232
_, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email)
229-
VALUES (:first_name, :last_name, :email) `, typedMap)
233+
VALUES (:first_name, :last_name, :email) ;--`, typedMap)
230234
test.Error(err)
231235

232236
for _, p := range sls {
@@ -296,3 +300,70 @@ func TestNamedQueries(t *testing.T) {
296300

297301
})
298302
}
303+
304+
func TestFixBounds(t *testing.T) {
305+
table := []struct {
306+
name, query, expect string
307+
loop int
308+
}{
309+
{
310+
name: `named syntax`,
311+
query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
312+
expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last)`,
313+
loop: 2,
314+
},
315+
{
316+
name: `mysql syntax`,
317+
query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
318+
expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?)`,
319+
loop: 2,
320+
},
321+
{
322+
name: `named syntax w/ trailer`,
323+
query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) ;--`,
324+
expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) ;--`,
325+
loop: 2,
326+
},
327+
{
328+
name: `mysql syntax w/ trailer`,
329+
query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?) ;--`,
330+
expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?) ;--`,
331+
loop: 2,
332+
},
333+
{
334+
name: `not found test`,
335+
query: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`,
336+
expect: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`,
337+
loop: 2,
338+
},
339+
{
340+
name: `found twice test`,
341+
query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`,
342+
expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`,
343+
loop: 2,
344+
},
345+
}
346+
347+
for _, tc := range table {
348+
t.Run(tc.name, func(t *testing.T) {
349+
res := fixBound(tc.query, tc.loop)
350+
if res != tc.expect {
351+
t.Errorf("mismatched results")
352+
}
353+
})
354+
}
355+
356+
t.Run("regex changed", func(t *testing.T) {
357+
var valueBracketRegChanged = regexp.MustCompile(`(VALUES)\s+(\([^(]*.[^(]\))`)
358+
saveRegexp := valueBracketReg
359+
defer func() {
360+
valueBracketReg = saveRegexp
361+
}()
362+
valueBracketReg = valueBracketRegChanged
363+
364+
res := fixBound("VALUES (:a, :b)", 2)
365+
if res != "VALUES (:a, :b)" {
366+
t.Errorf("changed regex should return string")
367+
}
368+
})
369+
}

0 commit comments

Comments
 (0)