Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 新增Wrapper类对注解@Param设置别名的支持&setParamAlias可以重复设置&增加SqlSegment的缓存删除访问 #6450

Open
wants to merge 1 commit into
base: 3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -650,17 +650,17 @@ public String getParamAlias() {
}

/**
* 参数别名设置,初始化时优先设置该值、重复设置异常
* 参数别名设置,初始化时优先设置该值
*
* @param paramAlias 参数别名
* @return Children
*/
@SuppressWarnings("unused")
public Children setParamAlias(String paramAlias) {
Assert.notEmpty(paramAlias, "paramAlias can not be empty!");
Assert.isEmpty(paramNameValuePairs, "Please call this method before working!");
Assert.isNull(this.paramAlias, "Please do not call the method repeatedly!");
String oldParamAlias = getParamAlias();
this.paramAlias = new SharedString(paramAlias);
if (this.expression != null && !oldParamAlias.equals(paramAlias)) {
expression.clearSqlSegmentCache();
}
return typedThis;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public String getSqlFirst() {
* 1. 逻辑删除需要自己拼接条件 (之前自定义也同样)
* 2. 不支持wrapper中附带实体的情况 (wrapper自带实体会更麻烦)
* 3. 用法 ${ew.customSqlSegment} (不需要where标签包裹,切记!)
* 4. ew是wrapper定义别名,不能使用其他的替换
* 4. ew是wrapper定义的默认别名,可通过{@link org.apache.ibatis.annotations.Param}注解修改,也可以通过{@link AbstractWrapper#setParamAlias(String)}方式设置别名
*/
public String getCustomSqlSegment() {
MergeSegments expression = getExpression();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,9 @@ public void clear() {
sqlSegment = EMPTY;
cacheSqlSegment = true;
}

public void clearSqlSegmentCache() {
sqlSegment = EMPTY;
cacheSqlSegment = false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ public String getSqlSegment() {
return sqlSegment;
}

public void clearSqlSegmentCache() {
sqlSegment = StringPool.EMPTY;
cacheSqlSegment = false;
normal.clearSqlSegmentCache();
groupBy.clearSqlSegmentCache();
having.clearSqlSegmentCache();
orderBy.clearSqlSegmentCache();
}

/**
* 清理
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
*/
package com.baomidou.mybatisplus.core.override;

import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.toolkit.Assert;

import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.binding.BindingException;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.cursor.Cursor;
Expand All @@ -27,11 +31,15 @@
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.session.SqlSession;

import java.lang.annotation.Annotation;
import java.lang.reflect.Array;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.SortedMap;
import java.util.TreeMap;

/**
* 从 {@link MapperMethod} copy 过来 </br>
Expand All @@ -45,27 +53,46 @@
public class MybatisMapperMethod {
private final MapperMethod.SqlCommand command;
private final MapperMethod.MethodSignature method;
private final Map<Integer, String> wrapperParamsAliasNameMap;

public MybatisMapperMethod(Class<?> mapperInterface, Method method, Configuration config) {
Annotation[][] paramAnnotations = method.getParameterAnnotations();
Class<?>[] parameterTypes = method.getParameterTypes();
int paramCount = method.getParameterCount();
final SortedMap<Integer, String> map = new TreeMap<>();
// get names from @Param annotations
for (int paramIndex = 0; paramIndex < paramCount; paramIndex++) {
String name = null;
for (Annotation annotation : paramAnnotations[paramIndex]) {
Class<?> parameterType = parameterTypes[paramIndex];
if (annotation instanceof Param && Wrapper.class.isAssignableFrom(parameterType)) {
name = ((Param) annotation).value();
break;
}
}
map.put(paramIndex, name);
}
wrapperParamsAliasNameMap = Collections.unmodifiableSortedMap(map);
this.command = new MapperMethod.SqlCommand(config, mapperInterface, method);
this.method = new MapperMethod.MethodSignature(config, mapperInterface, method);
}


public Object execute(SqlSession sqlSession, Object[] args) {
Object result;
switch (command.getType()) {
case INSERT: {
Object param = method.convertArgsToSqlCommandParam(args);
Object param = this.convertArgsToSqlCommandParam(args);
result = rowCountResult(sqlSession.insert(command.getName(), param));
break;
}
case UPDATE: {
Object param = method.convertArgsToSqlCommandParam(args);
Object param = this.convertArgsToSqlCommandParam(args);
result = rowCountResult(sqlSession.update(command.getName(), param));
break;
}
case DELETE: {
Object param = method.convertArgsToSqlCommandParam(args);
Object param = this.convertArgsToSqlCommandParam(args);
result = rowCountResult(sqlSession.delete(command.getName(), param));
break;
}
Expand All @@ -83,7 +110,7 @@ public Object execute(SqlSession sqlSession, Object[] args) {
if (IPage.class.isAssignableFrom(method.getReturnType())) {
result = executeForIPage(sqlSession, args);
} else {
Object param = method.convertArgsToSqlCommandParam(args);
Object param = this.convertArgsToSqlCommandParam(args);
result = sqlSession.selectOne(command.getName(), param);
if (method.returnsOptional()
&& (result == null || !method.getReturnType().equals(result.getClass()))) {
Expand Down Expand Up @@ -115,7 +142,7 @@ private <E> Object executeForIPage(SqlSession sqlSession, Object[] args) {
}
}
Assert.notNull(result, "can't found IPage for args!");
Object param = method.convertArgsToSqlCommandParam(args);
Object param = this.convertArgsToSqlCommandParam(args);
List<E> list = sqlSession.selectList(command.getName(), param);
result.setRecords(list);
return result;
Expand Down Expand Up @@ -145,7 +172,7 @@ private void executeWithResultHandler(SqlSession sqlSession, Object[] args) {
+ " needs either a @ResultMap annotation, a @ResultType annotation,"
+ " or a resultType attribute in XML so a ResultHandler can be used as a parameter.");
}
Object param = method.convertArgsToSqlCommandParam(args);
Object param = this.convertArgsToSqlCommandParam(args);
if (method.hasRowBounds()) {
RowBounds rowBounds = method.extractRowBounds(args);
sqlSession.select(command.getName(), param, rowBounds, method.extractResultHandler(args));
Expand All @@ -156,7 +183,7 @@ private void executeWithResultHandler(SqlSession sqlSession, Object[] args) {

private <E> Object executeForMany(SqlSession sqlSession, Object[] args) {
List<E> result;
Object param = method.convertArgsToSqlCommandParam(args);
Object param = this.convertArgsToSqlCommandParam(args);
if (method.hasRowBounds()) {
RowBounds rowBounds = method.extractRowBounds(args);
result = sqlSession.selectList(command.getName(), param, rowBounds);
Expand All @@ -176,7 +203,7 @@ private <E> Object executeForMany(SqlSession sqlSession, Object[] args) {

private <T> Cursor<T> executeForCursor(SqlSession sqlSession, Object[] args) {
Cursor<T> result;
Object param = method.convertArgsToSqlCommandParam(args);
Object param = this.convertArgsToSqlCommandParam(args);
if (method.hasRowBounds()) {
RowBounds rowBounds = method.extractRowBounds(args);
result = sqlSession.selectCursor(command.getName(), param, rowBounds);
Expand Down Expand Up @@ -208,7 +235,7 @@ private <E> Object convertToArray(List<E> list) {

private <K, V> Map<K, V> executeForMap(SqlSession sqlSession, Object[] args) {
Map<K, V> result;
Object param = method.convertArgsToSqlCommandParam(args);
Object param = this.convertArgsToSqlCommandParam(args);
if (method.hasRowBounds()) {
RowBounds rowBounds = method.extractRowBounds(args);
result = sqlSession.selectMap(command.getName(), param, method.getMapKey(), rowBounds);
Expand All @@ -217,4 +244,19 @@ private <K, V> Map<K, V> executeForMap(SqlSession sqlSession, Object[] args) {
}
return result;
}

private Object convertArgsToSqlCommandParam(Object[] args) {
if (args == null) {
return null;
}
int argCount = args.length;
for (int i = 0; i < argCount; i++) {
Object arg = args[i];
String s = wrapperParamsAliasNameMap.get(i);
if (s != null && arg instanceof AbstractWrapper) {
((AbstractWrapper<?, ?, ?>) arg).setParamAlias(s);
}
}
return method.convertArgsToSqlCommandParam(args);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,38 @@ void testSqlInjectionByCustomSqlSegment() {
Assertions.assertEquals(2, h2Users.size());
}


@Test
@Order(33)
void testWrapperSetAliasByParam() {
// Preparing: select * from h2user WHERE (name LIKE ?)
// Parameters: %y%%(String)
List<H2User> h2Users = userService.testWrapperSetAliasByParam(new QueryWrapper<H2User>().like("name", "y%"));
Assertions.assertEquals(2, h2Users.size());
}

@Test
@Order(34)
void testMultiWrapperQuery() {
// Preparing: select * from h2user a inner join h2user b on a.name=b.name WHERE (a.name LIKE ?) and (b.name = ?)
// Parameters: %y%%(String), Jerry(String)
QueryWrapper<H2User> leftTable = new QueryWrapper<H2User>() {
@Override
protected String columnToString(String column) {
return "a." + super.columnToString(column);
}
}.like("name", "y%");
System.out.println(leftTable.getCustomSqlSegment());
QueryWrapper<H2User> rightTable = new QueryWrapper<H2User>() {
@Override
protected String columnToString(String column) {
return "b." + super.columnToString(column);
}
}.eq("name", "Jerry");
List<H2User> h2Users = userService.testMultiWrapperQuery(leftTable, rightTable);
Assertions.assertEquals(1, h2Users.size());
}

@Test
void myQueryWithGroupByOrderBy() {
userService.mySelectMaps().forEach(System.out::println);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ public interface H2UserMapper extends SuperMapper<H2User> {
@Select("select * from h2user ${ew.customSqlSegment}")
List<H2User> selectTestCustomSqlSegment(@Param(Constants.WRAPPER) Wrapper wrapper);

@Select("select * from h2user ${ewAlias.customSqlSegment}")
List<H2User> testWrapperSetAliasByParam(@Param("ewAlias") Wrapper wrapper);

@Select("select * from h2user a inner join h2user b on a.name=b.name ${asd.customSqlSegment} and ${haha.sqlSegment}")
List<H2User> testMultiWrapperQuery(@Param("haha") Wrapper haha, @Param("asd") Wrapper wrapper2);

@Select("select count(1) from (" +
"select test_id as id, CAST(#{nameParam} AS VARCHAR) as name" +
" from h2user " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ public interface IH2UserService extends IService<H2User> {

List<H2User> testCustomSqlSegment(Wrapper wrapper);

List<H2User> testWrapperSetAliasByParam(Wrapper wrapper);

List<H2User> testMultiWrapperQuery(Wrapper wrapper, Wrapper wrapper2);

void testSaveOrUpdateTransactional1(List<H2User> users);

void testSaveOrUpdateTransactional2(List<H2User> users);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ public List<H2User> testCustomSqlSegment(Wrapper wrapper) {
return baseMapper.selectTestCustomSqlSegment(wrapper);
}

@Override
public List<H2User> testWrapperSetAliasByParam(Wrapper wrapper) {
return baseMapper.testWrapperSetAliasByParam(wrapper);
}

@Override
public List<H2User> testMultiWrapperQuery(Wrapper wrapper, Wrapper wrapper2) {
return baseMapper.testMultiWrapperQuery(wrapper, wrapper2);
}

@Override
@Transactional(rollbackFor = RuntimeException.class)
public void testSaveOrUpdateTransactional1(List<H2User> users) {
Expand Down