diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index da0aa5f12fde..2bc4721f3cfa 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -20,14 +20,15 @@ use crate::utils::make_scalar_function; use arrow::array::{Capacities, MutableArrayData}; use arrow::compute; +use arrow::compute::cast; use arrow_array::{ - new_null_array, Array, ArrayRef, GenericListArray, Int64Array, ListArray, - OffsetSizeTrait, + new_null_array, Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, + UInt64Array, }; use arrow_buffer::OffsetBuffer; use arrow_schema::DataType::{LargeList, List}; use arrow_schema::{DataType, Field}; -use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; +use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -86,7 +87,7 @@ impl Default for ArrayRepeat { impl ArrayRepeat { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), aliases: vec![String::from("list_repeat")], } } @@ -124,6 +125,30 @@ impl ScalarUDFImpl for ArrayRepeat { &self.aliases } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return exec_err!("array_repeat expects two arguments"); + } + + let element_type = &arg_types[0]; + let first = element_type.clone(); + + let count_type = &arg_types[1]; + + // Coerce the second argument to Int64/UInt64 if it's a numeric type + let second = match count_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + DataType::Int64 + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + DataType::UInt64 + } + _ => return exec_err!("count must be an integer type"), + }; + + Ok(vec![first, second]) + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -131,12 +156,16 @@ impl ScalarUDFImpl for ArrayRepeat { /// Array_repeat SQL function pub fn array_repeat_inner(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_repeat expects two arguments"); - } - let element = &args[0]; - let count_array = as_int64_array(&args[1])?; + let count_array = &args[1]; + + let count_array = match count_array.data_type() { + DataType::Int64 => &cast(count_array, &DataType::UInt64)?, + DataType::UInt64 => count_array, + _ => return exec_err!("count must be an integer type"), + }; + + let count_array = as_uint64_array(count_array)?; match element.data_type() { List(_) => { @@ -165,7 +194,7 @@ pub fn array_repeat_inner(args: &[ArrayRef]) -> Result { /// ``` fn general_repeat( array: &ArrayRef, - count_array: &Int64Array, + count_array: &UInt64Array, ) -> Result { let data_type = array.data_type(); let mut new_values = vec![]; @@ -219,7 +248,7 @@ fn general_repeat( /// ``` fn general_list_repeat( list_array: &GenericListArray, - count_array: &Int64Array, + count_array: &UInt64Array, ) -> Result { let data_type = list_array.data_type(); let value_type = list_array.value_type(); diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 22a85eb15512..baf4ef7795e7 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2760,6 +2760,30 @@ select ---- [[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[NULL, NULL], [NULL, NULL], [NULL, NULL]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] +# array_repeat scalar function with count of different integer types +query ???????? +Select + array_repeat(1, arrow_cast(2,'Int8')), + array_repeat(2, arrow_cast(2,'Int16')), + array_repeat(3, arrow_cast(2,'Int32')), + array_repeat(4, arrow_cast(2,'Int64')), + array_repeat(1, arrow_cast(2,'UInt8')), + array_repeat(2, arrow_cast(2,'UInt16')), + array_repeat(3, arrow_cast(2,'UInt32')), + array_repeat(4, arrow_cast(2,'UInt64')); +---- +[1, 1] [2, 2] [3, 3] [4, 4] [1, 1] [2, 2] [3, 3] [4, 4] + +# array_repeat scalar function with count of negative integer types +query ???? +Select + array_repeat(1, arrow_cast(-2,'Int8')), + array_repeat(2, arrow_cast(-2,'Int16')), + array_repeat(3, arrow_cast(-2,'Int32')), + array_repeat(4, arrow_cast(-2,'Int64')); +---- +[] [] [] [] + # array_repeat with columns #1 statement ok