diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 5db0f5ed5cc0..e0436946f384 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -976,6 +976,129 @@ impl ScalarValue { ) } + /// Create a Null instance of ScalarValue for this datatype + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::datatypes::DataType; + /// + /// let scalar = ScalarValue::try_new_null(&DataType::Int32).unwrap(); + /// assert_eq!(scalar.is_null(), true); + /// assert_eq!(scalar.data_type(), DataType::Int32); + /// ``` + pub fn try_new_null(data_type: &DataType) -> Result { + Ok(match data_type { + DataType::Boolean => ScalarValue::Boolean(None), + DataType::Float16 => ScalarValue::Float16(None), + DataType::Float64 => ScalarValue::Float64(None), + DataType::Float32 => ScalarValue::Float32(None), + DataType::Int8 => ScalarValue::Int8(None), + DataType::Int16 => ScalarValue::Int16(None), + DataType::Int32 => ScalarValue::Int32(None), + DataType::Int64 => ScalarValue::Int64(None), + DataType::UInt8 => ScalarValue::UInt8(None), + DataType::UInt16 => ScalarValue::UInt16(None), + DataType::UInt32 => ScalarValue::UInt32(None), + DataType::UInt64 => ScalarValue::UInt64(None), + DataType::Decimal128(precision, scale) => { + ScalarValue::Decimal128(None, *precision, *scale) + } + DataType::Decimal256(precision, scale) => { + ScalarValue::Decimal256(None, *precision, *scale) + } + DataType::Utf8 => ScalarValue::Utf8(None), + DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), + DataType::Utf8View => ScalarValue::Utf8View(None), + DataType::Binary => ScalarValue::Binary(None), + DataType::BinaryView => ScalarValue::BinaryView(None), + DataType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None), + DataType::LargeBinary => ScalarValue::LargeBinary(None), + DataType::Date32 => ScalarValue::Date32(None), + DataType::Date64 => ScalarValue::Date64(None), + DataType::Time32(TimeUnit::Second) => ScalarValue::Time32Second(None), + DataType::Time32(TimeUnit::Millisecond) => { + ScalarValue::Time32Millisecond(None) + } + DataType::Time64(TimeUnit::Microsecond) => { + ScalarValue::Time64Microsecond(None) + } + DataType::Time64(TimeUnit::Nanosecond) => ScalarValue::Time64Nanosecond(None), + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + ScalarValue::TimestampSecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + ScalarValue::TimestampMillisecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + ScalarValue::TimestampNanosecond(None, tz_opt.clone()) + } + DataType::Interval(IntervalUnit::YearMonth) => { + ScalarValue::IntervalYearMonth(None) + } + DataType::Interval(IntervalUnit::DayTime) => { + ScalarValue::IntervalDayTime(None) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + ScalarValue::IntervalMonthDayNano(None) + } + DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None), + DataType::Duration(TimeUnit::Millisecond) => { + ScalarValue::DurationMillisecond(None) + } + DataType::Duration(TimeUnit::Microsecond) => { + ScalarValue::DurationMicrosecond(None) + } + DataType::Duration(TimeUnit::Nanosecond) => { + ScalarValue::DurationNanosecond(None) + } + DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary( + index_type.clone(), + Box::new(value_type.as_ref().try_into()?), + ), + // `ScalaValue::List` contains single element `ListArray`. + DataType::List(field_ref) => ScalarValue::List(Arc::new( + GenericListArray::new_null(Arc::clone(field_ref), 1), + )), + // `ScalarValue::LargeList` contains single element `LargeListArray`. + DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new( + GenericListArray::new_null(Arc::clone(field_ref), 1), + )), + // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. + DataType::FixedSizeList(field_ref, fixed_length) => { + ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null( + Arc::clone(field_ref), + *fixed_length, + 1, + ))) + } + DataType::Struct(fields) => ScalarValue::Struct( + new_null_array(&DataType::Struct(fields.to_owned()), 1) + .as_struct() + .to_owned() + .into(), + ), + DataType::Map(fields, sorted) => ScalarValue::Map( + new_null_array(&DataType::Map(fields.to_owned(), sorted.to_owned()), 1) + .as_map() + .to_owned() + .into(), + ), + DataType::Union(fields, mode) => { + ScalarValue::Union(None, fields.clone(), *mode) + } + DataType::Null => ScalarValue::Null, + _ => { + return _not_impl_err!( + "Can't create a null scalar from data_type \"{data_type:?}\"" + ); + } + }) + } + /// Returns a [`ScalarValue::Utf8`] representing `val` pub fn new_utf8(val: impl Into) -> Self { ScalarValue::from(val.into()) @@ -3457,115 +3580,7 @@ impl TryFrom<&DataType> for ScalarValue { /// Create a Null instance of ScalarValue for this datatype fn try_from(data_type: &DataType) -> Result { - Ok(match data_type { - DataType::Boolean => ScalarValue::Boolean(None), - DataType::Float16 => ScalarValue::Float16(None), - DataType::Float64 => ScalarValue::Float64(None), - DataType::Float32 => ScalarValue::Float32(None), - DataType::Int8 => ScalarValue::Int8(None), - DataType::Int16 => ScalarValue::Int16(None), - DataType::Int32 => ScalarValue::Int32(None), - DataType::Int64 => ScalarValue::Int64(None), - DataType::UInt8 => ScalarValue::UInt8(None), - DataType::UInt16 => ScalarValue::UInt16(None), - DataType::UInt32 => ScalarValue::UInt32(None), - DataType::UInt64 => ScalarValue::UInt64(None), - DataType::Decimal128(precision, scale) => { - ScalarValue::Decimal128(None, *precision, *scale) - } - DataType::Decimal256(precision, scale) => { - ScalarValue::Decimal256(None, *precision, *scale) - } - DataType::Utf8 => ScalarValue::Utf8(None), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), - DataType::Utf8View => ScalarValue::Utf8View(None), - DataType::Binary => ScalarValue::Binary(None), - DataType::BinaryView => ScalarValue::BinaryView(None), - DataType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None), - DataType::LargeBinary => ScalarValue::LargeBinary(None), - DataType::Date32 => ScalarValue::Date32(None), - DataType::Date64 => ScalarValue::Date64(None), - DataType::Time32(TimeUnit::Second) => ScalarValue::Time32Second(None), - DataType::Time32(TimeUnit::Millisecond) => { - ScalarValue::Time32Millisecond(None) - } - DataType::Time64(TimeUnit::Microsecond) => { - ScalarValue::Time64Microsecond(None) - } - DataType::Time64(TimeUnit::Nanosecond) => ScalarValue::Time64Nanosecond(None), - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - ScalarValue::TimestampSecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { - ScalarValue::TimestampMillisecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { - ScalarValue::TimestampMicrosecond(None, tz_opt.clone()) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { - ScalarValue::TimestampNanosecond(None, tz_opt.clone()) - } - DataType::Interval(IntervalUnit::YearMonth) => { - ScalarValue::IntervalYearMonth(None) - } - DataType::Interval(IntervalUnit::DayTime) => { - ScalarValue::IntervalDayTime(None) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - ScalarValue::IntervalMonthDayNano(None) - } - DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(None), - DataType::Duration(TimeUnit::Millisecond) => { - ScalarValue::DurationMillisecond(None) - } - DataType::Duration(TimeUnit::Microsecond) => { - ScalarValue::DurationMicrosecond(None) - } - DataType::Duration(TimeUnit::Nanosecond) => { - ScalarValue::DurationNanosecond(None) - } - DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary( - index_type.clone(), - Box::new(value_type.as_ref().try_into()?), - ), - // `ScalaValue::List` contains single element `ListArray`. - DataType::List(field_ref) => ScalarValue::List(Arc::new( - GenericListArray::new_null(Arc::clone(field_ref), 1), - )), - // `ScalarValue::LargeList` contains single element `LargeListArray`. - DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new( - GenericListArray::new_null(Arc::clone(field_ref), 1), - )), - // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. - DataType::FixedSizeList(field_ref, fixed_length) => { - ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null( - Arc::clone(field_ref), - *fixed_length, - 1, - ))) - } - DataType::Struct(fields) => ScalarValue::Struct( - new_null_array(&DataType::Struct(fields.to_owned()), 1) - .as_struct() - .to_owned() - .into(), - ), - DataType::Map(fields, sorted) => ScalarValue::Map( - new_null_array(&DataType::Map(fields.to_owned(), sorted.to_owned()), 1) - .as_map() - .to_owned() - .into(), - ), - DataType::Union(fields, mode) => { - ScalarValue::Union(None, fields.clone(), *mode) - } - DataType::Null => ScalarValue::Null, - _ => { - return _not_impl_err!( - "Can't create a scalar from data_type \"{data_type:?}\"" - ); - } - }) + Self::try_new_null(data_type) } } @@ -7269,4 +7284,88 @@ mod tests { let dictionary_array = dictionary_scalar.to_array().unwrap(); assert!(dictionary_array.is_null(0)); } + + #[test] + fn test_scalar_value_try_new_null() { + let scalars = vec![ + ScalarValue::try_new_null(&DataType::Boolean).unwrap(), + ScalarValue::try_new_null(&DataType::Int8).unwrap(), + ScalarValue::try_new_null(&DataType::Int16).unwrap(), + ScalarValue::try_new_null(&DataType::Int32).unwrap(), + ScalarValue::try_new_null(&DataType::Int64).unwrap(), + ScalarValue::try_new_null(&DataType::UInt8).unwrap(), + ScalarValue::try_new_null(&DataType::UInt16).unwrap(), + ScalarValue::try_new_null(&DataType::UInt32).unwrap(), + ScalarValue::try_new_null(&DataType::UInt64).unwrap(), + ScalarValue::try_new_null(&DataType::Float16).unwrap(), + ScalarValue::try_new_null(&DataType::Float32).unwrap(), + ScalarValue::try_new_null(&DataType::Float64).unwrap(), + ScalarValue::try_new_null(&DataType::Decimal128(42, 42)).unwrap(), + ScalarValue::try_new_null(&DataType::Decimal256(42, 42)).unwrap(), + ScalarValue::try_new_null(&DataType::Utf8).unwrap(), + ScalarValue::try_new_null(&DataType::LargeUtf8).unwrap(), + ScalarValue::try_new_null(&DataType::Utf8View).unwrap(), + ScalarValue::try_new_null(&DataType::Binary).unwrap(), + ScalarValue::try_new_null(&DataType::BinaryView).unwrap(), + ScalarValue::try_new_null(&DataType::FixedSizeBinary(42)).unwrap(), + ScalarValue::try_new_null(&DataType::LargeBinary).unwrap(), + ScalarValue::try_new_null(&DataType::Date32).unwrap(), + ScalarValue::try_new_null(&DataType::Date64).unwrap(), + ScalarValue::try_new_null(&DataType::Time32(TimeUnit::Second)).unwrap(), + ScalarValue::try_new_null(&DataType::Time32(TimeUnit::Millisecond)).unwrap(), + ScalarValue::try_new_null(&DataType::Time64(TimeUnit::Microsecond)).unwrap(), + ScalarValue::try_new_null(&DataType::Time64(TimeUnit::Nanosecond)).unwrap(), + ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Second, None)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Millisecond, None)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Microsecond, None)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Timestamp(TimeUnit::Nanosecond, None)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Interval(IntervalUnit::YearMonth)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Interval(IntervalUnit::DayTime)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Interval(IntervalUnit::MonthDayNano)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Duration(TimeUnit::Second)).unwrap(), + ScalarValue::try_new_null(&DataType::Duration(TimeUnit::Microsecond)) + .unwrap(), + ScalarValue::try_new_null(&DataType::Duration(TimeUnit::Nanosecond)).unwrap(), + ScalarValue::try_new_null(&DataType::Null).unwrap(), + ]; + assert!(scalars.iter().all(|s| s.is_null())); + + let field_ref = Arc::new(Field::new("foo", DataType::Int32, true)); + let map_field_ref = Arc::new(Field::new( + "foo", + DataType::Struct(Fields::from(vec![ + Field::new("bar", DataType::Utf8, true), + Field::new("baz", DataType::Int32, true), + ])), + true, + )); + let scalars = vec![ + ScalarValue::try_new_null(&DataType::List(Arc::clone(&field_ref))).unwrap(), + ScalarValue::try_new_null(&DataType::LargeList(Arc::clone(&field_ref))) + .unwrap(), + ScalarValue::try_new_null(&DataType::FixedSizeList( + Arc::clone(&field_ref), + 42, + )) + .unwrap(), + ScalarValue::try_new_null(&DataType::Struct( + vec![Arc::clone(&field_ref)].into(), + )) + .unwrap(), + ScalarValue::try_new_null(&DataType::Map(map_field_ref, false)).unwrap(), + ScalarValue::try_new_null(&DataType::Union( + UnionFields::new(vec![42], vec![field_ref]), + UnionMode::Dense, + )) + .unwrap(), + ]; + assert!(scalars.iter().all(|s| s.is_null())); + } }