This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 658206a668 Support fixed_size_list for make_array (#6759)
658206a668 is described below
commit 658206a66825e229304ee744715a88908d281b9b
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Jul 6 01:35:20 2023 +0800
Support fixed_size_list for make_array (#6759)
* support make_array for fixed_size_list
Signed-off-by: jayzhan211 <[email protected]>
* add arrow-typeof in test
Signed-off-by: jayzhan211 <[email protected]>
* fix schema mismatch
Signed-off-by: jayzhan211 <[email protected]>
* cleanup code
Signed-off-by: jayzhan211 <[email protected]>
* create array data with correct len
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/common/src/scalar.rs | 62 ++++++++-------
.../core/tests/data/fixed_size_list_array.parquet | Bin 0 -> 718 bytes
.../core/tests/sqllogictests/test_files/array.slt | 39 +++++++++-
datafusion/optimizer/src/analyzer/type_coercion.rs | 86 +++++++++++++++++++--
datafusion/physical-expr/src/array_expressions.rs | 4 +-
datafusion/proto/src/logical_plan/to_proto.rs | 4 +
datafusion/sql/src/expr/arrow_cast.rs | 19 ++++-
7 files changed, 174 insertions(+), 40 deletions(-)
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 4fef60020f..b0769df1e9 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -101,7 +101,9 @@ pub enum ScalarValue {
FixedSizeBinary(i32, Option<Vec<u8>>),
/// large binary
LargeBinary(Option<Vec<u8>>),
- /// list of nested ScalarValue
+ /// Fixed size list of nested ScalarValue
+ Fixedsizelist(Option<Vec<ScalarValue>>, FieldRef, i32),
+ /// List of nested ScalarValue
List(Option<Vec<ScalarValue>>, FieldRef),
/// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01
Date32(Option<i32>),
@@ -196,6 +198,10 @@ impl PartialEq for ScalarValue {
(FixedSizeBinary(_, _), _) => false,
(LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2),
(LargeBinary(_), _) => false,
+ (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => {
+ v1.eq(v2) && t1.eq(t2) && l1.eq(l2)
+ }
+ (Fixedsizelist(_, _, _), _) => false,
(List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2),
(List(_, _), _) => false,
(Date32(v1), Date32(v2)) => v1.eq(v2),
@@ -315,6 +321,14 @@ impl PartialOrd for ScalarValue {
(FixedSizeBinary(_, _), _) => None,
(LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2),
(LargeBinary(_), _) => None,
+ (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => {
+ if t1.eq(t2) && l1.eq(l2) {
+ v1.partial_cmp(v2)
+ } else {
+ None
+ }
+ }
+ (Fixedsizelist(_, _, _), _) => None,
(List(v1, t1), List(v2, t2)) => {
if t1.eq(t2) {
v1.partial_cmp(v2)
@@ -1518,6 +1532,11 @@ impl std::hash::Hash for ScalarValue {
Binary(v) => v.hash(state),
FixedSizeBinary(_, v) => v.hash(state),
LargeBinary(v) => v.hash(state),
+ Fixedsizelist(v, t, l) => {
+ v.hash(state);
+ t.hash(state);
+ l.hash(state);
+ }
List(v, t) => {
v.hash(state);
t.hash(state);
@@ -1994,6 +2013,10 @@ impl ScalarValue {
ScalarValue::Binary(_) => DataType::Binary,
ScalarValue::FixedSizeBinary(sz, _) =>
DataType::FixedSizeBinary(*sz),
ScalarValue::LargeBinary(_) => DataType::LargeBinary,
+ ScalarValue::Fixedsizelist(_, field, length) =>
DataType::FixedSizeList(
+ Arc::new(Field::new("item", field.data_type().clone(), true)),
+ *length,
+ ),
ScalarValue::List(_, field) => DataType::List(Arc::new(Field::new(
"item",
field.data_type().clone(),
@@ -2142,6 +2165,7 @@ impl ScalarValue {
ScalarValue::Binary(v) => v.is_none(),
ScalarValue::FixedSizeBinary(_, v) => v.is_none(),
ScalarValue::LargeBinary(v) => v.is_none(),
+ ScalarValue::Fixedsizelist(v, ..) => v.is_none(),
ScalarValue::List(v, _) => v.is_none(),
ScalarValue::Date32(v) => v.is_none(),
ScalarValue::Date64(v) => v.is_none(),
@@ -2847,6 +2871,9 @@ impl ScalarValue {
.collect::<LargeBinaryArray>(),
),
},
+ ScalarValue::Fixedsizelist(..) => {
+ unimplemented!("FixedSizeList is not supported yet")
+ }
ScalarValue::List(values, field) => Arc::new(match
field.data_type() {
DataType::Boolean => build_list!(BooleanBuilder, Boolean,
values, size),
DataType::Int8 => build_list!(Int8Builder, Int8, values, size),
@@ -3294,6 +3321,7 @@ impl ScalarValue {
ScalarValue::LargeBinary(val) => {
eq_array_primitive!(array, index, LargeBinaryArray, val)
}
+ ScalarValue::Fixedsizelist(..) => unimplemented!(),
ScalarValue::List(_, _) => unimplemented!(),
ScalarValue::Date32(val) => {
eq_array_primitive!(array, index, Date32Array, val)
@@ -3414,7 +3442,8 @@ impl ScalarValue {
| ScalarValue::LargeBinary(b) => {
b.as_ref().map(|b| b.capacity()).unwrap_or_default()
}
- ScalarValue::List(vals, field) => {
+ ScalarValue::Fixedsizelist(vals, field, _)
+ | ScalarValue::List(vals, field) => {
vals.as_ref()
.map(|vals| Self::size_of_vec(vals) -
std::mem::size_of_val(vals))
.unwrap_or_default()
@@ -3732,29 +3761,9 @@ impl fmt::Display for ScalarValue {
ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?,
ScalarValue::Utf8(e) => format_option!(f, e)?,
ScalarValue::LargeUtf8(e) => format_option!(f, e)?,
- ScalarValue::Binary(e) => match e {
- Some(l) => write!(
- f,
- "{}",
- l.iter()
- .map(|v| format!("{v}"))
- .collect::<Vec<_>>()
- .join(",")
- )?,
- None => write!(f, "NULL")?,
- },
- ScalarValue::FixedSizeBinary(_, e) => match e {
- Some(l) => write!(
- f,
- "{}",
- l.iter()
- .map(|v| format!("{v}"))
- .collect::<Vec<_>>()
- .join(",")
- )?,
- None => write!(f, "NULL")?,
- },
- ScalarValue::LargeBinary(e) => match e {
+ ScalarValue::Binary(e)
+ | ScalarValue::FixedSizeBinary(_, e)
+ | ScalarValue::LargeBinary(e) => match e {
Some(l) => write!(
f,
"{}",
@@ -3765,7 +3774,7 @@ impl fmt::Display for ScalarValue {
)?,
None => write!(f, "NULL")?,
},
- ScalarValue::List(e, _) => match e {
+ ScalarValue::Fixedsizelist(e, ..) | ScalarValue::List(e, _) =>
match e {
Some(l) => write!(
f,
"{}",
@@ -3849,6 +3858,7 @@ impl fmt::Debug for ScalarValue {
}
ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"),
ScalarValue::LargeBinary(Some(_)) => write!(f,
"LargeBinary(\"{self}\")"),
+ ScalarValue::Fixedsizelist(..) => write!(f,
"FixedSizeList([{self}])"),
ScalarValue::List(_, _) => write!(f, "List([{self}])"),
ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"),
ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"),
diff --git a/datafusion/core/tests/data/fixed_size_list_array.parquet
b/datafusion/core/tests/data/fixed_size_list_array.parquet
new file mode 100644
index 0000000000..aafc5ce62f
Binary files /dev/null and
b/datafusion/core/tests/data/fixed_size_list_array.parquet differ
diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt
b/datafusion/core/tests/sqllogictests/test_files/array.slt
index 0d99e6cbb3..1f43c5f8e1 100644
--- a/datafusion/core/tests/sqllogictests/test_files/array.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/array.slt
@@ -417,8 +417,6 @@ select make_array(x, y) from foo2;
# array_contains
-
-
# array_contains scalar function #1
query BBB rowsort
select array_contains(make_array(1, 2, 3), make_array(1, 1, 2, 3)),
array_contains([1, 2, 3], [1, 1, 2]), array_contains([1, 2, 3], [2, 1, 3, 1]);
@@ -531,3 +529,40 @@ SELECT
FROM t
----
true true
+
+statement ok
+CREATE EXTERNAL TABLE fixed_size_list_array STORED AS PARQUET LOCATION
'tests/data/fixed_size_list_array.parquet';
+
+query T
+select arrow_typeof(f0) from fixed_size_list_array;
+----
+FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id:
0, dict_is_ordered: false, metadata: {} }, 2)
+FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id:
0, dict_is_ordered: false, metadata: {} }, 2)
+
+query ?
+select * from fixed_size_list_array;
+----
+[1, 2]
+[3, 4]
+
+query ?
+select f0 from fixed_size_list_array;
+----
+[1, 2]
+[3, 4]
+
+query ?
+select arrow_cast(f0, 'List(Int64)') from fixed_size_list_array;
+----
+[1, 2]
+[3, 4]
+
+query ?
+select make_array(arrow_cast(f0, 'List(Int64)')) from fixed_size_list_array
+----
+[[1, 2], [3, 4]]
+
+query ?
+select make_array(f0) from fixed_size_list_array
+----
+[[1, 2], [3, 4]]
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 5d1fef5352..7cf4a233f7 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -330,8 +330,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
&self.schema,
&fun.signature,
)?;
- let expr = Expr::ScalarUDF(ScalarUDF::new(fun, new_expr));
- Ok(expr)
+ Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr)))
}
Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let new_args = coerce_arguments_for_signature(
@@ -520,7 +519,7 @@ fn coerce_window_frame(
fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) ->
Result<Expr> {
let left_type = expr.get_type(schema)?;
get_input_types(&left_type, &Operator::IsDistinctFrom,
&DataType::Boolean)?;
- expr.clone().cast_to(&DataType::Boolean, schema)
+ cast_expr(expr, &DataType::Boolean, schema)
}
/// Returns `expressions` coerced to types compatible with
@@ -559,6 +558,25 @@ fn coerce_arguments_for_fun(
return Ok(vec![]);
}
+ let mut expressions: Vec<Expr> = expressions.to_vec();
+
+ // Cast Fixedsizelist to List for array functions
+ if *fun == BuiltinScalarFunction::MakeArray {
+ expressions = expressions
+ .into_iter()
+ .map(|expr| {
+ let data_type = expr.get_type(schema).unwrap();
+ if let DataType::FixedSizeList(field, _) = data_type {
+ let field = field.as_ref().clone();
+ let to_type = DataType::List(Arc::new(field));
+ expr.cast_to(&to_type, schema)
+ } else {
+ Ok(expr)
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+ }
+
if *fun == BuiltinScalarFunction::MakeArray {
// Find the final data type for the function arguments
let current_types = expressions
@@ -579,8 +597,7 @@ fn coerce_arguments_for_fun(
.map(|(expr, from_type)| cast_array_expr(expr, &from_type,
&new_type, schema))
.collect();
}
-
- Ok(expressions.to_vec())
+ Ok(expressions)
}
/// Cast `expr` to the specified type, if possible
@@ -598,7 +615,7 @@ fn cast_array_expr(
if from_type.equals_datatype(&DataType::Null) {
Ok(expr.clone())
} else {
- expr.clone().cast_to(to_type, schema)
+ cast_expr(expr, to_type, schema)
}
}
@@ -625,7 +642,7 @@ fn coerce_agg_exprs_for_signature(
input_exprs
.iter()
.enumerate()
- .map(|(i, expr)| expr.clone().cast_to(&coerced_types[i], schema))
+ .map(|(i, expr)| cast_expr(expr, &coerced_types[i], schema))
.collect::<Result<Vec<_>>>()
}
@@ -746,6 +763,7 @@ mod test {
use arrow::datatypes::{DataType, TimeUnit};
+ use arrow::datatypes::Field;
use datafusion_common::tree_node::TreeNode;
use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result,
ScalarValue};
use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
@@ -763,7 +781,7 @@ mod test {
use datafusion_physical_expr::expressions::AvgAccumulator;
use crate::analyzer::type_coercion::{
- coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
+ cast_expr, coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
};
use crate::test::assert_analyzed_plan_eq;
@@ -1220,6 +1238,58 @@ mod test {
Ok(())
}
+ #[test]
+ fn test_casting_for_fixed_size_list() -> Result<()> {
+ let val = lit(ScalarValue::Fixedsizelist(
+ Some(vec![
+ ScalarValue::from(1i32),
+ ScalarValue::from(2i32),
+ ScalarValue::from(3i32),
+ ]),
+ Arc::new(Field::new("item", DataType::Int32, true)),
+ 3,
+ ));
+ let expr = Expr::ScalarFunction(ScalarFunction {
+ fun: BuiltinScalarFunction::MakeArray,
+ args: vec![val.clone()],
+ });
+ let schema = Arc::new(DFSchema::new_with_metadata(
+ vec![DFField::new_unqualified(
+ "item",
+ DataType::FixedSizeList(
+ Arc::new(Field::new("a", DataType::Int32, true)),
+ 3,
+ ),
+ true,
+ )],
+ std::collections::HashMap::new(),
+ )?);
+ let mut rewriter = TypeCoercionRewriter { schema };
+ let result = expr.rewrite(&mut rewriter)?;
+
+ let schema = Arc::new(DFSchema::new_with_metadata(
+ vec![DFField::new_unqualified(
+ "item",
+ DataType::List(Arc::new(Field::new("a", DataType::Int32,
true))),
+ true,
+ )],
+ std::collections::HashMap::new(),
+ )?);
+ let expected_casted_expr = cast_expr(
+ &val,
+ &DataType::List(Arc::new(Field::new("item", DataType::Int32,
true))),
+ &schema,
+ )?;
+
+ let expected = Expr::ScalarFunction(ScalarFunction {
+ fun: BuiltinScalarFunction::MakeArray,
+ args: vec![expected_casted_expr],
+ });
+
+ assert_eq!(result, expected);
+ Ok(())
+ }
+
#[test]
fn test_type_coercion_rewrite() -> Result<()> {
// gt
diff --git a/datafusion/physical-expr/src/array_expressions.rs
b/datafusion/physical-expr/src/array_expressions.rs
index 911c94b06d..bddeef526a 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -111,7 +111,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) ->
Result<ArrayRef> {
DataType::List(..) => {
let arrays =
downcast_vec!(args,
ListArray).collect::<Result<Vec<&ListArray>>>()?;
- let len: i32 = arrays.len() as i32;
+ let len = arrays.iter().map(|arr| arr.len() as i32).sum();
let capacity =
Capacities::Array(arrays.iter().map(|a|
a.get_array_memory_size()).sum());
let array_data: Vec<_> =
@@ -125,7 +125,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) ->
Result<ArrayRef> {
}
let list_data_type =
- DataType::List(Arc::new(Field::new("item", data_type, false)));
+ DataType::List(Arc::new(Field::new("item", data_type, true)));
let list_data = ArrayData::builder(list_data_type)
.len(1)
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index a046be35d4..4a4b16db80 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1068,6 +1068,10 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
Value::LargeUtf8Value(s.to_owned())
})
}
+ ScalarValue::Fixedsizelist(..) => Err(Error::General(
+ "Proto serialization error: ScalarValue::Fixedsizelist not
supported"
+ .to_string(),
+ )),
ScalarValue::List(values, boxed_field) => {
let is_null = values.is_none();
diff --git a/datafusion/sql/src/expr/arrow_cast.rs
b/datafusion/sql/src/expr/arrow_cast.rs
index 91a42f4736..46957a9cdd 100644
--- a/datafusion/sql/src/expr/arrow_cast.rs
+++ b/datafusion/sql/src/expr/arrow_cast.rs
@@ -18,9 +18,9 @@
//! Implementation of the `arrow_cast` function that allows
//! casting to arbitrary arrow types (rather than SQL types)
-use std::{fmt::Display, iter::Peekable, str::Chars};
+use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc};
-use arrow_schema::{DataType, IntervalUnit, TimeUnit};
+use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::{Expr, ExprSchemable};
@@ -150,6 +150,7 @@ impl<'a> Parser<'a> {
Token::Decimal128 => self.parse_decimal_128(),
Token::Decimal256 => self.parse_decimal_256(),
Token::Dictionary => self.parse_dictionary(),
+ Token::List => self.parse_list(),
tok => Err(make_error(
self.val,
&format!("finding next type, got unexpected '{tok}'"),
@@ -157,6 +158,16 @@ impl<'a> Parser<'a> {
}
}
+ /// Parses the List type
+ fn parse_list(&mut self) -> Result<DataType> {
+ self.expect_token(Token::LParen)?;
+ let data_type = self.parse_next_type()?;
+ self.expect_token(Token::RParen)?;
+ Ok(DataType::List(Arc::new(Field::new(
+ "item", data_type, true,
+ ))))
+ }
+
/// Parses the next timeunit
fn parse_time_unit(&mut self, context: &str) -> Result<TimeUnit> {
match self.next_token()? {
@@ -486,6 +497,8 @@ impl<'a> Tokenizer<'a> {
"Date32" => Token::SimpleType(DataType::Date32),
"Date64" => Token::SimpleType(DataType::Date64),
+ "List" => Token::List,
+
"Second" => Token::TimeUnit(TimeUnit::Second),
"Millisecond" => Token::TimeUnit(TimeUnit::Millisecond),
"Microsecond" => Token::TimeUnit(TimeUnit::Microsecond),
@@ -573,12 +586,14 @@ enum Token {
None,
Integer(i64),
DoubleQuotedString(String),
+ List,
}
impl Display for Token {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Token::SimpleType(t) => write!(f, "{t}"),
+ Token::List => write!(f, "List"),
Token::Timestamp => write!(f, "Timestamp"),
Token::Time32 => write!(f, "Time32"),
Token::Time64 => write!(f, "Time64"),