Files
database/wrapper_client.go

179 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package database ...
//
// Description : mysql客户端
//
// Author : go_developer@163.com<白茶清欢>
//
// Date : 2021-03-01 9:20 下午
package database
import (
"context"
"errors"
"fmt"
"path/filepath"
"strings"
"sync"
"git.zhangdeman.cn/zhangdeman/consts"
"git.zhangdeman.cn/zhangdeman/database/abstract"
"git.zhangdeman.cn/zhangdeman/database/define"
"git.zhangdeman.cn/zhangdeman/serialize"
"go.uber.org/zap"
"gorm.io/gorm"
)
var (
// WrapperClient 包装后的数据库客户端
WrapperClient abstract.IWrapperClient
)
func init() {
WrapperClient = NewWrapperClient()
}
func NewWrapperClient() abstract.IWrapperClient {
return &wrapperClient{
lock: &sync.RWMutex{},
clientTable: make(map[string]abstract.IWrapperDatabaseClient),
}
}
type wrapperClient struct {
lock *sync.RWMutex
clientTable map[string]abstract.IWrapperDatabaseClient
logger *zap.Logger
}
// AddWithConfigFile 使用文件生成新的客户端文件名去掉后缀作为flag
func (c *wrapperClient) AddWithConfigFile(cfgFilePath string, logInstance *zap.Logger, extraFieldList []string) error {
var (
err error
cfg *define.CfgFile
)
if cfg, err = c.getCfg(cfgFilePath); nil != err {
return err
}
if nil == cfg {
// 不支持的配置文件格式
return nil
}
return c.AddWithConfig(cfg.Flag, logInstance, cfg.Config, extraFieldList)
}
// AddWithConfig ...
func (c *wrapperClient) AddWithConfig(flag string, logInstance *zap.Logger, databaseConfig *define.Database, extraFieldList []string) error {
dbClient := &DBClient{
DbFlag: flag,
LoggerInstance: logInstance,
ExtraFieldList: extraFieldList,
Cfg: define.Driver{},
}
if err := dbClient.Init(databaseConfig, nil); nil != err {
return err
}
c.lock.Lock()
c.clientTable[dbClient.DbFlag] = dbClient
c.lock.Unlock()
return nil
}
// BatchAddWithConfigDir 自动读取目录下配置文件, 生成客户端
func (c *wrapperClient) BatchAddWithConfigDir(cfgDir string, logInstance *zap.Logger, extraFieldList []string) error {
filepathNames, _ := filepath.Glob(filepath.Join(cfgDir, "*"))
for i := range filepathNames {
if err := c.AddWithConfigFile(filepathNames[i], logInstance, extraFieldList); nil != err {
return err
}
}
return nil
}
// getCfg 读取配置
func (c *wrapperClient) getCfg(cfgPath string) (*define.CfgFile, error) {
fileArr := strings.Split(cfgPath, ".")
if len(fileArr) < 2 {
// 获取不到类型
return nil, errors.New("文件格式必须是JSON或者YAML")
}
fileType := consts.FileType(strings.ToLower(fileArr[len(fileArr)-1]))
fileFlagArr := strings.Split(fileArr[0], string(filepath.Separator))
result := &define.CfgFile{
Path: cfgPath,
Type: "",
Flag: fileFlagArr[len(fileFlagArr)-1],
Config: &define.Database{},
}
var (
err error
cfgInfo define.Database
)
switch fileType {
case consts.FileTypeYaml:
fallthrough
case consts.FileTypeYml:
result.Type = consts.FileTypeYaml
if err = serialize.File.ReadYmlContent(cfgPath, &result.Config); nil != err {
return nil, fmt.Errorf("%s 配置文件解析失败, 原因 : %s", cfgPath, err.Error())
}
case consts.FileTypeJson:
result.Type = consts.FileTypeJson
if err = serialize.File.ReadJSONContent(cfgPath, &cfgInfo); nil != err {
return nil, fmt.Errorf("%s 配置文件解析失败, 原因 : %s", cfgPath, err.Error())
}
default:
// 不是JSON , 也不是YML, 跳过
return nil, nil
}
if len(result.Config.Master.Timezone) == 0 {
// 默认使用本地时区
result.Config.Master.Timezone = "Local"
} else {
result.Config.Slave.Timezone = result.Config.Master.Timezone
}
return result, nil
}
// GetDBClient 获取db client
func (c *wrapperClient) GetDBClient(dbFlag string) (abstract.IWrapperDatabaseClient, error) {
c.lock.RLock()
defer c.lock.RUnlock()
var (
exist bool
dbClient abstract.IWrapperDatabaseClient
)
if dbClient, exist = c.clientTable[dbFlag]; !exist {
return nil, fmt.Errorf("%s 标识的数据库实例不存在! ", dbFlag)
}
return dbClient, nil
}
// GetMasterClient 获取主库客户端
func (c *wrapperClient) GetMasterClient(ctx context.Context, dbFlag string) (*gorm.DB, error) {
var (
err error
dbClient abstract.IWrapperDatabaseClient
)
if dbClient, err = c.GetDBClient(dbFlag); nil != err {
return nil, err
}
return dbClient.GetMaster(ctx), nil
}
// GetSlaveClient 获取从库客户端
func (c *wrapperClient) GetSlaveClient(ctx context.Context, dbFlag string) (*gorm.DB, error) {
var (
err error
dbClient abstract.IWrapperDatabaseClient
)
if dbClient, err = c.GetDBClient(dbFlag); nil != err {
return nil, err
}
return dbClient.GetSlave(ctx), nil
}