This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch active_release
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/active_release by this push:
     new 6c570cf  Support read decimal data from csv reader if user provide the 
schema with decimal data type (#941) (#974)
6c570cf is described below

commit 6c570cfe98d6a7a4ec74b139b733c5c72ed10015
Author: Andrew Lamb <and...@nerdnetworks.org>
AuthorDate: Wed Nov 24 07:10:55 2021 -0500

    Support read decimal data from csv reader if user provide the schema with 
decimal data type (#941) (#974)
    
    * support decimal data type for csv reader
    
    * format code and fix lint check
    
    * fix the clippy error
    
    * enchance the parse csv to decimal and add more test
    
    Co-authored-by: Kun Liu <liu...@apache.org>
---
 arrow/src/array/builder.rs       |   4 +-
 arrow/src/array/mod.rs           |   2 +
 arrow/src/csv/reader.rs          | 263 ++++++++++++++++++++++++++++++++++++++-
 arrow/test/data/decimal_test.csv |  10 ++
 4 files changed, 275 insertions(+), 4 deletions(-)

diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs
index d08816c..af6f3c3 100644
--- a/arrow/src/array/builder.rs
+++ b/arrow/src/array/builder.rs
@@ -1118,7 +1118,7 @@ pub struct FixedSizeBinaryBuilder {
     builder: FixedSizeListBuilder<UInt8Builder>,
 }
 
-const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
+pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
     9,
     99,
     999,
@@ -1158,7 +1158,7 @@ const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
     9999999999999999999999999999999999999,
     170141183460469231731687303715884105727,
 ];
-const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
+pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
     -9,
     -99,
     -999,
diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs
index 235d868..26b410e 100644
--- a/arrow/src/array/mod.rs
+++ b/arrow/src/array/mod.rs
@@ -391,6 +391,8 @@ pub use self::builder::StringBuilder;
 pub use self::builder::StringDictionaryBuilder;
 pub use self::builder::StructBuilder;
 pub use self::builder::UnionBuilder;
+pub use self::builder::MAX_DECIMAL_FOR_EACH_PRECISION;
+pub use self::builder::MIN_DECIMAL_FOR_EACH_PRECISION;
 
 pub type Int8Builder = PrimitiveBuilder<Int8Type>;
 pub type Int16Builder = PrimitiveBuilder<Int16Type>;
diff --git a/arrow/src/csv/reader.rs b/arrow/src/csv/reader.rs
index 4940ea2..ac72939 100644
--- a/arrow/src/csv/reader.rs
+++ b/arrow/src/csv/reader.rs
@@ -50,7 +50,8 @@ use std::io::{Read, Seek, SeekFrom};
 use std::sync::Arc;
 
 use crate::array::{
-    ArrayRef, BooleanArray, DictionaryArray, PrimitiveArray, StringArray,
+    ArrayRef, BooleanArray, DecimalBuilder, DictionaryArray, PrimitiveArray, 
StringArray,
+    MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
 };
 use crate::compute::kernels::cast_utils::string_to_timestamp_nanos;
 use crate::datatypes::*;
@@ -58,8 +59,11 @@ use crate::error::{ArrowError, Result};
 use crate::record_batch::RecordBatch;
 
 use csv_crate::{ByteRecord, StringRecord};
+use std::ops::Neg;
 
 lazy_static! {
+    static ref PARSE_DECIMAL_RE: Regex =
+        Regex::new(r"^-?(\d+\.?\d*|\d*\.?\d+)$").unwrap();
     static ref DECIMAL_RE: Regex = 
Regex::new(r"^-?(\d*\.\d+|\d+\.\d*)$").unwrap();
     static ref INTEGER_RE: Regex = Regex::new(r"^-?(\d+)$").unwrap();
     static ref BOOLEAN_RE: Regex = RegexBuilder::new(r"^(true)$|^(false)$")
@@ -99,7 +103,7 @@ fn infer_field_schema(string: &str) -> DataType {
 ///
 /// If `max_read_records` is not set, the whole file is read to infer its 
schema.
 ///
-/// Return infered schema and number of records used for inference. This 
function does not change
+/// Return inferred schema and number of records used for inference. This 
function does not change
 /// reader cursor offset.
 pub fn infer_file_schema<R: Read + Seek>(
     reader: &mut R,
@@ -513,6 +517,9 @@ fn parse(
             let field = &fields[i];
             match field.data_type() {
                 DataType::Boolean => build_boolean_array(line_number, rows, i),
+                DataType::Decimal(precision, scale) => {
+                    build_decimal_array(line_number, rows, i, *precision, 
*scale)
+                }
                 DataType::Int8 => 
build_primitive_array::<Int8Type>(line_number, rows, i),
                 DataType::Int16 => {
                     build_primitive_array::<Int16Type>(line_number, rows, i)
@@ -728,6 +735,161 @@ fn parse_bool(string: &str) -> Option<bool> {
     }
 }
 
+// parse the column string to an Arrow Array
+fn build_decimal_array(
+    _line_number: usize,
+    rows: &[StringRecord],
+    col_idx: usize,
+    precision: usize,
+    scale: usize,
+) -> Result<ArrayRef> {
+    let mut decimal_builder = DecimalBuilder::new(rows.len(), precision, 
scale);
+    for row in rows {
+        let col_s = row.get(col_idx);
+        match col_s {
+            None => {
+                // No data for this row
+                decimal_builder.append_null()?;
+            }
+            Some(s) => {
+                if s.is_empty() {
+                    // append null
+                    decimal_builder.append_null()?;
+                } else {
+                    let decimal_value: Result<i128> =
+                        parse_decimal_with_parameter(s, precision, scale);
+                    match decimal_value {
+                        Ok(v) => {
+                            decimal_builder.append_value(v)?;
+                        }
+                        Err(e) => {
+                            return Err(e);
+                        }
+                    }
+                }
+            }
+        }
+    }
+    Ok(Arc::new(decimal_builder.finish()))
+}
+
+// Parse the string format decimal value to i128 format and checking the 
precision and scale.
+// The result i128 value can't be out of bounds.
+fn parse_decimal_with_parameter(s: &str, precision: usize, scale: usize) -> 
Result<i128> {
+    if PARSE_DECIMAL_RE.is_match(s) {
+        let mut offset = s.len();
+        let len = s.len();
+        // each byte is digit、'-' or '.'
+        let mut base = 1;
+
+        // handle the value after the '.' and meet the scale
+        let delimiter_position = s.find('.');
+        match delimiter_position {
+            None => {
+                // there is no '.'
+                base = 10_i128.pow(scale as u32);
+            }
+            Some(mid) => {
+                // there is the '.'
+                if len - mid >= scale + 1 {
+                    // If the string value is "123.12345" and the scale is 2, 
we should just remain '.12' and drop the '345' value.
+                    offset -= len - mid - 1 - scale;
+                } else {
+                    // If the string value is "123.12" and the scale is 4, we 
should append '00' to the tail.
+                    base = 10_i128.pow((scale + 1 + mid - len) as u32);
+                }
+            }
+        };
+
+        let bytes = s.as_bytes();
+        let mut negative = false;
+        let mut result: i128 = 0;
+
+        while offset > 0 {
+            match bytes[offset - 1] {
+                b'-' => {
+                    negative = true;
+                }
+                b'.' => {
+                    // do nothing
+                }
+                b'0'..=b'9' => {
+                    result += i128::from(bytes[offset - 1] - b'0') * base;
+                    base *= 10;
+                }
+                _ => {
+                    return Err(ArrowError::ParseError(format!(
+                        "can't match byte {}",
+                        bytes[offset - 1]
+                    )));
+                }
+            }
+            offset -= 1;
+        }
+        if negative {
+            result = result.neg();
+        }
+        if result > MAX_DECIMAL_FOR_EACH_PRECISION[precision - 1]
+            || result < MIN_DECIMAL_FOR_EACH_PRECISION[precision - 1]
+        {
+            return Err(ArrowError::ParseError(format!(
+                "parse decimal overflow, the precision {}, the scale {}, the 
value {}",
+                precision, scale, s
+            )));
+        }
+        Ok(result)
+    } else {
+        Err(ArrowError::ParseError(format!(
+            "can't parse the string value {} to decimal",
+            s
+        )))
+    }
+}
+
+// Parse the string format decimal value to i128 format without checking the 
precision and scale.
+// Like "125.12" to 12512_i128.
+fn parse_decimal(s: &str) -> Result<i128> {
+    if PARSE_DECIMAL_RE.is_match(s) {
+        let mut offset = s.len();
+        // each byte is digit、'-' or '.'
+        let bytes = s.as_bytes();
+        let mut negative = false;
+        let mut result: i128 = 0;
+        let mut base = 1;
+        while offset > 0 {
+            match bytes[offset - 1] {
+                b'-' => {
+                    negative = true;
+                }
+                b'.' => {
+                    // do nothing
+                }
+                b'0'..=b'9' => {
+                    result += i128::from(bytes[offset - 1] - b'0') * base;
+                    base *= 10;
+                }
+                _ => {
+                    return Err(ArrowError::ParseError(format!(
+                        "can't match byte {}",
+                        bytes[offset - 1]
+                    )));
+                }
+            }
+            offset -= 1;
+        }
+        if negative {
+            Ok(result.neg())
+        } else {
+            Ok(result)
+        }
+    } else {
+        Err(ArrowError::ParseError(format!(
+            "can't parse the string value {} to decimal",
+            s
+        )))
+    }
+}
+
 // parses a specific column (col_idx) into an Arrow Array.
 fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
     line_number: usize,
@@ -1056,6 +1218,37 @@ mod tests {
     }
 
     #[test]
+    fn test_csv_reader_with_decimal() {
+        let schema = Schema::new(vec![
+            Field::new("city", DataType::Utf8, false),
+            Field::new("lat", DataType::Decimal(26, 6), false),
+            Field::new("lng", DataType::Decimal(26, 6), false),
+        ]);
+
+        let file = File::open("test/data/decimal_test.csv").unwrap();
+
+        let mut csv = Reader::new(file, Arc::new(schema), false, None, 1024, 
None, None);
+        let batch = csv.next().unwrap().unwrap();
+        // access data from a primitive array
+        let lat = batch
+            .column(1)
+            .as_any()
+            .downcast_ref::<DecimalArray>()
+            .unwrap();
+
+        assert_eq!("57.653484", lat.value_as_string(0));
+        assert_eq!("53.002666", lat.value_as_string(1));
+        assert_eq!("52.412811", lat.value_as_string(2));
+        assert_eq!("51.481583", lat.value_as_string(3));
+        assert_eq!("12.123456", lat.value_as_string(4));
+        assert_eq!("50.760000", lat.value_as_string(5));
+        assert_eq!("0.123000", lat.value_as_string(6));
+        assert_eq!("123.000000", lat.value_as_string(7));
+        assert_eq!("123.000000", lat.value_as_string(8));
+        assert_eq!("-50.760000", lat.value_as_string(9));
+    }
+
+    #[test]
     fn test_csv_from_buf_reader() {
         let schema = Schema::new(vec![
             Field::new("city", DataType::Utf8, false),
@@ -1348,6 +1541,8 @@ mod tests {
         assert_eq!(infer_field_schema("false"), DataType::Boolean);
         assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
         assert_eq!(infer_field_schema("2020-11-08T14:20:01"), 
DataType::Date64);
+        assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
+        assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
     }
 
     #[test]
@@ -1374,6 +1569,70 @@ mod tests {
         );
     }
 
+    #[test]
+    fn test_parse_decimal() {
+        let tests = [
+            ("123.00", 12300i128),
+            ("123.123", 123123i128),
+            ("0.0123", 123i128),
+            ("0.12300", 12300i128),
+            ("-5.123", -5123i128),
+            ("-45.432432", -45432432i128),
+        ];
+        for (s, i) in tests {
+            let result = parse_decimal(s);
+            assert_eq!(i, result.unwrap());
+        }
+    }
+
+    #[test]
+    fn test_parse_decimal_with_parameter() {
+        let tests = [
+            ("123.123", 123123i128),
+            ("123.1234", 123123i128),
+            ("123.1", 123100i128),
+            ("123", 123000i128),
+            ("-123.123", -123123i128),
+            ("-123.1234", -123123i128),
+            ("-123.1", -123100i128),
+            ("-123", -123000i128),
+            ("0.0000123", 0i128),
+            ("12.", 12000i128),
+            ("-12.", -12000i128),
+            ("00.1", 100i128),
+            ("-00.1", -100i128),
+            ("12345678912345678.1234", 12345678912345678123i128),
+            ("-12345678912345678.1234", -12345678912345678123i128),
+            ("99999999999999999.999", 99999999999999999999i128),
+            ("-99999999999999999.999", -99999999999999999999i128),
+            (".123", 123i128),
+            ("-.123", -123i128),
+            ("123.", 123000i128),
+            ("-123.", -123000i128),
+        ];
+        for (s, i) in tests {
+            let result = parse_decimal_with_parameter(s, 20, 3);
+            assert_eq!(i, result.unwrap())
+        }
+        let can_not_parse_tests = ["123,123", "."];
+        for s in can_not_parse_tests {
+            let result = parse_decimal_with_parameter(s, 20, 3);
+            assert_eq!(
+                format!(
+                    "Parser error: can't parse the string value {} to decimal",
+                    s
+                ),
+                result.unwrap_err().to_string()
+            );
+        }
+        let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"];
+        for s in overflow_parse_tests {
+            let result = parse_decimal_with_parameter(s, 10, 3);
+            assert_eq!(format!(
+                "Parser error: parse decimal overflow, the precision {}, the 
scale {}, the value {}", 10,3, s),result.unwrap_err().to_string());
+        }
+    }
+
     /// Interprets a naive_datetime (with no explicit timezone offset)
     /// using the local timezone and returns the timestamp in UTC (0
     /// offset)
diff --git a/arrow/test/data/decimal_test.csv b/arrow/test/data/decimal_test.csv
new file mode 100644
index 0000000..460ed80
--- /dev/null
+++ b/arrow/test/data/decimal_test.csv
@@ -0,0 +1,10 @@
+"Elgin, Scotland, the UK",57.653484,-3.335724
+"Stoke-on-Trent, Staffordshire, the UK",53.002666,-2.179404
+"Solihull, Birmingham, UK",52.412811,-1.778197
+"Cardiff, Cardiff county, UK",51.481583,-3.179090
+"Cardiff, Cardiff county, UK",12.12345678,-3.179090
+"Eastbourne, East Sussex, UK",50.76,0.290472
+"Eastbourne, East Sussex, UK",.123,0.290472
+"Eastbourne, East Sussex, UK",123.,0.290472
+"Eastbourne, East Sussex, UK",123,0.290472
+"Eastbourne, East Sussex, UK",-50.76,0.290472
\ No newline at end of file

Reply via email to