Skip to content

Commit

Permalink
core: enhance partition prune when comparing partition key with const…
Browse files Browse the repository at this point in the history
…ant of different types
  • Loading branch information
L-maple committed Jan 23, 2025
1 parent ba79f50 commit 5abd231
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 8 deletions.
113 changes: 113 additions & 0 deletions pkg/planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,119 @@ func TestPartitionPruningForEQ(t *testing.T) {
require.Equal(t, 0, res[0])
}

func TestCast4PartitionPruning(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)

tk.MustExec("use test")
tk.MustExec(`drop table if exists t`)
tk.MustExec(`drop table if exists t_hash`)
tk.MustExec(`drop table if exists t_sub`)
tk.MustExec(`create table t(a int, b int, c int) partition by range(a) (
partition p1 values less than (100),
partition p2 values less than (200),
partition pm values less than (MAXVALUE));`)

// test between castIntAsReal(int) and real
tk.MustQuery(`explain select * from t where a between "123" and "199";`).Check(
testkit.Rows("TableReader_7 8000.00 root partition:p2 data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 123), le(cast(test.t.a, double BINARY), 199)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t | p2 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+

// test between castIntAsDecimal(int) and decimal
tk.MustQuery(`explain select * from t where a between 123.12 and 199.99;`).Check(
testkit.Rows("TableReader_7 8000.00 root partition:p2 data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, decimal(10,0) BINARY), 123.12), le(cast(test.t.a, decimal(10,0) BINARY), 199.99)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t | p2 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+

// test between castIntAsReal(int) and real
tk.MustQuery(`explain select * from t where a between "123.12" and "199.99";`).Check(
testkit.Rows("TableReader_7 8000.00 root partition:p2 data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 123.12), le(cast(test.t.a, double BINARY), 199.99)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t | p2 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+

// test between castIntAsReal(int) and real
tk.MustQuery(`explain select * from t where a between "ddd" and "99";`).Check(
testkit.Rows("TableReader_7 8000.00 root partition:p1 data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 0), le(cast(test.t.a, double BINARY), 99)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t | p1 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+

// test between castIntAsReal(int) and real / between castIntAsReal(int) and real
tk.MustQuery(`explain select * from t where a between "123.12" and cast("199.99" as decimal);`).Check(
testkit.Rows("TableReader_7 8000.00 root partition:p2,pm data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] ge(cast(test.t.a, double BINARY), 123.12), le(cast(test.t.a, double BINARY), 200)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t | p2,pm | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+

tk.MustExec(`CREATE TABLE t_hash(a int, b int) PARTITION BY HASH(a) PARTITIONS 6`)
tk.MustExec(`insert into t_hash values(1, 1), (10, 10), (26, 26)`)
tk.MustQuery(`select * from t_hash where a = '1'`).Check(testkit.Rows("1 1"))
tk.MustQuery(`explain select * from t_hash where a = '1'`).Check(testkit.Rows(
"TableReader_7 10.00 root partition:p1 data:Selection_6",
"└─Selection_6 10.00 cop[tikv] eq(test.t_hash.a, 1)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t_hash keep order:false, stats:pseudo"))
// MySQL explain:
//+----+-------------+--------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+--------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t_hash | p1 | ALL | NULL | NULL | NULL | NULL | 1 | 100.00 | Using where |
//+----+-------------+--------+------------+------+---------------+------+---------+------+------+----------+-------------+

tk.MustExec(`create table t_ts (report_updated timestamp) partition by range(unix_timestamp(report_updated)) (
partition p1 values less than (1732982400), -- 2024-12-01 00:00:00
partition p2 values less than (1733068800), -- 2024-12-02 00:00:00
partition pm values less than (MAXVALUE));`)
tk.MustExec("insert into t_ts values('2024-11-30 00:00:00'), ('2024-12-01 00:00:00'), ('2024-12-02 00:00:00')")
tk.MustQuery("select * from t_ts where report_updated = '2024-12-01 00:00:00'").Check(testkit.Rows("2024-12-01 00:00:00"))
tk.MustQuery("explain select * from t_ts where report_updated = 20241201").Check(testkit.Rows(
"TableReader_7 10.00 root partition:p2 data:Selection_6",
"└─Selection_6 10.00 cop[tikv] eq(test.t_ts.report_updated, 2024-12-01 00:00:00)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t_ts keep order:false, stats:pseudo"))
tk.MustQuery("explain select * from t_ts where report_updated = '2024-12-01 00:00:00'").Check(testkit.Rows(
"TableReader_7 10.00 root partition:p2 data:Selection_6",
"└─Selection_6 10.00 cop[tikv] eq(test.t_ts.report_updated, 2024-12-01 00:00:00.000000)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t_ts keep order:false, stats:pseudo"))
tk.MustQuery("explain select * from t_ts where report_updated > unix_timestamp('2008-05-01 00:00:00')").Check(testkit.Rows(
"TableReader_7 8000.00 root partition:all data:Selection_6",
"└─Selection_6 8000.00 cop[tikv] gt(cast(test.t_ts.report_updated, double BINARY), 1.2095712e+09)",
" └─TableFullScan_5 10000.00 cop[tikv] table:t_ts keep order:false, stats:pseudo"))
//MysQL explain:
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
//| 1 | SIMPLE | t_ts | p1,p2,pm | ALL | NULL | NULL | NULL | NULL | 3 | 33.33 | Using where |
//+----+-------------+-------+------------+------+---------------+------+---------+------+------+----------+-------------+
}

func TestNotReadOnlySQLOnTiFlash(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
63 changes: 55 additions & 8 deletions pkg/planner/core/rule_partition_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1593,10 +1593,23 @@ func (p *rangePruner) extractDataForPrune(sctx base.PlanContext, expr expression
if arg1, ok := op.GetArgs()[1].(*expression.Constant); ok {
col, con = arg0, arg1
}
} else if arg0, ok := op.GetArgs()[1].(*expression.Column); ok && arg0.ID == p.col.ID {
if arg1, ok := op.GetArgs()[0].(*expression.Constant); ok {
} else if arg1, ok := op.GetArgs()[1].(*expression.Column); ok && arg1.ID == p.col.ID {
if arg0, ok := op.GetArgs()[0].(*expression.Constant); ok {
ret.op = opposite(ret.op)
col, con = arg0, arg1
col, con = arg1, arg0
}
} else if sarg0, ok := op.GetArgs()[0].(*expression.ScalarFunction); ok && sarg0.FuncName.L == ast.Cast {
if arg0, ok := sarg0.GetArgs()[0].(*expression.Column); ok && arg0.ID == p.col.ID {
if arg1, ok := op.GetArgs()[1].(*expression.Constant); ok {
col, con = arg0, arg1
}
}
} else if sarg1, ok := op.GetArgs()[1].(*expression.ScalarFunction); ok && sarg1.FuncName.L == ast.Cast {
if arg1, ok := sarg1.GetArgs()[0].(*expression.Column); ok && arg1.ID == p.col.ID {
if arg0, ok := op.GetArgs()[0].(*expression.Constant); ok {
ret.op = opposite(ret.op)
col, con = arg1, arg0
}
}
}
if col == nil || con == nil {
Expand All @@ -1606,6 +1619,14 @@ func (p *rangePruner) extractDataForPrune(sctx base.PlanContext, expr expression
// Current expression is 'col op const'
var constExpr expression.Expression
if p.partFn != nil {
// If arg0 or arg1 is ScalarFunction, just skip it.
// Maybe more complicated cases would be considered in the future.
_, ok1 := op.GetArgs()[0].(*expression.ScalarFunction)
_, ok2 := op.GetArgs()[1].(*expression.ScalarFunction)
if ok1 || ok2 {
return ret, false
}

// If the partition function is not monotone, only EQ condition can be pruning.
if p.monotonous == monotoneModeInvalid && ret.op != ast.EQ {
return ret, false
Expand All @@ -1626,11 +1647,37 @@ func (p *rangePruner) extractDataForPrune(sctx base.PlanContext, expr expression
// If the partition expression is col, use constExpr.
constExpr = con
}
c, isNull, err := constExpr.EvalInt(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{})
if err == nil && !isNull {
ret.c = c
ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag())
return ret, true
evalType := constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).EvalType()
if evalType == types.ETInt {
c, isNull, err := constExpr.EvalInt(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{})
if err == nil && !isNull {
ret.c = c
ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag())
return ret, true
}
} else if evalType == types.ETReal {
f, isNull, err := constExpr.EvalReal(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{})
c := int64(f)
if err == nil && !isNull {
ret.c = c
ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag())
return ret, true
}
} else if evalType == types.ETDecimal {
d, isNull, err := constExpr.EvalDecimal(sctx.GetExprCtx().GetEvalCtx(), chunk.Row{})
if err != nil {
return ret, false
}
f, err := d.ToFloat64()
if err != nil {
return ret, false
}
if err == nil && !isNull {
ret.c = int64(f)
ret.unsigned = mysql.HasUnsignedFlag(constExpr.GetType(sctx.GetExprCtx().GetEvalCtx()).GetFlag())
return ret, true
}
} else {
}
return ret, false
}
Expand Down

0 comments on commit 5abd231

Please sign in to comment.