Skip to content

Commit

Permalink
minor: simplify union_extract code
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Feb 13, 2025
1 parent f785f27 commit e189946
Showing 1 changed file with 10 additions and 19 deletions.
29 changes: 10 additions & 19 deletions datafusion/functions/src/core/union_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use arrow::array::Array;
use arrow::datatypes::{DataType, FieldRef, UnionFields};
use datafusion_common::cast::as_union_array;
use datafusion_common::utils::take_function_args;
use datafusion_common::{
exec_datafusion_err, exec_err, internal_err, Result, ScalarValue,
};
Expand Down Expand Up @@ -113,22 +114,15 @@ impl ScalarUDFImpl for UnionExtractFun {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;
let [array, target_name] = take_function_args("union_extract", args.args)?;

if args.len() != 2 {
return exec_err!(
"union_extract expects 2 arguments, got {} instead",
args.len()
);
}

let target_name = match &args[1] {
let target_name = match target_name {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name),
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"),
_ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", &args[1].data_type()),
};
_ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", target_name.data_type()),
}?;

match &args[0] {
match array {
ColumnarValue::Array(array) => {
let union_array = as_union_array(&array).map_err(|_| {
exec_datafusion_err!(
Expand All @@ -140,19 +134,16 @@ impl ScalarUDFImpl for UnionExtractFun {
Ok(ColumnarValue::Array(
arrow::compute::kernels::union_extract::union_extract(
union_array,
target_name?,
&target_name,
)?,
))
}
ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => {
let target_name = target_name?;
let (target_type_id, target) = find_field(fields, target_name)?;
let (target_type_id, target) = find_field(&fields, &target_name)?;

let result = match value {
Some((type_id, value)) if target_type_id == *type_id => {
*value.clone()
}
_ => ScalarValue::try_from(target.data_type())?,
Some((type_id, value)) if target_type_id == type_id => *value,
_ => ScalarValue::try_new_null(target.data_type())?,
};

Ok(ColumnarValue::Scalar(result))
Expand Down

0 comments on commit e189946

Please sign in to comment.