#!/usr/bin/env python
# -*- coding: UTF-8 -*-
__author__ = "Yate"
__version__ = "1.0.0"
__maintainer__ = "Yate"
import sys
import os
import re
class Database(object):
def __init__(self):
self.tables = []
def AddTable(self,name):
t = Table(name)
self.tables.append(t)
def GetLastTable(self):
return self.tables[-1]
class Table:
def __init__(self,name):
self.name=name
self.fields=[]
self.pks=[]
def AddField(self,name,t,length,IsPK):
fd = Field(name,t,length,IsPK)
self.fields.append(fd)
def AddPK(self,name):
self.pks.append(name)
for f in self.fields:
if f.name == name:
f.IsPK = True
class Field:
def __init__(self,name,t,length,IsPK):
self.name=name
self.type=t
self.length=length
self.IsPK=IsPK
class JavaType:
def __init__(self,pkg,tn):
self.pkg = pkg
self.table_name = tn
self.clazz = ""
self.fields = []
def AddJavaField(self,sql_fn,java_fn,sql_type,java_type,IsPK):
fd = JavaFieldType(sql_fn,java_fn,sql_type,java_type,IsPK)
self.fields.append(fd)
class JavaFieldType:
def __init__(self,sql_fn,java_fn,sql_type,java_type,IsPK):
self.sql_fn = sql_fn
self.java_fn = java_fn
self.sql_type = sql_type
self.java_type = java_type
self.IsPK = IsPK
def issubstr_exp(exp,s):
r = re.match(exp,s,re.I)
if r:
return True;
return False;
def issubstr(s1, s2):
return s1 in s2
def get_between(str):
num1 = str.find("`", 0)
num2 = str.find("`", num1+1)
if num2 -num1 <= 0:
return str.split(" ")[-1].replace("(","")
res = str[(num1+1):num2]
return res
def get_field_name(str):
num1 = str.find("`", 0)
num2 = str.find("`", num1+1)
if num2 -num1 <= 0:
return str
res = str[(num1+1):num2]
return res
def get_field_type(str):
num1 = str.find("(", 0)
num2 = str.find(")", num1+1)
print str,num1,num2
if num2 == -1 and num1 > 0:
return str[:num1]
if num2 -num1 <= 0:
return str
res = str[:num1]
print str,num1,num2,res
return res
def get_field_length(str):
num1 = str.find("(", 1)
num2 = str.find(")", num1+1)
if num2 -num1 <= 0:
return 0
res = str[(num1+1):num2]
return res
def sqltype2javatype(type):
if type == 'int':
return "Integer"
if type == 'tinyint':
return "Integer"
if type == 'smallint':
return "Integer"
if type == 'bigint':
return "Long"
if type == 'varchar':
return "String"
if type == 'char':
return "String"
if type == 'timestamp':
return "java.util.Date"
if type == 'datetime':
return "java.util.Date"
if type == 'date':
return "java.util.Date"
if type == 'decimal':
return "java.math.BigDecimal"
if type == 'double':
return "java.math.BigDecimal"
if type == 'float':
return "java.math.BigDecimal"
if type == 'text':
return "String"
def py2java(pkg,Database):
for ts in db.tables:
jt = JavaType(pkg,ts.name.replace("\n",""))
javaclass=""
tmp1 = ts.name.split("_")
for x1 in tmp1:
javaclass += x1.capitalize()
jt.clazz = javaclass.replace("\n","")
for fs in ts.fields:
sql_fn = fs.name
sql_type=fs.type
java_fn =""
java_type=""
tmp2 = fs.name.split("_")
for x2 in tmp2:
if java_fn == "":
java_fn += x2
else:
java_fn += x2.capitalize()
java_type=sqltype2javatype(fs.type)
jt.AddJavaField(sql_fn,java_fn,sql_type,java_type,fs.IsPK)
writeDotJava(jt)
wirteMyBatis(jt)
wirteMapper(jt)
wirteService(jt)
wirteBaseService(jt)
def writeDotJava(jt):
author = """
/**
* 本段代码由sql2java自动生成.
* https://github.com/yangting/sql2java
* @author Yate
*/\n"""
s = "package "+jt.pkg+".metadata.entity;\n\n"
s +="import java.io.Serializable;\n\nimport lombok.EqualsAndHashCode;\nimport lombok.ToString;\nimport lombok.experimental.Accessors;"
s += author
s += "@ToString\n@EqualsAndHashCode(exclude={\""+jt.fields[0].java_fn+"\"})\n"
s +="public class "+jt.clazz+" implements Serializable{\n"
for f in jt.fields:
print jt.clazz,f.java_type ,f.java_fn
s+="private "+f.java_type +" "+f.java_fn+";\n"
for f in jt.fields:
s+="public void set"+f.java_fn[0].capitalize()+f.java_fn[1:]+"("+f.java_type+" v){\n\tthis."+f.java_fn+"=v;\n}\n\n"
s+="public "+f.java_type+" get"+f.java_fn[0].capitalize()+f.java_fn[1:]+"(){\n\treturn this."+f.java_fn+";\n}\n\n";
s+="}"
x = jt.pkg.replace(".", "/")
if not os.path.exists("src/main/java/"+x+"/metadata/entity"):
os.makedirs("src/main/java/"+x+"/metadata/entity")
f = open("src/main/java/"+x+"/metadata/entity/"+jt.clazz+".java",'w')
f.write(s) # python will convert \n to os.linesep
f.close() # you can omit in most cases as the destructor will call it
def wirteMyBatis(jt):
s="\n"
s+="\n"
s+="\n\n"
s+="\n"
for f in jt.fields:
if f.IsPK:
s+="\n"
else:
s+="\n"
s+="\n\n"
s+="\n"
s+="\n\n"
s+="\n"
s+="\n"
s+="SELECT LAST_INSERT_ID() AS id\n"
s+="\n"
s+="insert into "+jt.table_name+"("
for f in jt.fields:
if not f.IsPK:
s+=f.sql_fn+","
s=s[:-1]+") values("
for f in jt.fields:
if not f.IsPK:
s+="#{"+f.java_fn+"},"
s=s[:-1]+")\n"
s+="\n\n"
s+=""
s+="insert into "+jt.table_name+"("
for f in jt.fields:
if not f.IsPK:
s+=f.sql_fn+","
s=s[:-1]+") values\n"
s+="\n"
s+="("
for f in jt.fields:
if not f.IsPK:
s+="#{item."+f.java_fn+"},"
s=s[:-1]+")\n"
s+="\n"
s+="\n\n"
s+="\n"
s+="update "+jt.table_name+ "\n\n"
for f in jt.fields:
if not f.IsPK:
s+="\n"
s+=f.sql_fn+"=#{"+f.java_fn+"},\n"
s+="\n"
s+="\n"
s+="where "+jt.fields[0].sql_fn+"=#{"+jt.fields[0].java_fn+"}\n"
s+="\n\n"
s+="\n"
s+="delete from "+jt.table_name+ "\n"
s+="where "+jt.fields[0].sql_fn+"=#{"+jt.fields[0].java_fn+"}\n"
s+="\n\n"
s+="\n"
s+="delete from "+jt.table_name+ " where id in \n"
s+="\n"
s+="#{item."+jt.fields[0].java_fn+"}\n"
s+="\n"
s+="\n\n"
s+=""
if not os.path.exists("src/main/resources/mybatis/mapper"):
os.makedirs("src/main/resources/mybatis/mapper")
f = open("src/main/resources/mybatis/mapper/"+jt.clazz+"Mapper.xml",'w')
f.write(s) # python will convert \n to os.linesep
f.close() # you can omit in most cases as the destructor will call it
def wirteMapper(jt):
author = """
/**
* 本段代码由sql2java自动生成.
* https://github.com/yangting/sql2java
* @author Yate
*/\n"""
s = "package "+jt.pkg+".metadata.dao.mapper;\n"
s +="import org.springframework.stereotype.Repository;\n"
s +="import "+jt.pkg+".metadata.dao.IBaseMapperDao;\n"
s +="import "+jt.pkg+".metadata.entity."+jt.clazz+";\n\n"
s += author
s += "@Repository\n"
s +="public interface "+jt.clazz+"Mapper extends IBaseMapperDao<"+jt.clazz+", "+jt.fields[0].java_type+">{\n"
s +="}\n"
x = jt.pkg.replace(".", "/")
if not os.path.exists("src/main/java/"+x+"/metadata/dao/mapper"):
os.makedirs("src/main/java/"+x+"/metadata/dao/mapper")
f = open("src/main/java/"+x+"/metadata/dao/mapper/"+jt.clazz+"Mapper.java",'w')
f.write(s) # python will convert \n to os.linesep
f.close() # you can omit in most cases as the destructor will call it
base="""import java.util.List;
import org.apache.ibatis.annotations.Options;
import org.apache.ibatis.annotations.Param;
/**
* 本段代码由sql2java自动生成.
* https://github.com/yangting/sql2java
* @author Yate
*/
public interface IBaseMapperDao {
/**
* @description 通过实体进行添加
* @param e
*/
@Options(useGeneratedKeys = true, keyProperty = "id")
void add(final E e);
/**
* @description 通过集合进行批量添加
* @param e
*/
@Options(useGeneratedKeys = true, keyProperty = "id")
void batchAdd(final List list);
/**
* @description 通过主键进行删除
* @param e
*/
Integer remove(@Param(value = "id") final PK e);
/**
* @description 通过主键进行批量删除
* @param e
*/
void batchRemove(final PK[] ids);
/**
* @description 通过实体更新
* @param e
*/
Integer update(final E e);
/**
* @description 通过主键查询
* @param id
* @return
* @throws Exception
*/
E getEntity(@Param(value = "id") final PK id);
}"""
f1 = open("src/main/java/"+x+"/metadata/dao/IBaseMapperDao.java",'w')
f1.write("package "+jt.pkg+".metadata.dao;\n"+base) # python will convert \n to os.linesep
f1.close() # you can omit in most cases as the destructor will call it
def wirteService(jt):
author = """
/**
* 本段代码由sql2java自动生成.
* https://github.com/yangting/sql2java
* @author Yate
*/\n"""
s = "package "+jt.pkg+".service;\n\n"
s +="import "+jt.pkg+".metadata.entity."+jt.clazz+";\n\n"
s += author
s += "public interface I"+jt.clazz+"Service extends IBaseService<"+jt.clazz+","+jt.fields[0].java_type+">{\n\n"
s += "}"
x = jt.pkg.replace(".", "/")
if not os.path.exists("src/main/java/"+x+"/service"):
os.makedirs("src/main/java/"+x+"/service")
f1 = open("src/main/java/"+x+"/service/I"+jt.clazz+"Service.java",'w')
f1.write(s) # python will convert \n to os.linesep
f1.close() # you can omit in most cases as the destructor will call it
s1 ="package "+jt.pkg+".service.impl;\n\n"
s1+="import org.springframework.stereotype.Service;\n"
s1+="import javax.annotation.Resource;\n"
s1+="import "+jt.pkg+".metadata.dao.IBaseMapperDao;\n"
s1+="import "+jt.pkg+".metadata.dao.mapper."+jt.clazz+"Mapper;\n"
s1+="import "+jt.pkg+".service.I"+jt.clazz+"Service;\n"
s1+="import "+jt.pkg+".metadata.entity."+jt.clazz+";\n\n"
s1+=author
s1+="@Service\n"
s1+="public class "+jt.clazz+"ServiceImpl extends BaseServiceImpl<"+jt.clazz+","+jt.fields[0].java_type+"> implements I"+jt.clazz+"Service {\n"
s1+="@Resource\n"
s1+="private "+jt.clazz+"Mapper mapper;\n\n"
s1+="protected IBaseMapperDao<"+jt.clazz+","+jt.fields[0].java_type+"> getMapperDao() {\n"
s1+="return mapper;\n"
s1+="}\n}"
x1 = jt.pkg.replace(".", "/")
if not os.path.exists("src/main/java/"+x1+"/service/impl"):
os.makedirs("src/main/java/"+x1+"/service/impl")
f1 = open("src/main/java/"+x1+"/service/impl/"+jt.clazz+"ServiceImpl.java",'w')
f1.write(s1) # python will convert \n to os.linesep
f1.close() # you can omit in most cases as the destructor will call it
def wirteBaseService(jt):
author = """
/**
* 本段代码由sql2java自动生成.
* https://github.com/yangting/sql2java
* @author Yate
*/\n"""
ibase = """
import java.util.List;
/**
* 本段代码由sql2java自动生成.
* https://github.com/yangting/sql2java
* @author Yate
*/
public interface IBaseService {
void add(final E e);
void batchAdd(final List list);
int remove(final PK id);
void batchRemove(final PK[] ids);
int update(final E e);
E getEntity(final PK id);
}"""
x = jt.pkg.replace(".", "/")
if not os.path.exists("src/main/java/"+x+"/service"):
os.makedirs("src/main/java/"+x+"/service")
f1 = open("src/main/java/"+x+"/service/IBaseService.java",'w')
f1.write("package "+jt.pkg+".service;\n"+ibase) # python will convert \n to os.linesep
f1.close() # you can omit in most cases as the destructor will call it
xbase ="import "+jt.pkg+".metadata.dao.IBaseMapperDao;\n"
xbase +="import "+jt.pkg+".service.IBaseService;\n"
xbase +="import java.util.List;\n\n"
xbase +=author
xbase +="public abstract class BaseServiceImpl implements IBaseService {\n"
xbase +="protected abstract IBaseMapperDao getMapperDao();\n"
xbase +="public void add(E e) {\n"
xbase +="this.getMapperDao().add(e);\n"
xbase +="}\n\n"
xbase +="public void batchAdd(List list) {\n"
xbase +="this.getMapperDao().batchAdd(list);\n"
xbase +="}\n\n"
xbase +="public int remove(PK id) {\n"
xbase +="return this.getMapperDao().remove(id);\n"
xbase +="}\n\n"
xbase +="public void batchRemove(PK[] ids) {\n"
xbase +="this.getMapperDao().batchRemove(ids);\n"
xbase +="}\n\n"
xbase +="public int update(E e) {\n"
xbase +="return this.getMapperDao().update(e);\n"
xbase +="}\n\n"
xbase +="public E getEntity(PK id) {\n"
xbase +="return this.getMapperDao().getEntity(id);\n"
xbase +="}\n\n}"
f1 = open("src/main/java/"+x+"/service/impl/BaseServiceImpl.java",'w')
f1.write("package "+jt.pkg+".service.impl;\n\n"+xbase) # python will convert \n to os.linesep
f1.close() # you can omit in most cases as the destructor will call it
if __name__=='__main__':
sql_file = sys.argv[1]
pkg_name = sys.argv[2]
f = open(sql_file)
db = Database()
line = f.readline()
skip = False
while line:
print line
if skip :
if issubstr_exp(".*\*/",line):
skip = False
line = f.readline()
continue
else:
line = f.readline()
continue
if issubstr_exp("^/\*.*",line):
skip = True
line = f.readline()
continue
if issubstr_exp("^SET.*",line):
line = f.readline()
continue
if issubstr_exp(".*INDEX.*",line):
line = f.readline()
continue
if issubstr_exp("DROP.*",line):
line = f.readline()
continue
if issubstr_exp(".*SET.*FOREIGN_KEY_CHECKS.*",line):
line = f.readline()
continue
if issubstr_exp("^--.*",line):
line = f.readline()
continue
if issubstr_exp("^/\*.*",line):
line = f.readline()
continue
if issubstr_exp(".*drop.*table.*",line):
line = f.readline()
continue
if issubstr_exp(".*create.*table.*",line):
table_name = get_between(line)
print table_name
db.AddTable(table_name)
elif issubstr_exp(".*primary.*key.*",line):
r = re.match(".*\((.*)\)",line,re.I)
if r:
pkstr = r.group(1)
pkarr = pkstr.split(',')
#print pkstr,pkarr
for x in pkarr:
#print x.replace("`","")
db.GetLastTable().AddPK(x.replace("`",""))
line = f.readline()
continue
elif issubstr_exp("\s*key.*",line):
line = f.readline()
continue
elif issubstr_exp(".*engine.*",line):
line = f.readline()
continue
else:
ds = line.split()
if len(ds) > 0:
fname = get_field_name(ds[0])
if ds[1]:
x = ds[1]
ftype = get_field_type(x)
flen = get_field_length(x)
print "test=",db.GetLastTable().name,fname,ftype,flen
if fname == None :
break;
if ftype == None :
break;
db.GetLastTable().AddField(fname,ftype,flen,False)
line = f.readline()
f.close()
py2java(pkg_name,db)