From b9f315ae18c0fb54071753ec489844d80fa61f73 Mon Sep 17 00:00:00 2001 From: Milkdove Date: Tue, 3 Sep 2024 10:45:06 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9EWrapper=E7=B1=BB?= =?UTF-8?q?=E5=AF=B9=E6=B3=A8=E8=A7=A3@Param=E8=AE=BE=E7=BD=AE=E5=88=AB?= =?UTF-8?q?=E5=90=8D=E7=9A=84=E6=94=AF=E6=8C=81&setParamAlias=E5=8F=AF?= =?UTF-8?q?=E4=BB=A5=E9=87=8D=E5=A4=8D=E8=AE=BE=E7=BD=AE&=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0SqlSegment=E7=9A=84=E7=BC=93=E5=AD=98=E5=88=A0?= =?UTF-8?q?=E9=99=A4=E8=AE=BF=E9=97=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/conditions/AbstractWrapper.java | 10 ++-- .../mybatisplus/core/conditions/Wrapper.java | 2 +- .../segments/AbstractISegmentList.java | 5 ++ .../conditions/segments/MergeSegments.java | 9 +++ .../core/override/MybatisMapperMethod.java | 60 ++++++++++++++++--- .../mybatisplus/test/h2/H2UserTest.java | 32 ++++++++++ .../test/h2/mapper/H2UserMapper.java | 6 ++ .../test/h2/service/IH2UserService.java | 4 ++ .../h2/service/impl/H2UserServiceImpl.java | 10 ++++ 9 files changed, 123 insertions(+), 15 deletions(-) diff --git a/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/AbstractWrapper.java b/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/AbstractWrapper.java index 31316b098..a19554156 100755 --- a/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/AbstractWrapper.java +++ b/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/AbstractWrapper.java @@ -650,17 +650,17 @@ public String getParamAlias() { } /** - * 参数别名设置,初始化时优先设置该值、重复设置异常 + * 参数别名设置,初始化时优先设置该值 * * @param paramAlias 参数别名 * @return Children */ - @SuppressWarnings("unused") public Children setParamAlias(String paramAlias) { - Assert.notEmpty(paramAlias, "paramAlias can not be empty!"); - Assert.isEmpty(paramNameValuePairs, "Please call this method before working!"); - Assert.isNull(this.paramAlias, "Please do not call the method repeatedly!"); + String oldParamAlias = getParamAlias(); this.paramAlias = new SharedString(paramAlias); + if (this.expression != null && !oldParamAlias.equals(paramAlias)) { + expression.clearSqlSegmentCache(); + } return typedThis; } diff --git a/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/Wrapper.java b/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/Wrapper.java index 19fbcea6c..e2a3e309d 100644 --- a/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/Wrapper.java +++ b/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/Wrapper.java @@ -70,7 +70,7 @@ public String getSqlFirst() { * 1. 逻辑删除需要自己拼接条件 (之前自定义也同样) * 2. 不支持wrapper中附带实体的情况 (wrapper自带实体会更麻烦) * 3. 用法 ${ew.customSqlSegment} (不需要where标签包裹,切记!) - * 4. ew是wrapper定义别名,不能使用其他的替换 + * 4. ew是wrapper定义的默认别名,可通过{@link org.apache.ibatis.annotations.Param}注解修改,也可以通过{@link AbstractWrapper#setParamAlias(String)}方式设置别名 */ public String getCustomSqlSegment() { MergeSegments expression = getExpression(); diff --git a/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/segments/AbstractISegmentList.java b/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/segments/AbstractISegmentList.java index 31fb4740d..19741b9b6 100644 --- a/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/segments/AbstractISegmentList.java +++ b/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/segments/AbstractISegmentList.java @@ -119,4 +119,9 @@ public void clear() { sqlSegment = EMPTY; cacheSqlSegment = true; } + + public void clearSqlSegmentCache() { + sqlSegment = EMPTY; + cacheSqlSegment = false; + } } diff --git a/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/segments/MergeSegments.java b/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/segments/MergeSegments.java index de4ccfcaf..d3f4f7ead 100644 --- a/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/segments/MergeSegments.java +++ b/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/segments/MergeSegments.java @@ -73,6 +73,15 @@ public String getSqlSegment() { return sqlSegment; } + public void clearSqlSegmentCache() { + sqlSegment = StringPool.EMPTY; + cacheSqlSegment = false; + normal.clearSqlSegmentCache(); + groupBy.clearSqlSegmentCache(); + having.clearSqlSegmentCache(); + orderBy.clearSqlSegmentCache(); + } + /** * 清理 * diff --git a/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/override/MybatisMapperMethod.java b/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/override/MybatisMapperMethod.java index ac766b65f..7c60830b6 100644 --- a/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/override/MybatisMapperMethod.java +++ b/mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/override/MybatisMapperMethod.java @@ -15,8 +15,12 @@ */ package com.baomidou.mybatisplus.core.override; +import com.baomidou.mybatisplus.core.conditions.AbstractWrapper; +import com.baomidou.mybatisplus.core.conditions.Wrapper; import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.core.toolkit.Assert; + +import org.apache.ibatis.annotations.Param; import org.apache.ibatis.binding.BindingException; import org.apache.ibatis.binding.MapperMethod; import org.apache.ibatis.cursor.Cursor; @@ -27,11 +31,15 @@ import org.apache.ibatis.session.RowBounds; import org.apache.ibatis.session.SqlSession; +import java.lang.annotation.Annotation; import java.lang.reflect.Array; import java.lang.reflect.Method; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.SortedMap; +import java.util.TreeMap; /** * 从 {@link MapperMethod} copy 过来
@@ -45,27 +53,46 @@ public class MybatisMapperMethod { private final MapperMethod.SqlCommand command; private final MapperMethod.MethodSignature method; + private final Map wrapperParamsAliasNameMap; public MybatisMapperMethod(Class mapperInterface, Method method, Configuration config) { + Annotation[][] paramAnnotations = method.getParameterAnnotations(); + Class[] parameterTypes = method.getParameterTypes(); + int paramCount = method.getParameterCount(); + final SortedMap map = new TreeMap<>(); + // get names from @Param annotations + for (int paramIndex = 0; paramIndex < paramCount; paramIndex++) { + String name = null; + for (Annotation annotation : paramAnnotations[paramIndex]) { + Class parameterType = parameterTypes[paramIndex]; + if (annotation instanceof Param && Wrapper.class.isAssignableFrom(parameterType)) { + name = ((Param) annotation).value(); + break; + } + } + map.put(paramIndex, name); + } + wrapperParamsAliasNameMap = Collections.unmodifiableSortedMap(map); this.command = new MapperMethod.SqlCommand(config, mapperInterface, method); this.method = new MapperMethod.MethodSignature(config, mapperInterface, method); } + public Object execute(SqlSession sqlSession, Object[] args) { Object result; switch (command.getType()) { case INSERT: { - Object param = method.convertArgsToSqlCommandParam(args); + Object param = this.convertArgsToSqlCommandParam(args); result = rowCountResult(sqlSession.insert(command.getName(), param)); break; } case UPDATE: { - Object param = method.convertArgsToSqlCommandParam(args); + Object param = this.convertArgsToSqlCommandParam(args); result = rowCountResult(sqlSession.update(command.getName(), param)); break; } case DELETE: { - Object param = method.convertArgsToSqlCommandParam(args); + Object param = this.convertArgsToSqlCommandParam(args); result = rowCountResult(sqlSession.delete(command.getName(), param)); break; } @@ -83,7 +110,7 @@ public Object execute(SqlSession sqlSession, Object[] args) { if (IPage.class.isAssignableFrom(method.getReturnType())) { result = executeForIPage(sqlSession, args); } else { - Object param = method.convertArgsToSqlCommandParam(args); + Object param = this.convertArgsToSqlCommandParam(args); result = sqlSession.selectOne(command.getName(), param); if (method.returnsOptional() && (result == null || !method.getReturnType().equals(result.getClass()))) { @@ -115,7 +142,7 @@ private Object executeForIPage(SqlSession sqlSession, Object[] args) { } } Assert.notNull(result, "can't found IPage for args!"); - Object param = method.convertArgsToSqlCommandParam(args); + Object param = this.convertArgsToSqlCommandParam(args); List list = sqlSession.selectList(command.getName(), param); result.setRecords(list); return result; @@ -145,7 +172,7 @@ private void executeWithResultHandler(SqlSession sqlSession, Object[] args) { + " needs either a @ResultMap annotation, a @ResultType annotation," + " or a resultType attribute in XML so a ResultHandler can be used as a parameter."); } - Object param = method.convertArgsToSqlCommandParam(args); + Object param = this.convertArgsToSqlCommandParam(args); if (method.hasRowBounds()) { RowBounds rowBounds = method.extractRowBounds(args); sqlSession.select(command.getName(), param, rowBounds, method.extractResultHandler(args)); @@ -156,7 +183,7 @@ private void executeWithResultHandler(SqlSession sqlSession, Object[] args) { private Object executeForMany(SqlSession sqlSession, Object[] args) { List result; - Object param = method.convertArgsToSqlCommandParam(args); + Object param = this.convertArgsToSqlCommandParam(args); if (method.hasRowBounds()) { RowBounds rowBounds = method.extractRowBounds(args); result = sqlSession.selectList(command.getName(), param, rowBounds); @@ -176,7 +203,7 @@ private Object executeForMany(SqlSession sqlSession, Object[] args) { private Cursor executeForCursor(SqlSession sqlSession, Object[] args) { Cursor result; - Object param = method.convertArgsToSqlCommandParam(args); + Object param = this.convertArgsToSqlCommandParam(args); if (method.hasRowBounds()) { RowBounds rowBounds = method.extractRowBounds(args); result = sqlSession.selectCursor(command.getName(), param, rowBounds); @@ -208,7 +235,7 @@ private Object convertToArray(List list) { private Map executeForMap(SqlSession sqlSession, Object[] args) { Map result; - Object param = method.convertArgsToSqlCommandParam(args); + Object param = this.convertArgsToSqlCommandParam(args); if (method.hasRowBounds()) { RowBounds rowBounds = method.extractRowBounds(args); result = sqlSession.selectMap(command.getName(), param, method.getMapKey(), rowBounds); @@ -217,4 +244,19 @@ private Map executeForMap(SqlSession sqlSession, Object[] args) { } return result; } + + private Object convertArgsToSqlCommandParam(Object[] args) { + if (args == null) { + return null; + } + int argCount = args.length; + for (int i = 0; i < argCount; i++) { + Object arg = args[i]; + String s = wrapperParamsAliasNameMap.get(i); + if (s != null && arg instanceof AbstractWrapper) { + ((AbstractWrapper) arg).setParamAlias(s); + } + } + return method.convertArgsToSqlCommandParam(args); + } } diff --git a/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserTest.java b/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserTest.java index b2ab34094..4dd81c087 100644 --- a/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserTest.java +++ b/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserTest.java @@ -487,6 +487,38 @@ void testSqlInjectionByCustomSqlSegment() { Assertions.assertEquals(2, h2Users.size()); } + + @Test + @Order(33) + void testWrapperSetAliasByParam() { + // Preparing: select * from h2user WHERE (name LIKE ?) + // Parameters: %y%%(String) + List h2Users = userService.testWrapperSetAliasByParam(new QueryWrapper().like("name", "y%")); + Assertions.assertEquals(2, h2Users.size()); + } + + @Test + @Order(34) + void testMultiWrapperQuery() { + // Preparing: select * from h2user a inner join h2user b on a.name=b.name WHERE (a.name LIKE ?) and (b.name = ?) + // Parameters: %y%%(String), Jerry(String) + QueryWrapper leftTable = new QueryWrapper() { + @Override + protected String columnToString(String column) { + return "a." + super.columnToString(column); + } + }.like("name", "y%"); + System.out.println(leftTable.getCustomSqlSegment()); + QueryWrapper rightTable = new QueryWrapper() { + @Override + protected String columnToString(String column) { + return "b." + super.columnToString(column); + } + }.eq("name", "Jerry"); + List h2Users = userService.testMultiWrapperQuery(leftTable, rightTable); + Assertions.assertEquals(1, h2Users.size()); + } + @Test void myQueryWithGroupByOrderBy() { userService.mySelectMaps().forEach(System.out::println); diff --git a/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/mapper/H2UserMapper.java b/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/mapper/H2UserMapper.java index 400448b79..cf63125bf 100644 --- a/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/mapper/H2UserMapper.java +++ b/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/mapper/H2UserMapper.java @@ -66,6 +66,12 @@ public interface H2UserMapper extends SuperMapper { @Select("select * from h2user ${ew.customSqlSegment}") List selectTestCustomSqlSegment(@Param(Constants.WRAPPER) Wrapper wrapper); + @Select("select * from h2user ${ewAlias.customSqlSegment}") + List testWrapperSetAliasByParam(@Param("ewAlias") Wrapper wrapper); + + @Select("select * from h2user a inner join h2user b on a.name=b.name ${asd.customSqlSegment} and ${haha.sqlSegment}") + List testMultiWrapperQuery(@Param("haha") Wrapper haha, @Param("asd") Wrapper wrapper2); + @Select("select count(1) from (" + "select test_id as id, CAST(#{nameParam} AS VARCHAR) as name" + " from h2user " + diff --git a/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/service/IH2UserService.java b/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/service/IH2UserService.java index 94f9d2426..ca7243f78 100644 --- a/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/service/IH2UserService.java +++ b/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/service/IH2UserService.java @@ -46,6 +46,10 @@ public interface IH2UserService extends IService { List testCustomSqlSegment(Wrapper wrapper); + List testWrapperSetAliasByParam(Wrapper wrapper); + + List testMultiWrapperQuery(Wrapper wrapper, Wrapper wrapper2); + void testSaveOrUpdateTransactional1(List users); void testSaveOrUpdateTransactional2(List users); diff --git a/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/service/impl/H2UserServiceImpl.java b/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/service/impl/H2UserServiceImpl.java index 315027f2f..e910f74c6 100644 --- a/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/service/impl/H2UserServiceImpl.java +++ b/mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/service/impl/H2UserServiceImpl.java @@ -128,6 +128,16 @@ public List testCustomSqlSegment(Wrapper wrapper) { return baseMapper.selectTestCustomSqlSegment(wrapper); } + @Override + public List testWrapperSetAliasByParam(Wrapper wrapper) { + return baseMapper.testWrapperSetAliasByParam(wrapper); + } + + @Override + public List testMultiWrapperQuery(Wrapper wrapper, Wrapper wrapper2) { + return baseMapper.testMultiWrapperQuery(wrapper, wrapper2); + } + @Override @Transactional(rollbackFor = RuntimeException.class) public void testSaveOrUpdateTransactional1(List users) {