Skip to content

Commit

Permalink
clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jan 30, 2025
1 parent ecabac1 commit ac21c83
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 72 deletions.
52 changes: 1 addition & 51 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ use arrow::datatypes::{
DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};
use datafusion_common::types::NativeType;
use datafusion_common::{
exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err, Result,
};
use datafusion_common::{exec_err, internal_err, plan_datafusion_err, plan_err, Result};
use itertools::Itertools;

/// The type signature of an instantiation of binary operator expression such as
Expand Down Expand Up @@ -869,54 +867,6 @@ fn get_wider_decimal_type(
}
}

/// Returns the wider type among arguments `lhs` and `rhs`.
/// The wider type is the type that can safely represent values from both types
/// without information loss. Returns an Error if types are incompatible.
pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result<DataType> {
use arrow::datatypes::DataType::*;
Ok(match (lhs, rhs) {
(lhs, rhs) if lhs == rhs => lhs.clone(),
// Right UInt is larger than left UInt.
(UInt8, UInt16 | UInt32 | UInt64) | (UInt16, UInt32 | UInt64) | (UInt32, UInt64) |
// Right Int is larger than left Int.
(Int8, Int16 | Int32 | Int64) | (Int16, Int32 | Int64) | (Int32, Int64) |
// Right Float is larger than left Float.
(Float16, Float32 | Float64) | (Float32, Float64) |
// Right String is larger than left String.
(Utf8, LargeUtf8) |
// Any right type is wider than a left hand side Null.
(Null, _) => rhs.clone(),
// Left UInt is larger than right UInt.
(UInt16 | UInt32 | UInt64, UInt8) | (UInt32 | UInt64, UInt16) | (UInt64, UInt32) |
// Left Int is larger than right Int.
(Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) |
// Left Float is larger than right Float.
(Float32 | Float64, Float16) | (Float64, Float32) |
// Left String is larger than right String.
(LargeUtf8, Utf8) |
// Any left type is wider than a right hand side Null.
(_, Null) => lhs.clone(),
(List(lhs_field), List(rhs_field)) => {
let field_type =
get_wider_type(lhs_field.data_type(), rhs_field.data_type())?;
if lhs_field.name() != rhs_field.name() {
return Err(exec_datafusion_err!(
"There is no wider type that can represent both {lhs} and {rhs}."
));
}
assert_eq!(lhs_field.name(), rhs_field.name());
let field_name = lhs_field.name();
let nullable = lhs_field.is_nullable() | rhs_field.is_nullable();
List(Arc::new(Field::new(field_name, field_type, nullable)))
}
(_, _) => {
return Err(exec_datafusion_err!(
"There is no wider type that can represent both {lhs} and {rhs}."
));
}
})
}

/// Convert the numeric data type to the decimal data type.
/// We support signed and unsigned integer types and floating-point type.
fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
Expand Down
45 changes: 26 additions & 19 deletions datafusion/functions-nested/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion_common::{
cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims,
};
use datafusion_expr::{
type_coercion::binary::get_wider_type, ColumnarValue, Documentation, ScalarUDFImpl,
ColumnarValue, Documentation, ScalarUDFImpl,
Signature, Volatility,
};
use datafusion_macros::user_doc;
Expand Down Expand Up @@ -276,25 +276,32 @@ impl ScalarUDFImpl for ArrayConcat {
let mut expr_type = DataType::Null;
let mut max_dims = 0;
for arg_type in arg_types {
match arg_type {
DataType::List(field) => {
if !field.data_type().equals_datatype(&DataType::Null) {
let dims = list_ndims(arg_type);
expr_type = match max_dims.cmp(&dims) {
Ordering::Greater => expr_type,
Ordering::Equal => get_wider_type(&expr_type, arg_type)?,
Ordering::Less => {
max_dims = dims;
arg_type.clone()
}
};
let DataType::List(field) = arg_type else {
return plan_err!(
"The array_concat function can only accept list as the args."
);
};
if !field.data_type().equals_datatype(&DataType::Null) {
let dims = list_ndims(arg_type);
expr_type = match max_dims.cmp(&dims) {
Ordering::Greater => expr_type,
Ordering::Equal => {
if expr_type == DataType::Null {
arg_type.clone()
} else if !expr_type.equals_datatype(arg_type) {
return plan_err!(
"It is not possible to concatenate arrays of different types. Expected: {}, got: {}", expr_type, arg_type
);
} else {
expr_type
}
}
}
_ => {
return plan_err!(
"The array_concat function can only accept list as the args."
)
}

Ordering::Less => {
max_dims = dims;
arg_type.clone()
}
};
}
}

Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2876,14 +2876,14 @@ select array_concat(
[1, 2, 3]

# Concatenating Mixed types (doesn't work)
query error DataFusion error: Arrow error: Invalid argument error: It is not possible to concatenate arrays of different data types\.
query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: LargeUtf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)
select array_concat(
[arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')],
[arrow_cast('3', 'LargeUtf8')]
);

# Concatenating Mixed types (doesn't work)
query error DataFusion error: Execution error: There is no wider type that can represent both Utf8 and Utf8View\.
query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)
select array_concat(
[arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')],
[arrow_cast('3', 'Utf8View')]
Expand Down

0 comments on commit ac21c83

Please sign in to comment.