@@ -14,11 +14,12 @@ use crate::error::{self, InvalidListArraySnafu};
1414use arrow:: array:: {
1515 Array , ArrayRef , ArrowPrimitiveType , BinaryArray , BooleanArray , DictionaryArray ,
1616 FixedSizeBinaryArray , Float32Array , Float64Array , Int8Array , Int16Array , Int32Array ,
17- Int64Array , PrimitiveArray , RecordBatch , StringArray , TimestampNanosecondArray , UInt8Array ,
18- UInt16Array , UInt32Array , UInt64Array ,
17+ Int64Array , PrimitiveArray , RecordBatch , StringArray , StructArray , TimestampNanosecondArray ,
18+ UInt8Array , UInt16Array , UInt32Array , UInt64Array ,
19+ } ;
20+ use arrow:: datatypes:: {
21+ ArrowDictionaryKeyType , ArrowNativeType , DataType , TimeUnit , UInt8Type , UInt16Type ,
1922} ;
20- use arrow:: datatypes:: { ArrowDictionaryKeyType , TimeUnit } ;
21- use arrow:: datatypes:: { ArrowNativeType , DataType , UInt8Type , UInt16Type } ;
2223use paste:: paste;
2324use snafu:: { OptionExt , ensure} ;
2425
@@ -139,9 +140,7 @@ macro_rules! impl_downcast {
139140
140141 pub fn [ <get_ $suffix _array> ] <' a>( rb: & ' a RecordBatch , name: & str ) -> error:: Result <& ' a $array_type> {
141142 use arrow:: datatypes:: DataType :: * ;
142- let arr = rb. column_by_name( name) . context( error:: ColumnNotFoundSnafu {
143- name,
144- } ) ?;
143+ let arr = get_required_array( rb, name) ?;
145144
146145 arr. as_any( )
147146 . downcast_ref:: <$array_type>( )
@@ -177,6 +176,17 @@ impl_downcast!(
177176 TimestampNanosecondArray
178177) ;
179178
179+ /// Get reference to array that the caller requires to be in the record batch.
180+ /// If the column is not in the record batch, returns `ColumnNotFound` error
181+ pub fn get_required_array < ' a > (
182+ record_batch : & ' a RecordBatch ,
183+ column_name : & str ,
184+ ) -> error:: Result < & ' a ArrayRef > {
185+ record_batch
186+ . column_by_name ( column_name)
187+ . context ( error:: ColumnNotFoundSnafu { name : column_name } )
188+ }
189+
180190trait NullableInt64ArrayAccessor {
181191 fn i64_at ( & self , idx : usize ) -> error:: Result < Option < i64 > > ;
182192}
@@ -243,6 +253,13 @@ pub enum ByteArrayAccessor<'a> {
243253}
244254
245255impl < ' a > ByteArrayAccessor < ' a > {
256+ pub fn try_new_for_column (
257+ record_batch : & ' a RecordBatch ,
258+ column_name : & str ,
259+ ) -> error:: Result < Self > {
260+ Self :: try_new ( get_required_array ( record_batch, column_name) ?)
261+ }
262+
246263 pub fn try_new ( arr : & ' a ArrayRef ) -> error:: Result < Self > {
247264 match arr. data_type ( ) {
248265 DataType :: Binary => {
@@ -339,7 +356,11 @@ where
339356 fn try_new_with_datatype ( data_type : DataType , arr : & ' a ArrayRef ) -> error:: Result < Self > {
340357 // if the type isn't a dictionary, we treat it as an unencoded array
341358 if * arr. data_type ( ) == data_type {
342- return Ok ( Self :: Native ( arr. as_any ( ) . downcast_ref :: < T > ( ) . unwrap ( ) ) ) ;
359+ return Ok ( Self :: Native (
360+ arr. as_any ( )
361+ . downcast_ref :: < T > ( )
362+ . expect ( "array can be downcast to it's native datatype" ) ,
363+ ) ) ;
343364 }
344365
345366 // determine if the type is a dictionary where the value is the desired datatype
@@ -356,13 +377,13 @@ where
356377 DataType :: UInt8 => Self :: Dictionary8 ( DictionaryArrayAccessor :: new (
357378 arr. as_any ( )
358379 . downcast_ref :: < DictionaryArray < UInt8Type > > ( )
359- . unwrap ( ) ,
360- ) ) ,
380+ . expect ( "array can be downcast to DictionaryArray<UInt8Type" ) ,
381+ ) ? ) ,
361382 DataType :: UInt16 => Self :: Dictionary16 ( DictionaryArrayAccessor :: new (
362383 arr. as_any ( )
363384 . downcast_ref :: < DictionaryArray < UInt16Type > > ( )
364- . unwrap ( ) ,
365- ) ) ,
385+ . expect ( "array can be downcast to DictionaryArray<UInt16Type>" ) ,
386+ ) ? ) ,
366387 _ => {
367388 return error:: UnsupportedDictionaryKeyTypeSnafu {
368389 expect_oneof : vec ! [ DataType :: UInt8 , DataType :: UInt16 ] ,
@@ -394,6 +415,13 @@ where
394415 pub fn try_new ( arr : & ' a ArrayRef ) -> error:: Result < Self > {
395416 Self :: try_new_with_datatype ( V :: DATA_TYPE , arr)
396417 }
418+
419+ pub fn try_new_for_column (
420+ record_batch : & ' a RecordBatch ,
421+ column_name : & str ,
422+ ) -> error:: Result < Self > {
423+ Self :: try_new ( get_required_array ( record_batch, column_name) ?)
424+ }
397425}
398426
399427impl < ' a > MaybeDictArrayAccessor < ' a , BinaryArray > {
@@ -412,6 +440,13 @@ impl<'a> MaybeDictArrayAccessor<'a, StringArray> {
412440 pub fn try_new ( arr : & ' a ArrayRef ) -> error:: Result < Self > {
413441 Self :: try_new_with_datatype ( StringArray :: DATA_TYPE , arr)
414442 }
443+
444+ pub fn try_new_for_column (
445+ record_batch : & ' a RecordBatch ,
446+ column_name : & str ,
447+ ) -> error:: Result < Self > {
448+ Self :: try_new ( get_required_array ( record_batch, column_name) ?)
449+ }
415450}
416451
417452pub type Int32ArrayAccessor < ' a > = MaybeDictArrayAccessor < ' a , Int32Array > ;
@@ -431,22 +466,109 @@ where
431466 K : ArrowDictionaryKeyType ,
432467 V : Array + NullableArrayAccessor + ' static ,
433468{
434- pub fn new ( a : & ' a DictionaryArray < K > ) -> Self {
435- let dict = a. as_any ( ) . downcast_ref :: < DictionaryArray < K > > ( ) . unwrap ( ) ;
436- let value = dict. values ( ) . as_any ( ) . downcast_ref :: < V > ( ) . unwrap ( ) ;
437- Self { inner : dict, value }
469+ pub fn new ( dict : & ' a DictionaryArray < K > ) -> error:: Result < Self > {
470+ let value = dict
471+ . values ( )
472+ . as_any ( )
473+ . downcast_ref :: < V > ( )
474+ . with_context ( || error:: InvalidListArraySnafu {
475+ expect_oneof : Vec :: new ( ) ,
476+ actual : dict. values ( ) . data_type ( ) . clone ( ) ,
477+ } ) ?;
478+ Ok ( Self { inner : dict, value } )
438479 }
439480
440481 pub fn value_at ( & self , idx : usize ) -> Option < V :: Native > {
441482 if self . inner . is_valid ( idx) {
442- let offset = self . inner . key ( idx) . unwrap ( ) ;
483+ let offset = self
484+ . inner
485+ . key ( idx)
486+ . expect ( "dictionary should be valid at index" ) ;
443487 self . value . value_at ( offset)
444488 } else {
445489 None
446490 }
447491 }
448492}
449493
494+ /// Helper for accessing columns of a struct array
495+ ///
496+ /// Methods return various errors into this crate's Error type if
497+ /// if callers requirments for the struct columns are not met (for
498+ /// example `ColumnDataTypeMismatchSnafu`)
499+ pub struct StructColumnAccessor < ' a > {
500+ inner : & ' a StructArray ,
501+ }
502+
503+ impl < ' a > StructColumnAccessor < ' a > {
504+ pub fn new ( arr : & ' a StructArray ) -> Self {
505+ Self { inner : arr }
506+ }
507+
508+ pub fn primitive_column < T : ArrowPrimitiveType + ' static > (
509+ & self ,
510+ column_name : & str ,
511+ ) -> error:: Result < & ' a PrimitiveArray < T > > {
512+ self . primitive_column_op ( column_name) ?
513+ . with_context ( || error:: ColumnNotFoundSnafu {
514+ name : column_name. to_string ( ) ,
515+ } )
516+ }
517+
518+ pub fn primitive_column_op < T : ArrowPrimitiveType + ' static > (
519+ & self ,
520+ column_name : & str ,
521+ ) -> error:: Result < Option < & ' a PrimitiveArray < T > > > {
522+ self . inner
523+ . column_by_name ( column_name)
524+ . map ( |arr| {
525+ arr. as_any ( )
526+ . downcast_ref :: < PrimitiveArray < T > > ( )
527+ . with_context ( || error:: ColumnDataTypeMismatchSnafu {
528+ name : column_name. to_string ( ) ,
529+ expect : T :: DATA_TYPE ,
530+ actual : arr. data_type ( ) . clone ( ) ,
531+ } )
532+ } )
533+ . transpose ( )
534+ }
535+
536+ pub fn bool_column_op ( & self , column_name : & str ) -> error:: Result < Option < & ' a BooleanArray > > {
537+ self . inner
538+ . column_by_name ( column_name)
539+ . map ( |arr| {
540+ arr. as_any ( )
541+ . downcast_ref ( )
542+ . with_context ( || error:: ColumnDataTypeMismatchSnafu {
543+ name : column_name. to_string ( ) ,
544+ expect : DataType :: Boolean ,
545+ actual : arr. data_type ( ) . clone ( ) ,
546+ } )
547+ } )
548+ . transpose ( )
549+ }
550+
551+ pub fn string_column_op (
552+ & self ,
553+ column_name : & str ,
554+ ) -> error:: Result < Option < StringArrayAccessor < ' a > > > {
555+ self . inner
556+ . column_by_name ( column_name)
557+ . map ( StringArrayAccessor :: try_new)
558+ . transpose ( )
559+ }
560+
561+ pub fn byte_array_column_op (
562+ & self ,
563+ column_name : & str ,
564+ ) -> error:: Result < Option < ByteArrayAccessor < ' a > > > {
565+ self . inner
566+ . column_by_name ( column_name)
567+ . map ( ByteArrayAccessor :: try_new)
568+ . transpose ( )
569+ }
570+ }
571+
450572#[ cfg( test) ]
451573mod tests {
452574 use crate :: arrays:: { NullableArrayAccessor , StringArrayAccessor } ;
0 commit comments