database/wrapper_client.go
2024-08-20 18:15:59 +08:00

206 lines
5.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package database ...
//
// Description : mysql客户端
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 2021-03-01 9:20 下午
package database
import (
"context"
"errors"
"fmt"
"git.zhangdeman.cn/zhangdeman/consts"
"git.zhangdeman.cn/zhangdeman/database/abstract"
"git.zhangdeman.cn/zhangdeman/database/define"
"git.zhangdeman.cn/zhangdeman/serialize"
"path/filepath"
"strings"
"sync"
"go.uber.org/zap"
"gorm.io/gorm"
)
var (
// WrapperClient 包装后的数据库客户端
WrapperClient abstract.IWrapperClient
)
func init() {
WrapperClient = NewWrapperClient()
}
func NewWrapperClient() *wrapperClient {
return &wrapperClient{
lock: &sync.RWMutex{},
clientTable: make(map[string]abstract.IWrapperDatabaseClient),
}
}
type wrapperClient struct {
lock *sync.RWMutex
clientTable map[string]abstract.IWrapperDatabaseClient
logger *zap.Logger
}
// AddWithConfigFile 使用文件生成新的客户端文件名去掉后缀作为flag
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 19:19 2022/6/5
func (c *wrapperClient) AddWithConfigFile(cfgFilePath string, logInstance *zap.Logger, extraFieldList []string) error {
var (
err error
cfg *define.CfgFile
)
if cfg, err = c.getCfg(cfgFilePath); nil != err {
return err
}
if nil == cfg {
// 不支持的配置文件格式
return nil
}
return c.AddWithConfig(cfg.Flag, logInstance, cfg.Config, extraFieldList)
}
// AddWithConfig ...
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 20:41 2023/4/18
func (c *wrapperClient) AddWithConfig(flag string, logInstance *zap.Logger, databaseConfig *define.Database, extraFieldList []string) error {
dbClient := &DBClient{
DbFlag: flag,
LoggerInstance: logInstance,
ExtraFieldList: extraFieldList,
Cfg: define.Driver{},
}
if err := dbClient.Init(databaseConfig); nil != err {
return err
}
c.lock.Lock()
c.clientTable[dbClient.DbFlag] = dbClient
c.lock.Unlock()
return nil
}
// BatchAddWithConfigDir 自动读取目录下配置文件, 生成客户端
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 19:19 2022/6/5
func (c *wrapperClient) BatchAddWithConfigDir(cfgDir string, logInstance *zap.Logger, extraFieldList []string) error {
filepathNames, _ := filepath.Glob(filepath.Join(cfgDir, "*"))
for i := range filepathNames {
if err := c.AddWithConfigFile(filepathNames[i], logInstance, extraFieldList); nil != err {
return err
}
}
return nil
}
// getCfg 读取配置
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 18:05 2022/6/11
func (c *wrapperClient) getCfg(cfgPath string) (*define.CfgFile, error) {
fileArr := strings.Split(cfgPath, ".")
if len(fileArr) < 2 {
// 获取不到类型
return nil, errors.New("文件格式必须是JSON或者YAML")
}
fileType := strings.ToLower(fileArr[len(fileArr)-1])
fileFlagArr := strings.Split(fileArr[0], string(filepath.Separator))
result := &define.CfgFile{
Path: cfgPath,
Type: "",
Flag: fileFlagArr[len(fileFlagArr)-1],
Config: &define.Database{},
}
var (
err error
cfgInfo define.Database
)
switch fileType {
case consts.FileTypeYaml:
fallthrough
case consts.FileTypeYml:
result.Type = consts.FileTypeYaml
if err = serialize.File.ReadYmlContent(cfgPath, &result.Config); nil != err {
return nil, fmt.Errorf("%s 配置文件解析失败, 原因 : %s", cfgPath, err.Error())
}
case consts.FileTypeJson:
result.Type = consts.FileTypeJson
if err = serialize.File.ReadJSONContent(cfgPath, &cfgInfo); nil != err {
return nil, fmt.Errorf("%s 配置文件解析失败, 原因 : %s", cfgPath, err.Error())
}
default:
// 不是JSON , 也不是YML, 跳过
return nil, nil
}
if len(result.Config.Master.Timezone) == 0 {
// 默认使用本地时区
result.Config.Master.Timezone = "Local"
} else {
result.Config.Slave.Timezone = result.Config.Master.Timezone
}
return result, nil
}
// GetDBClient 获取db client
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 19:32 2022/6/5
func (c *wrapperClient) GetDBClient(dbFlag string) (abstract.IWrapperDatabaseClient, error) {
c.lock.RLock()
defer c.lock.RUnlock()
var (
exist bool
dbClient abstract.IWrapperDatabaseClient
)
if dbClient, exist = c.clientTable[dbFlag]; !exist {
return nil, fmt.Errorf("%s 标识的数据库实例不存在! ", dbFlag)
}
return dbClient, nil
}
// GetMasterClient 获取主库客户端
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 19:36 2022/6/5
func (c *wrapperClient) GetMasterClient(ctx context.Context, dbFlag string) (*gorm.DB, error) {
var (
err error
dbClient abstract.IWrapperDatabaseClient
)
if dbClient, err = c.GetDBClient(dbFlag); nil != err {
return nil, err
}
return dbClient.GetMaster(ctx), nil
}
// GetSlaveClient 获取从库客户端
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 19:37 2022/6/5
func (c *wrapperClient) GetSlaveClient(ctx context.Context, dbFlag string) (*gorm.DB, error) {
var (
err error
dbClient abstract.IWrapperDatabaseClient
)
if dbClient, err = c.GetDBClient(dbFlag); nil != err {
return nil, err
}
return dbClient.GetSlave(ctx), nil
}