From e189946190944b755c6ef590b527b359bdc138b7 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 13 Feb 2025 05:54:30 -0500 Subject: [PATCH] minor: simplify `union_extract` code --- .../functions/src/core/union_extract.rs | 29 +++++++------------ 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index d54627f73598..95814197d8df 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -18,6 +18,7 @@ use arrow::array::Array; use arrow::datatypes::{DataType, FieldRef, UnionFields}; use datafusion_common::cast::as_union_array; +use datafusion_common::utils::take_function_args; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; @@ -113,22 +114,15 @@ impl ScalarUDFImpl for UnionExtractFun { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args = args.args; + let [array, target_name] = take_function_args("union_extract", args.args)?; - if args.len() != 2 { - return exec_err!( - "union_extract expects 2 arguments, got {} instead", - args.len() - ); - } - - let target_name = match &args[1] { + let target_name = match target_name { ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name), ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"), - _ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", &args[1].data_type()), - }; + _ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", target_name.data_type()), + }?; - match &args[0] { + match array { ColumnarValue::Array(array) => { let union_array = as_union_array(&array).map_err(|_| { exec_datafusion_err!( @@ -140,19 +134,16 @@ impl ScalarUDFImpl for UnionExtractFun { Ok(ColumnarValue::Array( arrow::compute::kernels::union_extract::union_extract( union_array, - target_name?, + &target_name, )?, )) } ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => { - let target_name = target_name?; - let (target_type_id, target) = find_field(fields, target_name)?; + let (target_type_id, target) = find_field(&fields, &target_name)?; let result = match value { - Some((type_id, value)) if target_type_id == *type_id => { - *value.clone() - } - _ => ScalarValue::try_from(target.data_type())?, + Some((type_id, value)) if target_type_id == type_id => *value, + _ => ScalarValue::try_new_null(target.data_type())?, }; Ok(ColumnarValue::Scalar(result))