database/sql2go/parser.go

119 lines
3.0 KiB
Go

// Package sql2go...
//
// Description : sql2go...
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 2021-10-25 4:49 下午
package sql2go
import (
"errors"
"strings"
"git.zhangdeman.cn/zhangdeman/util"
"github.com/xwb1989/sqlparser"
)
const (
// CreateSQLColumnTPL 每个字段的模版
CreateSQLColumnTPL = " {FIELD} {TYPE} `json:\"{JSON_TAG}\" gorm:\"column:{COLUMN};default:{DEFAULT_VALUE};{NOT_NULL}\"` // {COMMENT}"
)
// BasicTableInfo ...
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 11:47 上午 2021/11/17
type BasicTableInfo struct {
TableName string
ModelStruct string
PrimaryField string
PrimaryFieldType string
}
// ParseCreateTableSql 解析建表sql
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 4:49 下午 2021/10/25
func ParseCreateTableSql(sql string, withGetTableFunc bool) (string, *BasicTableInfo, error) {
var (
stmt sqlparser.Statement
err error
basic *BasicTableInfo
)
basic = &BasicTableInfo{
TableName: "",
ModelStruct: "",
PrimaryField: "ID",
PrimaryFieldType: "",
}
sql = strings.ReplaceAll(strings.ReplaceAll(sql, "CURRENT_TIMESTAMP()", "CURRENT_TIMESTAMP"), "current_timestamp()", "CURRENT_TIMESTAMP")
if stmt, err = sqlparser.ParseStrictDDL(sql); nil != err {
return "", nil, err
}
r, ok := stmt.(*sqlparser.DDL)
if !ok {
return "", nil, errors.New("input sql is not ddl")
}
basic.TableName = sqlparser.String(r.NewName)
basic.ModelStruct = util.String.SnakeCaseToCamel(basic.TableName)
structResult := "type " + basic.ModelStruct + " struct { \n"
for _, item := range r.TableSpec.Columns {
comment := ""
if item.Type.Comment == nil {
comment = item.Name.String()
} else {
comment = string(item.Type.Comment.Val)
}
data := map[string]string{
"{FIELD}": util.String.SnakeCaseToCamel(item.Name.String()),
"{COLUMN}": item.Name.String(),
"{JSON_TAG}": item.Name.String(),
"{DEFAULT_VALUE}": "",
"{COMMENT}": comment,
"{TYPE}": sqlTypeMap[item.Type.Type],
}
if data["{FIELD}"] == "ID" {
basic.PrimaryFieldType = data["{TYPE}"]
}
if item.Type.NotNull {
data["{NOT_NULL}"] = "NOT NULL"
}
if nil != item.Type.Default {
data["{DEFAULT_VALUE}"] += string(item.Type.Default.Val)
}
val := CreateSQLColumnTPL
for k, v := range data {
val = strings.ReplaceAll(val, k, v)
}
structResult += val + "\n"
}
structResult = structResult + "}"
if withGetTableFunc {
// 生成表名称获取方法
tableFirst := string([]byte(basic.TableName)[:1])
funcTpl := `
// TableName 获取表名称
func ({{TABLE_FIRST}} {{TABLE_STRUCT_NAME}}) TableName() string {
return "{{TABLE_NAME}}"
}`
replaceTable := map[string]string{
"{{TABLE_FIRST}}": tableFirst,
"{{TABLE_STRUCT_NAME}}": basic.ModelStruct,
"{{TABLE_NAME}}": basic.TableName,
}
for k, v := range replaceTable {
funcTpl = strings.ReplaceAll(funcTpl, k, v)
}
structResult = structResult + funcTpl
}
return structResult, basic, nil
}