天天看點

拿來即用的MyBatis Plugin實作SQL語句結構動态新增/更改字段(在實際鍊路追蹤服務運作)

一、前言

前一段時間在公司寫了一個鍊路追蹤的服務,需要把使用者的資訊和服務鍊路的調用資訊持久化到業務表,然後使用canal + binlog的方式做日志審計。其中在将使用者的資訊和服務鍊路的調用資訊持久化到業務表時,采用自定義Mybatis Plugin的方式實作多業務子產品通用的攔截落庫。

關于MyBatis Plugin的實作原理,參考我以下的三個博文:

1、​​從JDK動态代理一步步推導到MyBatis Plugin插件實作原理​​​ 2、​​原來MyBatis插件/攔截器(Plugin/Interceptyor)的實作原理這麼簡單​​ 3、​​四種方式使通用SDK中自定義Mybatis Plugin生效?​​

廢話不多說,我們直接上代碼,完整源碼(服務調用鍊路trace追蹤–最終持久化到業務表(此處可按需求調整持久化的地方))脫敏後部落客将開源到github,明天補充連結。

二、代碼

jsqlparser的maven依賴如下:

<dependency>
    <groupId>com.github.jsqlparser</groupId>
    <artifactId>jsqlparser</artifactId>
    <version>3.2</version>
    <optional>true</optional>
</dependency>      

代碼中主要使用​

​jsqlparser​

​對SQL語句的結構進行動态修改;下列代碼的核心邏輯是:

  1. 針對insert和update類型語句進行SQL結構的動态調整;包括:
  2. insert語句插入一些字段時,如果trace需要的字段在原本的insert語句中已經存在,則将原insert語句中對應的字段–value值修改為trace中的值,否者新增trace中的字段和相應的trace中的value值。
  3. update語句更新一些字段時,和insert語句類似。不過update語句中可以存在多個相同的字段,真正更新的值以最後的一個為準。

最後采用兜底措施,對所有的攔截器邏輯進行try{}…catch{}操作,防止攔截器出現不正常問題;有一個點需要注意,必須要先修改業務表,增加trace中的字段,否者try{}…catch{}是沒有用的;大家可能會問,為什麼不在攔截器中直接寫DDL語句修改表結構啊?這裡是出于攔截器的性能考慮,盡可能的讓攔截器以最少的時間損耗做最更通用的事情。

有一個點就很坑,由于新老項目中可能原本就使用到了jsqlparser,并且使用的版本差異有很大,是以我做SDK也寫了兩套相容他們,雖然我已經在SDK對​

​jsqlparser​

​​的maven依賴做了​

​<optional>true</optional>​

​​标注,但由于版本原因,業務上引入不同版本的​

​jsqlparser​

​還是會報錯。

1> 在ThreadLocal中儲存trace資訊,進而持久到業務表:

package com.saint.constant;

import lombok.Data;
import lombok.experimental.Accessors;

import java.io.Serializable;

/**
 * 鍊路資訊上下文
 *
 * @author Saint
 */
public class MybatisTraceContext implements Serializable {

    private final static InheritableThreadLocal<TraceContext> traceContextHolder = new InheritableThreadLocal<>();

    public static InheritableThreadLocal<TraceContext> get() {
        return traceContextHolder;
    }

    /**
     * 設定traceContext
     *
     * @param traceContext traceContext
     */
    public static void setTraceContext(TraceContext traceContext) {
        traceContextHolder.set(traceContext);
    }

    /**
     * 擷取traceContext
     *
     * @return traceContext
     */
    public static TraceContext getTraceContext() {
        return traceContextHolder.get();
    }

    /**
     * 清空trace上下文
     */
    public static void clear() {
        traceContextHolder.remove();
    }

    @Data
    @Accessors(chain = true)
    public static class TraceContext implements Serializable {
        private Long userId;
        private String traceId;
        private String controllerAction;
        private String visitIp;
        private String appName;
    }
}      

2> MyBatis攔截器中使用到的常量:

/**
 * MyBatis攔截器中使用到的常量
 *
 * @author Saint
 */
@Getter
public enum MyBatisPluginConst {

    /**
     * 在這裡修改業務表字段名
     */
     .......
    DELEGATE_BOUND_SQL("delegate.boundSql.sql"),
    BOUND_SQL("boundSql"),
    DELEGATE_MAPPED_STATEMENT("delegate.mappedStatement"),
    MAPPED_STATEMENT("mappedStatement"),
    METHOD_PREPARE("prepare"),
    METHOD_SET_PARAMETERS("setParameters");

    private String vale;

    MyBatisPluginConst(String vale) {
        this.vale = vale;
    }
}      

1、使用jsqlparser高版本(例如3.2)的自定義MyBatis plugin

這裡做了很多mock操作,把原本業務上的注釋掉了,采用mock字段名和字段值的方式;

  1. insert時,就判斷業務表上有沒有user_name字段,有就修改值,否者就添加列和字段值。
  2. update時,和insert時一樣的邏輯,列存在就修改值,否則添加列和字段值。
package com.saint.mybatis;

import com.saint.constant.MyBatisPluginConst;
import com.saint.utils.MyBatisPluginUtils;
import com.saint.utils.StringUtil;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.ItemsListVisitor;
import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;
import net.sf.jsqlparser.expression.operators.relational.NamedExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.statement.values.ValuesStatement;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.springframework.beans.factory.annotation.Value;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.util.*;

/**
 * MyBatis攔截器;
 *
 * @author Saint
 */
@Slf4j
@Intercepts({@Signature(type = StatementHandler.class,
        method = "prepare", args = {Connection.class, Integer.class}),
        @Signature(type = ParameterHandler.class, method = "setParameters", args = {PreparedStatement.class})
})
public class MybatisInterceptor implements Interceptor {

    /**
     * Tables not intercepted
     */
    @Value("#{'${mybatis.plugin.ignoreTables:}'.split(',')}")
    private List<String> ignoreTableList = Collections.emptyList();

    /**
     * 從啟動指令的-D參數中擷取`ENABLE_MYBATIS_PLUGIN`參數的值,表示是否啟動mybatis攔截器
     * 當然也可以使用System.getEnv()從環境變量中擷取
     */
    private String enableMybatisPlugin = System.getProperty("ENABLE_MYBATIS_PLUGIN");

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        try {

            if (StringUtils.isEmpty(enableMybatisPlugin) || (!StringUtils.equals(enableMybatisPlugin, "true")
                    && !StringUtils.equals(enableMybatisPlugin, "TRUE"))) {
                return invocation.proceed();
            }

            String invocationName = invocation.getMethod().getName();

            if (Objects.equals(invocationName, MyBatisPluginConst.METHOD_PREPARE.getVale())) {

                // case1:通過MetaObject優雅通路對象的屬性,這裡是通路statementHandler的屬性;
                //   1、MetaObject是Mybatis提供的一個用于友善、優雅通路對象屬性的對象;
                //   2、通過它可以簡化代碼、不需要try/catch各種reflect異常,同時它支援對JavaBean、Collection、Map三種類型對象的操作。
                StatementHandler handler = (StatementHandler) invocation.getTarget();
                MetaObject metaObject = SystemMetaObject.forObject(handler);

                // case2:先攔截到RoutingStatementHandler,裡面有個StatementHandler類型的delegate變量,
                //        其實作類是BaseStatementHandler,然後就到BaseStatementHandler的成員變量mappedStatement
                MappedStatement mappedStatement = (MappedStatement) metaObject.getValue(MyBatisPluginConst.DELEGATE_MAPPED_STATEMENT.getVale());

                // id為執行的mapper方法的全路徑名,如com.uv.dao.UserMapper.insertUser
                String id = mappedStatement.getId();

                //資料庫連接配接資訊
                Configuration configuration = mappedStatement.getConfiguration();

                // sql type: UNKNOWN, INSERT, UPDATE, DELETE, SELECT, FLUSH
                String sqlCommandType = mappedStatement.getSqlCommandType().toString();

                // only intercept update and insert dml
                if (!Objects.equals(sqlCommandType, SqlCommandType.UPDATE.toString())
                        && !Objects.equals(sqlCommandType, SqlCommandType.INSERT.toString())) {
                    return invocation.proceed();
                }

                // obtain original sql,擷取到原始sql語句,way1:
                String sql = metaObject.getValue(MyBatisPluginConst.DELEGATE_BOUND_SQL.getVale()).toString();
                //   way2: 也可以通過如下方式擷取原始的SQL語句:
//                BoundSql boundSql = handler.getBoundSql();
//                String sql = boundSql.getSql();

                // 通過jsqlparser解析SQL,此處的statement是封裝過後的Insert/Update/Query等SQL語句
                Statement statement = CCJSqlParserUtil.parse(sql);

                // todo 如果是簡單的加個查詢個數限制,可以使用jsqlparser解析SQL,用反射修改SQL語句也可。比如:
//                String mSql = sql + " limit 2";
//                BoundSql boundSql = handler.getBoundSql();
//                Field field = boundSql.getClass().getDeclaredField("sql");
//                field.setAccessible(true);
//                field.set(boundSql, mSql);

                switch (sqlCommandType) {
                    case "INSERT":
                        prepareInsertSql(statement, metaObject);
                        break;
                    case "UPDATE":
                        // can not handle, will not affect execute, but be elegant
                        prepareUpdateSql(statement, metaObject);
                        break;
                    default:
                        break;
                }
            } else if (Objects.equals(invocationName, MyBatisPluginConst.METHOD_SET_PARAMETERS.getVale())) {
                // 擷取最原始的參數解析器:ParameterHandler
                ParameterHandler handler = (ParameterHandler) MyBatisPluginUtils.realTarget(invocation.getTarget());
                MetaObject metaObject = SystemMetaObject.forObject(handler);
                MappedStatement mappedStatement = (MappedStatement) metaObject.getValue(MyBatisPluginConst.MAPPED_STATEMENT.getVale());
                // sql type: UNKNOWN, INSERT, UPDATE, DELETE, SELECT, FLUSH
                String sqlCommandType = mappedStatement.getSqlCommandType().toString();
                // only intercept update and insert dml
                if (!Objects.equals(sqlCommandType, SqlCommandType.UPDATE.toString())
                        && !Objects.equals(sqlCommandType, SqlCommandType.INSERT.toString())) {
                    return invocation.proceed();
                }

                BoundSql boundSql = (BoundSql) metaObject.getValue(MyBatisPluginConst.BOUND_SQL.getVale());
                Statement statement = CCJSqlParserUtil.parse(boundSql.getSql());
                switch (sqlCommandType) {
                    case "INSERT":
                        Insert insert = (Insert) statement;
                        if (!matchesIgnoreTables(insert.getTable().getName())) {
                            handleParameterMapping(boundSql);
                        }
                        break;
                    case "UPDATE":
                        Update update = (Update) statement;
                        if (!matchesIgnoreTables(update.getTable().getName())) {
                            handleParameterMapping(boundSql);
                        }
                        break;
                    default:
                        break;
                }
            }

        } catch (Exception e) {
            log.error("Exception in executing MyBatis Interceptor", e);
        }

        return invocation.proceed();
    }

    /**
     * handle update sql in StatementHandler#prepare() phase
     *
     * @param statement  statement
     * @param metaObject metaObject
     */
    private void prepareUpdateSql(Statement statement, MetaObject metaObject) {

        Update update = (Update) statement;
        if (matchesIgnoreTables(update.getTable().getName())) {
            return;
        }

        boolean isContainsUserIdColumn = false;
        int modifyDateColumnIndex = 0;

        for (int i = 0; i < update.getColumns().size(); i++) {
            Column column = update.getColumns().get(i);
            if (column.getColumnName().equals("user_name")) {
                // sql中包含了設定的列名,則隻需要設定值
                isContainsUserIdColumn = true;
                modifyDateColumnIndex = i;
            }
        }

        // 如果sql語句已經包含了`user_name`字段,則更新字段值,否者新增字段列和值
        if (isContainsUserIdColumn) {
            updateValueWithIndex(modifyDateColumnIndex, "Saint-update", update);
        } else {
            updateValue("user_name", "Saint-update", update);
        }

        log.debug("intercept update sql is : {}", update);
        metaObject.setValue("delegate.boundSql.sql", update.toString());

    }

    /**
     * handle insert sql in StatementHandler#prepare() phase
     *
     * @param statement  statement
     * @param metaObject metaObject
     */
    private void prepareInsertSql(Statement statement, MetaObject metaObject) {

        Insert insert = (Insert) statement;
        if (matchesIgnoreTables(insert.getTable().getName())) {
            return;
        }

        boolean isContainsUserIdColumn = false;
        int createDateColumnIndex = 0;
        for (int i = 0; i < insert.getColumns().size(); i++) {
            Column column = insert.getColumns().get(i);
            if (column.getColumnName().equals("user_name")) {
                // sql中包含了設定的列名,則隻需要設定值
                isContainsUserIdColumn = true;
                createDateColumnIndex = i;
            }
        }

        if (isContainsUserIdColumn) {
            intoValueWithIndex(createDateColumnIndex, "Saint-insert", insert);
        } else {
            intoValue("user_name", "Saint-insert", insert);
        }

        log.debug("intercept insert sql is : {}", insert);

        metaObject.setValue("delegate.boundSql.sql", insert.toString());
    }

    /**
     * update sql update column value
     *
     * @param modifyDateColumnIndex
     * @param columnValue
     * @param update
     */
    private void updateValueWithIndex(int modifyDateColumnIndex, Object columnValue, Update update) {
        if (columnValue instanceof Long) {
            update.getExpressions().set(modifyDateColumnIndex, new LongValue((Long) columnValue));
        } else if (columnValue instanceof String) {
            update.getExpressions().set(modifyDateColumnIndex, new StringValue((String) columnValue));
        } else {
            // if you need to add other type data, add more if branch
            update.getExpressions().set(modifyDateColumnIndex, new StringValue((String) columnValue));
        }
    }

    /**
     * update sql add column
     *
     * @param updateDateColumnName
     * @param columnValue
     * @param update
     */
    private void updateValue(String updateDateColumnName, Object columnValue, Update update) {
        // 添加列
        update.getColumns().add(new Column(updateDateColumnName));
        if (columnValue instanceof Long) {
            update.getExpressions().add(new LongValue((Long) columnValue));
        } else if (columnValue instanceof String) {
            update.getExpressions().add(new StringValue((String) columnValue));
        } else {
            // if you need to add other type data, add more if branch
            update.getExpressions().add(new StringValue((String) columnValue));
        }
    }

    /**
     * insert sql add column
     *
     * @param columnName
     * @param columnValue
     * @param insert
     */
    private void intoValue(String columnName, final Object columnValue, Insert insert) {
        // 添加列
        insert.getColumns().add(new Column(columnName));
        // 通過visitor設定對應的值
        if (insert.getItemsList() == null) {
            insert.getSelect().getSelectBody().accept(new PlainSelectVisitor(-1, columnValue));
        } else {
            insert.getItemsList().accept(new ItemsListVisitor() {
                @Override
                public void visit(SubSelect subSelect) {
                    throw new UnsupportedOperationException("Not supported yet.");
                }

                @Override
                public void visit(ExpressionList expressionList) {
                    // 這裡表示添加列時。列值在資料庫中的資料類型, 目前隻用到了Long和String,需要的自行擴充
                    // todo 下面出現此類代碼的都一樣
                    if (columnValue instanceof String) {
                        expressionList.getExpressions().add(new StringValue((String) columnValue));
                    } else if (columnValue instanceof Long) {
                        expressionList.getExpressions().add(new LongValue((Long) columnValue));
                    } else {
                        // if you need to add other type data, add more if branch
                        expressionList.getExpressions().add(new StringValue((String) columnValue));
                    }
                }

                @Override
                public void visit(NamedExpressionList namedExpressionList) {
                    throw new UnsupportedOperationException("Not supported yet.");
                }

                @Override
                public void visit(MultiExpressionList multiExpressionList) {
                    for (ExpressionList expressionList : multiExpressionList.getExprList()) {
                        if (columnValue instanceof String) {
                            expressionList.getExpressions().add(new StringValue((String) columnValue));
                        } else if (columnValue instanceof Long) {
                            expressionList.getExpressions().add(new LongValue((Long) columnValue));
                        } else {
                            // if you need to add other type data, add more if branch
                            expressionList.getExpressions().add(new StringValue((String) columnValue));
                        }
                    }
                }
            });
        }
    }

    /**
     * insert sql update column value
     *
     * @param index
     * @param columnValue
     * @param insert
     */
    private void intoValueWithIndex(final int index, final Object columnValue, Insert insert) {
        // 通過visitor設定對應的值
        if (insert.getItemsList() == null) {
            insert.getSelect().getSelectBody().accept(new PlainSelectVisitor(index, columnValue));
        } else {
            insert.getItemsList().accept(new ItemsListVisitor() {
                @Override
                public void visit(SubSelect subSelect) {
                    throw new UnsupportedOperationException("Not supported yet.");
                }

                @Override
                public void visit(ExpressionList expressionList) {
                    if (columnValue instanceof String) {
                        expressionList.getExpressions().set(index, new StringValue((String) columnValue));
                    } else if (columnValue instanceof Long) {
                        expressionList.getExpressions().set(index, new LongValue((Long) columnValue));
                    } else {
                        // if you need to add other type data, add more if branch
                        expressionList.getExpressions().set(index, new StringValue((String) columnValue));
                    }
                }

                @Override
                public void visit(NamedExpressionList namedExpressionList) {
                    throw new UnsupportedOperationException("Not supported yet.");
                }

                @Override
                public void visit(MultiExpressionList multiExpressionList) {
                    for (ExpressionList expressionList : multiExpressionList.getExprList()) {
                        if (columnValue instanceof String) {
                            expressionList.getExpressions().set(index, new StringValue((String) columnValue));
                        } else if (columnValue instanceof Long) {
                            expressionList.getExpressions().set(index, new LongValue((Long) columnValue));
                        } else {
                            // if you need to add other type data, add more if branch
                            expressionList.getExpressions().set(index, new StringValue((String) columnValue));
                        }
                    }
                }
            });
        }
    }

    /**
     * 将已經存在的列從ParameterMapping中移除
     * 以解決原始sql語句中已包含自動添加的列 導緻參數數量映射異常的問題
     *
     * @param boundSql
     */
    private void handleParameterMapping(BoundSql boundSql) {
        List<ParameterMapping> parameterMappingList = boundSql.getParameterMappings();
        Iterator<ParameterMapping> it = parameterMappingList.iterator();
        String userIdProperty = StringUtil.snakeToCamelCase("user_name");
        while (it.hasNext()) {
            ParameterMapping pm = it.next();
            // 後面的條件為相容批量插入操作(不能用contains)
            if (pm.getProperty().equals(userIdProperty) || pm.getProperty().endsWith("." + userIdProperty)) {
                log.debug("原始Sql語句已包含自動添加的列: {}", userIdProperty);
                it.remove();
            }
        }
    }

    /**
     * 忽略處理配置的表
     *
     * @param tableName 目前執行的sql表
     * @return true:表示比對忽略的表,false:表示不比對忽略的表
     */
    private boolean matchesIgnoreTables(String tableName) {
        for (String ignoreTable : ignoreTableList) {
            if (tableName.matches(ignoreTable)) {
                return true;
            }
        }
        return false;
    }

    /**
     * 支援INSERT INTO SELECT 語句
     */
    private class PlainSelectVisitor implements SelectVisitor {
        int index;
        Object columnValue;

        public PlainSelectVisitor(int index, Object columnValue) {
            this.index = index;
            this.columnValue = columnValue;
        }

        @Override
        public void visit(PlainSelect plainSelect) {
            if (index != -1) {
                if (columnValue instanceof String) {
                    plainSelect.getSelectItems().set(index, new SelectExpressionItem(new StringValue((String) columnValue)));
                } else if (columnValue instanceof Long) {
                    plainSelect.getSelectItems().set(index, new SelectExpressionItem(new LongValue((Long) columnValue)));
                } else {
                    // if you need to add other type data, add more if branch
                    plainSelect.getSelectItems().set(index, new SelectExpressionItem(new StringValue((String) columnValue)));
                }
            } else {
                if (columnValue instanceof String) {
                    plainSelect.getSelectItems().add(new SelectExpressionItem(new StringValue((String) columnValue)));
                } else if (columnValue instanceof Long) {
                    plainSelect.getSelectItems().add(new SelectExpressionItem(new LongValue((Long) columnValue)));
                } else {
                    // if you need to add other type data, add more if branch
                    plainSelect.getSelectItems().add(new SelectExpressionItem(new StringValue((String) columnValue)));
                }
            }
        }

        @Override
        public void visit(SetOperationList setOperationList) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override
        public void visit(WithItem withItem) {
            if (index != -1) {
                if (columnValue instanceof String) {
                    withItem.getWithItemList().set(index, new SelectExpressionItem(new StringValue((String) columnValue)));
                } else if (columnValue instanceof Long) {
                    withItem.getWithItemList().set(index, new SelectExpressionItem(new LongValue((Long) columnValue)));
                } else {
                    // if you need to add other type data, add more if branch
                    withItem.getWithItemList().set(index, new SelectExpressionItem(new StringValue((String) columnValue)));
                }
            } else {
                if (columnValue instanceof String) {
                    withItem.getWithItemList().add(new SelectExpressionItem(new StringValue((String) columnValue)));
                } else if (columnValue instanceof Long) {
                    withItem.getWithItemList().add(new SelectExpressionItem(new LongValue((Long) columnValue)));
                } else {
                    // if you need to add other type data, add more if branch
                    withItem.getWithItemList().add(new SelectExpressionItem(new StringValue((String) columnValue)));
                }
            }
        }

        @Override
        public void visit(ValuesStatement valuesStatement) {
            if (index != -1) {
                if (columnValue instanceof String) {
                    valuesStatement.getExpressions().set(index, new StringValue((String) columnValue));
                } else if (columnValue instanceof Long) {
                    valuesStatement.getExpressions().set(index, new LongValue((Long) columnValue));
                } else {
                    // if you need to add other type data, add more if branch
                    valuesStatement.getExpressions().set(index, new StringValue((String) columnValue));
                }
            } else {
                if (columnValue instanceof String) {
                    valuesStatement.getExpressions().add(new StringValue((String) columnValue));
                } else if (columnValue instanceof Long) {
                    valuesStatement.getExpressions().add(new LongValue((Long) columnValue));
                } else {
                    // if you need to add other type data, add more if branch
                    valuesStatement.getExpressions().add(new StringValue((String) columnValue));
                }
            }
        }
    }

    @Override
    public Object plugin(Object o) {
        return Plugin.wrap(o, this);
    }

    @Override
    public void setProperties(Properties properties) {
        // 接收到配置檔案的property參數
    }
}      

2、使用jsqlparser低版本(例如1.2)的自定義MyBatis plugin

maven依賴:

<dependency>
    <groupId>com.github.jsqlparser</groupId>
    <artifactId>jsqlparser</artifactId>
    <version>1.2</version>
</dependency>      
package com.saint.mybatis;

import com.saint.constant.MyBatisPluginConst;
import com.saint.utils.MyBatisPluginUtils;
import com.saint.utils.StringUtil;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.ItemsListVisitor;
import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.springframework.beans.factory.annotation.Value;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.util.*;

/**
 * MyBatis攔截器;自定義TraceContext落盤業務表邏輯
 *
 * @author Saint
 */
@Slf4j
@Intercepts({@Signature(type = StatementHandler.class,
        method = "prepare", args = {Connection.class, Integer.class}),
        @Signature(type = ParameterHandler.class, method = "setParameters", args = {PreparedStatement.class})
})
public class MybatisInterceptor implements Interceptor {

    /**
     * Tables not intercepted
     */
    @Value("#{'${mybatis.plugin.ignoreTables:}'.split(',')}")
    private List<String> ignoreTableList = Collections.emptyList();

    /**
     * 從啟動指令的-D參數中擷取`ENABLE_MYBATIS_PLUGIN`參數的值,表示是否啟動mybatis攔截器
     * 當然也可以使用System.getEnv()從環境變量中擷取
     */
    private String enableMybatisPlugin = System.getProperty("ENABLE_MYBATIS_PLUGIN");

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        try {

            if (StringUtils.isEmpty(enableMybatisPlugin) || (!StringUtils.equals(enableMybatisPlugin, "true")
                    && !StringUtils.equals(enableMybatisPlugin, "TRUE"))) {
                return invocation.proceed();
            }

            String invocationName = invocation.getMethod().getName();

            if (Objects.equals(invocationName, MyBatisPluginConst.METHOD_PREPARE.getVale())) {

                // case1:通過MetaObject優雅通路對象的屬性,這裡是通路statementHandler的屬性;
                //   1、MetaObject是Mybatis提供的一個用于友善、優雅通路對象屬性的對象;
                //   2、通過它可以簡化代碼、不需要try/catch各種reflect異常,同時它支援對JavaBean、Collection、Map三種類型對象的操作。
                StatementHandler handler = (StatementHandler) invocation.getTarget();
                MetaObject metaObject = SystemMetaObject.forObject(handler);

                // case2:先攔截到RoutingStatementHandler,裡面有個StatementHandler類型的delegate變量,
                //        其實作類是BaseStatementHandler,然後就到BaseStatementHandler的成員變量mappedStatement
                MappedStatement mappedStatement = (MappedStatement) metaObject.getValue(MyBatisPluginConst.DELEGATE_MAPPED_STATEMENT.getVale());

                // id為執行的mapper方法的全路徑名,如com.uv.dao.UserMapper.insertUser
                String id = mappedStatement.getId();

                //資料庫連接配接資訊
                Configuration configuration = mappedStatement.getConfiguration();

                // sql type: UNKNOWN, INSERT, UPDATE, DELETE, SELECT, FLUSH
                String sqlCommandType = mappedStatement.getSqlCommandType().toString();

                // only intercept update and insert dml
                if (!Objects.equals(sqlCommandType, SqlCommandType.UPDATE.toString())
                        && !Objects.equals(sqlCommandType, SqlCommandType.INSERT.toString())) {
                    return invocation.proceed();
                }

                // obtain original sql,擷取到原始sql語句,way1:
                String sql = metaObject.getValue(MyBatisPluginConst.DELEGATE_BOUND_SQL.getVale()).toString();
                //   way2: 也可以通過如下方式擷取原始的SQL語句:
//                BoundSql boundSql = handler.getBoundSql();
//                String sql = boundSql.getSql();

                // 通過jsqlparser解析SQL,此處的statement是封裝過後的Insert/Update/Query等SQL語句
                Statement statement = CCJSqlParserUtil.parse(sql);

                // todo 如果是簡單的加個查詢個數限制,可以使用jsqlparser解析SQL,用反射修改SQL語句也可。比如:
//                String mSql = sql + " limit 2";
//                BoundSql boundSql = handler.getBoundSql();
//                Field field = boundSql.getClass().getDeclaredField("sql");
//                field.setAccessible(true);
//                field.set(boundSql, mSql);

                switch (sqlCommandType) {
                    case "INSERT":
                        prepareInsertSql(statement, metaObject);
                        break;
                    case "UPDATE":
                        // can not handle, will not affect execute, but be elegant
                        prepareUpdateSql(statement, metaObject);
                        break;
                    default:
                        break;
                }
            } else if (Objects.equals(invocationName, MyBatisPluginConst.METHOD_SET_PARAMETERS.getVale())) {
                // 擷取最原始的參數解析器:ParameterHandler
                ParameterHandler handler = (ParameterHandler) MyBatisPluginUtils.realTarget(invocation.getTarget());
                MetaObject metaObject = SystemMetaObject.forObject(handler);
                MappedStatement mappedStatement = (MappedStatement) metaObject.getValue(MyBatisPluginConst.MAPPED_STATEMENT.getVale());
                // sql type: UNKNOWN, INSERT, UPDATE, DELETE, SELECT, FLUSH
                String sqlCommandType = mappedStatement.getSqlCommandType().toString();
                // only intercept update and insert dml
                if (!Objects.equals(sqlCommandType, SqlCommandType.UPDATE.toString())
                        && !Objects.equals(sqlCommandType, SqlCommandType.INSERT.toString())) {
                    return invocation.proceed();
                }

                BoundSql boundSql = (BoundSql) metaObject.getValue(MyBatisPluginConst.BOUND_SQL.getVale());
                Statement statement = CCJSqlParserUtil.parse(boundSql.getSql());
                switch (sqlCommandType) {
                    case "INSERT":
                        Insert insert = (Insert) statement;
                        if (!matchesIgnoreTables(insert.getTable().getName())) {
                            handleParameterMapping(boundSql);
                        }
                        break;
                    case "UPDATE":
                        Update update = (Update) statement;
                        if (!matchesIgnoreTables(update.getTables().get(0).getName())) {
                            handleParameterMapping(boundSql);
                        }
                        break;
                    default:
                        break;
                }
            }

        } catch (Exception e) {
            log.error("Exception in executing MyBatis Interceptor", e);
        }

        return invocation.proceed();
    }

    /**
     * handle update sql in StatementHandler#prepare() phase
     *
     * @param statement  statement
     * @param metaObject metaObject
     */
    private void prepareUpdateSql(Statement statement, MetaObject metaObject) {

        Update update = (Update) statement;
        if (matchesIgnoreTables(update.getTables().get(0).getName())) {
            return;
        }

        boolean isContainsUserIdColumn = false;
        int modifyDateColumnIndex = 0;

        for (int i = 0; i < update.getColumns().size(); i++) {
            Column column = update.getColumns().get(i);
            if (column.getColumnName().equals("user_name")) {
                // sql中包含了設定的列名,則隻需要設定值
                isContainsUserIdColumn = true;
                modifyDateColumnIndex = i;
            }
        }

        // 如果sql語句已經包含了`user_name`字段,則更新字段值,否者新增字段列和值
        if (isContainsUserIdColumn) {
            updateValueWithIndex(modifyDateColumnIndex, "Saint-update", update);
        } else {
            updateValue("user_name", "Saint-update", update);
        }

        log.debug("intercept update sql is : {}", update);
        metaObject.setValue("delegate.boundSql.sql", update.toString());

    }

    /**
     * handle insert sql in StatementHandler#prepare() phase
     *
     * @param statement  statement
     * @param metaObject metaObject
     */
    private void prepareInsertSql(Statement statement, MetaObject metaObject) {

        Insert insert = (Insert) statement;
        if (matchesIgnoreTables(insert.getTable().getName())) {
            return;
        }

        boolean isContainsUserIdColumn = false;
        int createDateColumnIndex = 0;
        for (int i = 0; i < insert.getColumns().size(); i++) {
            Column column = insert.getColumns().get(i);
            if (column.getColumnName().equals("user_name")) {
                // sql中包含了設定的列名,則隻需要設定值
                isContainsUserIdColumn = true;
                createDateColumnIndex = i;
            }
        }

        if (isContainsUserIdColumn) {
            intoValueWithIndex(createDateColumnIndex, "Saint-insert", insert);
        } else {
            intoValue("user_name", "Saint-insert", insert);
        }

        log.debug("intercept insert sql is : {}", insert);

        metaObject.setValue("delegate.boundSql.sql", insert.toString());
    }

    /**
     * update sql update column value
     *
     * @param modifyDateColumnIndex
     * @param columnValue
     * @param update
     */
    private void updateValueWithIndex(int modifyDateColumnIndex, Object columnValue, Update update) {
        if (columnValue instanceof Long) {
            update.getExpressions().set(modifyDateColumnIndex, new LongValue((Long) columnValue));
        } else if (columnValue instanceof String) {
            update.getExpressions().set(modifyDateColumnIndex, new StringValue((String) columnValue));
        } else {
            // if you need to add other type data, add more if branch
            update.getExpressions().set(modifyDateColumnIndex, new StringValue((String) columnValue));
        }
    }

    /**
     * update sql add column
     *
     * @param updateDateColumnName
     * @param columnValue
     * @param update
     */
    private void updateValue(String updateDateColumnName, Object columnValue, Update update) {
        // 添加列
        update.getColumns().add(new Column(updateDateColumnName));
        if (columnValue instanceof Long) {
            update.getExpressions().add(new LongValue((Long) columnValue));
        } else if (columnValue instanceof String) {
            update.getExpressions().add(new StringValue((String) columnValue));
        } else {
            // if you need to add other type data, add more if branch
            update.getExpressions().add(new StringValue((String) columnValue));
        }
    }

    /**
     * insert sql add column
     *
     * @param columnName
     * @param columnValue
     * @param insert
     */
    private void intoValue(String columnName, final Object columnValue, Insert insert) {
        // 添加列
        insert.getColumns().add(new Column(columnName));
        // 通過visitor設定對應的值
        if (insert.getItemsList() == null) {
            insert.getSelect().getSelectBody().accept(new PlainSelectVisitor(-1, columnValue));
        } else {
            insert.getItemsList().accept(new ItemsListVisitor() {
                @Override
                public void visit(SubSelect subSelect) {
                    throw new UnsupportedOperationException("Not supported yet.");
                }

                @Override
                public void visit(ExpressionList expressionList) {
                    // 這裡表示添加列時。列值在資料庫中的資料類型, 目前隻用到了Long和String,需要的自行擴充
                    // todo 下面出現此類代碼的都一樣
                    if (columnValue instanceof String) {
                        expressionList.getExpressions().add(new StringValue((String) columnValue));
                    } else if (columnValue instanceof Long) {
                        expressionList.getExpressions().add(new LongValue((Long) columnValue));
                    } else {
                        // if you need to add other type data, add more if branch
                        expressionList.getExpressions().add(new StringValue((String) columnValue));
                    }
                }

                @Override
                public void visit(MultiExpressionList multiExpressionList) {
                    for (ExpressionList expressionList : multiExpressionList.getExprList()) {
                        if (columnValue instanceof String) {
                            expressionList.getExpressions().add(new StringValue((String) columnValue));
                        } else if (columnValue instanceof Long) {
                            expressionList.getExpressions().add(new LongValue((Long) columnValue));
                        } else {
                            // if you need to add other type data, add more if branch
                            expressionList.getExpressions().add(new StringValue((String) columnValue));
                        }
                    }
                }
            });
        }
    }

    /**
     * insert sql update column value
     *
     * @param index
     * @param columnValue
     * @param insert
     */
    private void intoValueWithIndex(final int index, final Object columnValue, Insert insert) {
        // 通過visitor設定對應的值
        if (insert.getItemsList() == null) {
            insert.getSelect().getSelectBody().accept(new PlainSelectVisitor(index, columnValue));
        } else {
            insert.getItemsList().accept(new ItemsListVisitor() {
                @Override
                public void visit(SubSelect subSelect) {
                    throw new UnsupportedOperationException("Not supported yet.");
                }

                @Override
                public void visit(ExpressionList expressionList) {
                    if (columnValue instanceof String) {
                        expressionList.getExpressions().set(index, new StringValue((String) columnValue));
                    } else if (columnValue instanceof Long) {
                        expressionList.getExpressions().set(index, new LongValue((Long) columnValue));
                    } else {
                        // if you need to add other type data, add more if branch
                        expressionList.getExpressions().set(index, new StringValue((String) columnValue));
                    }
                }

                @Override
                public void visit(MultiExpressionList multiExpressionList) {
                    for (ExpressionList expressionList : multiExpressionList.getExprList()) {
                        if (columnValue instanceof String) {
                            expressionList.getExpressions().set(index, new StringValue((String) columnValue));
                        } else if (columnValue instanceof Long) {
                            expressionList.getExpressions().set(index, new LongValue((Long) columnValue));
                        } else {
                            // if you need to add other type data, add more if branch
                            expressionList.getExpressions().set(index, new StringValue((String) columnValue));
                        }
                    }
                }
            });
        }
    }

    /**
     * 将已經存在的列從ParameterMapping中移除
     * 以解決原始sql語句中已包含自動添加的列 導緻參數數量映射異常的問題
     *
     * @param boundSql
     */
    private void handleParameterMapping(BoundSql boundSql) {
        List<ParameterMapping> parameterMappingList = boundSql.getParameterMappings();
        Iterator<ParameterMapping> it = parameterMappingList.iterator();
        String userIdProperty = StringUtil.snakeToCamelCase("user_name");
        while (it.hasNext()) {
            ParameterMapping pm = it.next();
            // 後面的條件為相容批量插入操作(不能用contains)
            if (pm.getProperty().equals(userIdProperty) || pm.getProperty().endsWith("." + userIdProperty)) {
                log.debug("原始Sql語句已包含自動添加的列: {}", userIdProperty);
                it.remove();
            }
        }
    }

    /**
     * 忽略處理配置的表
     *
     * @param tableName 目前執行的sql表
     * @return true:表示比對忽略的表,false:表示不比對忽略的表
     */
    private boolean matchesIgnoreTables(String tableName) {
        for (String ignoreTable : ignoreTableList) {
            if (tableName.matches(ignoreTable)) {
                return true;
            }
        }
        return false;
    }

    /**
     * 支援INSERT INTO SELECT 語句
     */
    private class PlainSelectVisitor implements SelectVisitor {
        int index;
        Object columnValue;

        public PlainSelectVisitor(int index, Object columnValue) {
            this.index = index;
            this.columnValue = columnValue;
        }

        @Override
        public void visit(PlainSelect plainSelect) {
            if (index != -1) {
                if (columnValue instanceof String) {
                    plainSelect.getSelectItems().set(index, new SelectExpressionItem(new StringValue((String) columnValue)));
                } else if (columnValue instanceof Long) {
                    plainSelect.getSelectItems().set(index, new SelectExpressionItem(new LongValue((Long) columnValue)));
                } else {
                    // if you need to add other type data, add more if branch
                    plainSelect.getSelectItems().set(index, new SelectExpressionItem(new StringValue((String) columnValue)));
                }
            } else {
                if (columnValue instanceof String) {
                    plainSelect.getSelectItems().add(new SelectExpressionItem(new StringValue((String) columnValue)));
                } else if (columnValue instanceof Long) {
                    plainSelect.getSelectItems().add(new SelectExpressionItem(new LongValue((Long) columnValue)));
                } else {
                    // if you need to add other type data, add more if branch
                    plainSelect.getSelectItems().add(new SelectExpressionItem(new StringValue((String) columnValue)));
                }
            }
        }

        @Override
        public void visit(SetOperationList setOperationList) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override
        public void visit(WithItem withItem) {
            if (index != -1) {
                if (columnValue instanceof String) {
                    withItem.getWithItemList().set(index, new SelectExpressionItem(new StringValue((String) columnValue)));
                } else if (columnValue instanceof Long) {
                    withItem.getWithItemList().set(index, new SelectExpressionItem(new LongValue((Long) columnValue)));
                } else {
                    // if you need to add other type data, add more if branch
                    withItem.getWithItemList().set(index, new SelectExpressionItem(new StringValue((String) columnValue)));
                }
            } else {
                if (columnValue instanceof String) {
                    withItem.getWithItemList().add(new SelectExpressionItem(new StringValue((String) columnValue)));
                } else if (columnValue instanceof Long) {
                    withItem.getWithItemList().add(new SelectExpressionItem(new LongValue((Long) columnValue)));
                } else {
                    // if you need to add other type data, add more if branch
                    withItem.getWithItemList().add(new SelectExpressionItem(new StringValue((String) columnValue)));
                }
            }
        }
    }

    @Override
    public Object plugin(Object o) {
        return Plugin.wrap(o, this);
    }

    @Override
    public void setProperties(Properties properties) {
        // 接收到配置檔案的property參數
    }
}      

三、如何使用自定義的Mybatis Plugin生效?