package_enter.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package ast
  2. import (
  3. "go/ast"
  4. "go/token"
  5. "io"
  6. )
  7. // PackageEnter 模块化入口
  8. type PackageEnter struct {
  9. Base
  10. Type Type // 类型
  11. Path string // 文件路径
  12. ImportPath string // 导包路径
  13. StructName string // 结构体名称
  14. PackageName string // 包名
  15. RelativePath string // 相对路径
  16. PackageStructName string // 包结构体名称
  17. }
  18. func (a *PackageEnter) Parse(filename string, writer io.Writer) (file *ast.File, err error) {
  19. if filename == "" {
  20. if a.RelativePath == "" {
  21. filename = a.Path
  22. a.RelativePath = a.Base.RelativePath(a.Path)
  23. return a.Base.Parse(filename, writer)
  24. }
  25. a.Path = a.Base.AbsolutePath(a.RelativePath)
  26. filename = a.Path
  27. }
  28. return a.Base.Parse(filename, writer)
  29. }
  30. func (a *PackageEnter) Rollback(file *ast.File) error {
  31. // 无需回滚
  32. return nil
  33. }
  34. func (a *PackageEnter) Injection(file *ast.File) error {
  35. _ = NewImport(a.ImportPath).Injection(file)
  36. ast.Inspect(file, func(n ast.Node) bool {
  37. genDecl, ok := n.(*ast.GenDecl)
  38. if !ok || genDecl.Tok != token.TYPE {
  39. return true
  40. }
  41. for _, spec := range genDecl.Specs {
  42. typeSpec, specok := spec.(*ast.TypeSpec)
  43. if !specok || typeSpec.Name.Name != a.Type.Group() {
  44. continue
  45. }
  46. structType, structTypeOK := typeSpec.Type.(*ast.StructType)
  47. if !structTypeOK {
  48. continue
  49. }
  50. for _, field := range structType.Fields.List {
  51. if len(field.Names) == 1 && field.Names[0].Name == a.StructName {
  52. return true
  53. }
  54. }
  55. field := &ast.Field{
  56. Names: []*ast.Ident{{Name: a.StructName}},
  57. Type: &ast.SelectorExpr{
  58. X: &ast.Ident{Name: a.PackageName},
  59. Sel: &ast.Ident{Name: a.PackageStructName},
  60. },
  61. }
  62. structType.Fields.List = append(structType.Fields.List, field)
  63. return false
  64. }
  65. return true
  66. })
  67. return nil
  68. }
  69. func (a *PackageEnter) Format(filename string, writer io.Writer, file *ast.File) error {
  70. if filename == "" {
  71. filename = a.Path
  72. }
  73. return a.Base.Format(filename, writer, file)
  74. }