2
0

package_initialize_gorm.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package ast
  2. import (
  3. "fmt"
  4. "go/ast"
  5. "go/token"
  6. "io"
  7. )
  8. // PackageInitializeGorm 包初始化gorm
  9. type PackageInitializeGorm struct {
  10. Base
  11. Type Type // 类型
  12. Path string // 文件路径
  13. ImportPath string // 导包路径
  14. Business string // 业务库 gva => gva, 不要传"gva"
  15. StructName string // 结构体名称
  16. PackageName string // 包名
  17. RelativePath string // 相对路径
  18. IsNew bool // 是否使用new关键字 true: new(PackageName.StructName) false: &PackageName.StructName{}
  19. }
  20. func (a *PackageInitializeGorm) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
  21. if filename == "" {
  22. if a.RelativePath == "" {
  23. filename = a.Path
  24. a.RelativePath = a.Base.RelativePath(a.Path)
  25. return a.Base.Parse(filename, writer)
  26. }
  27. a.Path = a.Base.AbsolutePath(a.RelativePath)
  28. filename = a.Path
  29. }
  30. return a.Base.Parse(filename, writer)
  31. }
  32. func (a *PackageInitializeGorm) Rollback(file *ast.File) error {
  33. packageNameNum := 0
  34. // 寻找目标结构
  35. ast.Inspect(file, func(n ast.Node) bool {
  36. // 总调用的db变量根据business来决定
  37. varDB := a.Business + "Db"
  38. if a.Business == "" {
  39. varDB = "db"
  40. }
  41. callExpr, ok := n.(*ast.CallExpr)
  42. if !ok {
  43. return true
  44. }
  45. // 检查是不是 db.AutoMigrate() 方法
  46. selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
  47. if !ok || selExpr.Sel.Name != "AutoMigrate" {
  48. return true
  49. }
  50. // 检查调用方是不是 db
  51. ident, ok := selExpr.X.(*ast.Ident)
  52. if !ok || ident.Name != varDB {
  53. return true
  54. }
  55. // 删除结构体参数
  56. for i := 0; i < len(callExpr.Args); i++ {
  57. if com, comok := callExpr.Args[i].(*ast.CompositeLit); comok {
  58. if selector, exprok := com.Type.(*ast.SelectorExpr); exprok {
  59. if x, identok := selector.X.(*ast.Ident); identok {
  60. if x.Name == a.PackageName {
  61. packageNameNum++
  62. if selector.Sel.Name == a.StructName {
  63. callExpr.Args = append(callExpr.Args[:i], callExpr.Args[i+1:]...)
  64. i--
  65. }
  66. }
  67. }
  68. }
  69. }
  70. }
  71. return true
  72. })
  73. if packageNameNum == 1 {
  74. _ = NewImport(a.ImportPath).Rollback(file)
  75. }
  76. return nil
  77. }
  78. func (a *PackageInitializeGorm) Injection(file *ast.File) error {
  79. _ = NewImport(a.ImportPath).Injection(file)
  80. bizModelDecl := FindFunction(file, "bizModel")
  81. if bizModelDecl != nil {
  82. a.addDbVar(bizModelDecl.Body)
  83. }
  84. // 寻找目标结构
  85. ast.Inspect(file, func(n ast.Node) bool {
  86. // 总调用的db变量根据business来决定
  87. varDB := a.Business + "Db"
  88. if a.Business == "" {
  89. varDB = "db"
  90. }
  91. callExpr, ok := n.(*ast.CallExpr)
  92. if !ok {
  93. return true
  94. }
  95. // 检查是不是 db.AutoMigrate() 方法
  96. selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
  97. if !ok || selExpr.Sel.Name != "AutoMigrate" {
  98. return true
  99. }
  100. // 检查调用方是不是 db
  101. ident, ok := selExpr.X.(*ast.Ident)
  102. if !ok || ident.Name != varDB {
  103. return true
  104. }
  105. // 添加结构体参数
  106. callExpr.Args = append(callExpr.Args, &ast.CompositeLit{
  107. Type: &ast.SelectorExpr{
  108. X: ast.NewIdent(a.PackageName),
  109. Sel: ast.NewIdent(a.StructName),
  110. },
  111. })
  112. return true
  113. })
  114. return nil
  115. }
  116. func (a *PackageInitializeGorm) Format(filename string, writer io.Writer, file *ast.File) error {
  117. if filename == "" {
  118. filename = a.Path
  119. }
  120. return a.Base.Format(filename, writer, file)
  121. }
  122. // 创建businessDB变量
  123. func (a *PackageInitializeGorm) addDbVar(astBody *ast.BlockStmt) {
  124. for i := range astBody.List {
  125. if assignStmt, ok := astBody.List[i].(*ast.AssignStmt); ok {
  126. if ident, ok := assignStmt.Lhs[0].(*ast.Ident); ok {
  127. if (a.Business == "" && ident.Name == "db") || ident.Name == a.Business+"Db" {
  128. return
  129. }
  130. }
  131. }
  132. }
  133. // 添加 businessDb := global.GetGlobalDBByDBName("business") 变量
  134. assignNode := &ast.AssignStmt{
  135. Lhs: []ast.Expr{
  136. &ast.Ident{
  137. Name: a.Business + "Db",
  138. },
  139. },
  140. Tok: token.DEFINE,
  141. Rhs: []ast.Expr{
  142. &ast.CallExpr{
  143. Fun: &ast.SelectorExpr{
  144. X: &ast.Ident{
  145. Name: "global",
  146. },
  147. Sel: &ast.Ident{
  148. Name: "GetGlobalDBByDBName",
  149. },
  150. },
  151. Args: []ast.Expr{
  152. &ast.BasicLit{
  153. Kind: token.STRING,
  154. Value: fmt.Sprintf("\"%s\"", a.Business),
  155. },
  156. },
  157. },
  158. },
  159. }
  160. // 添加 businessDb.AutoMigrate() 方法
  161. autoMigrateCall := &ast.ExprStmt{
  162. X: &ast.CallExpr{
  163. Fun: &ast.SelectorExpr{
  164. X: &ast.Ident{
  165. Name: a.Business + "Db",
  166. },
  167. Sel: &ast.Ident{
  168. Name: "AutoMigrate",
  169. },
  170. },
  171. },
  172. }
  173. returnNode := astBody.List[len(astBody.List)-1]
  174. astBody.List = append(astBody.List[:len(astBody.List)-1], assignNode, autoMigrateCall, returnNode)
  175. }