jayzhan211 commented on code in PR #10268: URL: https://github.com/apache/datafusion/pull/10268#discussion_r1583906718
########## datafusion/expr/src/type_coercion/binary.rs: ########## @@ -289,15 +290,164 @@ fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataT } } +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +enum TypeCategory { + Array, + Boolean, + Numeric, + // String, well-defined type, but are considered as unknown type. + DateTime, + Composite, + Unknown, + NotSupported, +} + +fn data_type_category(data_type: &DataType) -> TypeCategory { + if data_type.is_numeric() { + return TypeCategory::Numeric; + } + + if matches!(data_type, DataType::Boolean) { + return TypeCategory::Boolean; + } + + if matches!( + data_type, + DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) + ) { + return TypeCategory::Array; + } + + // String literal is possible to cast to many other types like numeric or datetime, + // therefore, it is categorized as a unknown type + if matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Null + ) { + return TypeCategory::Unknown; + } + + if matches!( + data_type, + DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Interval(_) + | DataType::Duration(_) + ) { + return TypeCategory::DateTime; + } + + if matches!( + data_type, + DataType::Dictionary(_, _) | DataType::Struct(_) | DataType::Union(_, _) + ) { + return TypeCategory::Composite; + } + + TypeCategory::NotSupported +} + +/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of constructs including +/// CASE, ARRAY, VALUES, and the GREATEST and LEAST functions. +/// See <https://www.postgresql.org/docs/current/typeconv-union-case.html> for more information. +/// The actual rules follows the behavior of Postgres and DuckDB +pub fn type_resolution(data_types: &[DataType]) -> Option<DataType> { + if data_types.is_empty() { + return None; + } + + // if all the data_types is the same return first one + if data_types.iter().all(|t| t == &data_types[0]) { + return Some(data_types[0].clone()); + } + + // if all the data_types are null, return string + if data_types.iter().all(|t| t == &DataType::Null) { + return Some(DataType::Utf8); + } + + // Ignore Nulls, if any data_type category is not the same, return None + let data_types_category: Vec<TypeCategory> = data_types + .iter() + .filter(|&t| t != &DataType::Null) + .map(data_type_category) + .collect(); + + if data_types_category + .iter() + .any(|t| t == &TypeCategory::NotSupported) + { + return None; + } + + // check if there is only one category excluding Unknown + let categories: HashSet<TypeCategory> = HashSet::from_iter( + data_types_category + .iter() + .filter(|&c| c != &TypeCategory::Unknown) + .cloned(), + ); + if categories.len() > 1 { + return None; + } + + // Ignore Nulls + let mut candidate_type: Option<DataType> = None; + for data_type in data_types.iter() { + if data_type == &DataType::Null { + continue; + } + if let Some(ref candidate_t) = candidate_type { + // Find candidate type that all the data types can be coerced to + // Follows the behavior of Postgres and DuckDB + // Coerced type may be different from the candidate and current data type + // For example, + // i64 and decimal(7, 2) are expect to get coerced type decimal(22, 2) + // numeric string ('1') and numeric (2) are expect to get coerced type numeric (1, 2) + if let Some(t) = type_resolution_coercion(data_type, candidate_t) { + candidate_type = Some(t); + } else { + return None; + } + } else { + candidate_type = Some(data_type.clone()); + } + } + + candidate_type +} + +/// See [type_resolution] for more information. +fn type_resolution_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option<DataType> { + if lhs_type == rhs_type { + return Some(lhs_type.clone()); + } + + // numeric coercion is the same as comparison coercion, both find the narrowest type + // that can accommodate both types + binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| pure_string_coercion(lhs_type, rhs_type)) + .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation +/// Unlike `coerced_from`, usually the coerced type is for comparison only. +/// For example, compare with Dictionary and Dictionary, only value type is what we care about pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { if lhs_type == rhs_type { // same type => equality is possible return Some(lhs_type.clone()); } - comparison_binary_numeric_coercion(lhs_type, rhs_type) + binary_numeric_coercion(lhs_type, rhs_type) .or_else(|| dictionary_coercion(lhs_type, rhs_type, true)) .or_else(|| temporal_coercion(lhs_type, rhs_type)) + .or_else(|| pure_string_coercion(lhs_type, rhs_type)) Review Comment: I don't think the list coercion in string coercion is correct, but avoid modifying it in this PR, I decided to pull out the string-only part ``` match (lhs_type, rhs_type) { // TODO: cast between array elements (#6558) (List(_), List(_)) => Some(lhs_type.clone()), (List(_), _) => Some(lhs_type.clone()), (_, List(_)) => Some(rhs_type.clone()), _ => None, } ``` ########## datafusion/expr/src/type_coercion/binary.rs: ########## @@ -289,15 +290,164 @@ fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataT } } +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +enum TypeCategory { + Array, + Boolean, + Numeric, + // String, well-defined type, but are considered as unknown type. + DateTime, + Composite, + Unknown, + NotSupported, +} + +fn data_type_category(data_type: &DataType) -> TypeCategory { + if data_type.is_numeric() { + return TypeCategory::Numeric; + } + + if matches!(data_type, DataType::Boolean) { + return TypeCategory::Boolean; + } + + if matches!( + data_type, + DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) + ) { + return TypeCategory::Array; + } + + // String literal is possible to cast to many other types like numeric or datetime, + // therefore, it is categorized as a unknown type + if matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Null + ) { + return TypeCategory::Unknown; + } + + if matches!( + data_type, + DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Interval(_) + | DataType::Duration(_) + ) { + return TypeCategory::DateTime; + } + + if matches!( + data_type, + DataType::Dictionary(_, _) | DataType::Struct(_) | DataType::Union(_, _) + ) { + return TypeCategory::Composite; + } + + TypeCategory::NotSupported +} + +/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of constructs including +/// CASE, ARRAY, VALUES, and the GREATEST and LEAST functions. +/// See <https://www.postgresql.org/docs/current/typeconv-union-case.html> for more information. +/// The actual rules follows the behavior of Postgres and DuckDB +pub fn type_resolution(data_types: &[DataType]) -> Option<DataType> { + if data_types.is_empty() { + return None; + } + + // if all the data_types is the same return first one + if data_types.iter().all(|t| t == &data_types[0]) { + return Some(data_types[0].clone()); + } + + // if all the data_types are null, return string + if data_types.iter().all(|t| t == &DataType::Null) { + return Some(DataType::Utf8); + } + + // Ignore Nulls, if any data_type category is not the same, return None + let data_types_category: Vec<TypeCategory> = data_types + .iter() + .filter(|&t| t != &DataType::Null) + .map(data_type_category) + .collect(); + + if data_types_category + .iter() + .any(|t| t == &TypeCategory::NotSupported) + { + return None; + } + + // check if there is only one category excluding Unknown + let categories: HashSet<TypeCategory> = HashSet::from_iter( + data_types_category + .iter() + .filter(|&c| c != &TypeCategory::Unknown) + .cloned(), + ); + if categories.len() > 1 { + return None; + } + + // Ignore Nulls + let mut candidate_type: Option<DataType> = None; + for data_type in data_types.iter() { + if data_type == &DataType::Null { + continue; + } + if let Some(ref candidate_t) = candidate_type { + // Find candidate type that all the data types can be coerced to + // Follows the behavior of Postgres and DuckDB + // Coerced type may be different from the candidate and current data type + // For example, + // i64 and decimal(7, 2) are expect to get coerced type decimal(22, 2) + // numeric string ('1') and numeric (2) are expect to get coerced type numeric (1, 2) + if let Some(t) = type_resolution_coercion(data_type, candidate_t) { + candidate_type = Some(t); + } else { + return None; + } + } else { + candidate_type = Some(data_type.clone()); + } + } + + candidate_type +} + +/// See [type_resolution] for more information. +fn type_resolution_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option<DataType> { + if lhs_type == rhs_type { + return Some(lhs_type.clone()); + } + + // numeric coercion is the same as comparison coercion, both find the narrowest type + // that can accommodate both types + binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| pure_string_coercion(lhs_type, rhs_type)) + .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation +/// Unlike `coerced_from`, usually the coerced type is for comparison only. +/// For example, compare with Dictionary and Dictionary, only value type is what we care about pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { if lhs_type == rhs_type { // same type => equality is possible return Some(lhs_type.clone()); } - comparison_binary_numeric_coercion(lhs_type, rhs_type) + binary_numeric_coercion(lhs_type, rhs_type) .or_else(|| dictionary_coercion(lhs_type, rhs_type, true)) .or_else(|| temporal_coercion(lhs_type, rhs_type)) + .or_else(|| pure_string_coercion(lhs_type, rhs_type)) Review Comment: I don't think the list coercion in string coercion is correct, but to avoid modifying it in this PR, I decided to pull out the string-only part ``` match (lhs_type, rhs_type) { // TODO: cast between array elements (#6558) (List(_), List(_)) => Some(lhs_type.clone()), (List(_), _) => Some(lhs_type.clone()), (_, List(_)) => Some(rhs_type.clone()), _ => None, } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org