Skip to content

Commit 7b11ef2

Browse files
committed
Refactor to_local_time() signature away from user_defined
1 parent 1dddf03 commit 7b11ef2

File tree

2 files changed

+164
-206
lines changed

2 files changed

+164
-206
lines changed

datafusion/functions/src/datetime/to_local_time.rs

Lines changed: 143 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use std::ops::Add;
2020
use std::sync::Arc;
2121

2222
use arrow::array::timezone::Tz;
23-
use arrow::array::{Array, ArrayRef, PrimitiveBuilder};
23+
use arrow::array::{ArrayRef, PrimitiveBuilder};
2424
use arrow::datatypes::DataType::Timestamp;
2525
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
2626
use arrow::datatypes::{
@@ -31,11 +31,12 @@ use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc};
3131

3232
use datafusion_common::cast::as_primitive_array;
3333
use datafusion_common::{
34-
exec_err, internal_datafusion_err, plan_err, utils::take_function_args, Result,
34+
exec_err, internal_datafusion_err, internal_err, utils::take_function_args, Result,
3535
ScalarValue,
3636
};
3737
use datafusion_expr::{
38-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
38+
Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
39+
Volatility,
3940
};
4041
use datafusion_macros::user_doc;
4142

@@ -111,133 +112,152 @@ impl Default for ToLocalTimeFunc {
111112
impl ToLocalTimeFunc {
112113
pub fn new() -> Self {
113114
Self {
114-
signature: Signature::user_defined(Volatility::Immutable),
115+
signature: Signature::coercible(
116+
vec![Coercion::new_exact(TypeSignatureClass::Timestamp)],
117+
Volatility::Immutable,
118+
),
115119
}
116120
}
121+
}
117122

118-
fn to_local_time(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
119-
let [time_value] = take_function_args(self.name(), args)?;
123+
impl ScalarUDFImpl for ToLocalTimeFunc {
124+
fn as_any(&self) -> &dyn Any {
125+
self
126+
}
120127

121-
let arg_type = time_value.data_type();
122-
match arg_type {
123-
Timestamp(_, None) => {
124-
// if no timezone specified, just return the input
125-
Ok(time_value.clone())
126-
}
127-
// If has timezone, adjust the underlying time value. The current time value
128-
// is stored as i64 in UTC, even though the timezone may not be in UTC. Therefore,
129-
// we need to adjust the time value to the local time. See [`adjust_to_local_time`]
130-
// for more details.
131-
//
132-
// Then remove the timezone in return type, i.e. return None
133-
Timestamp(_, Some(timezone)) => {
134-
let tz: Tz = timezone.parse()?;
135-
136-
match time_value {
137-
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
138-
Some(ts),
139-
Some(_),
140-
)) => {
141-
let adjusted_ts =
142-
adjust_to_local_time::<TimestampNanosecondType>(*ts, tz)?;
143-
Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
144-
Some(adjusted_ts),
145-
None,
146-
)))
147-
}
148-
ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
149-
Some(ts),
150-
Some(_),
151-
)) => {
152-
let adjusted_ts =
153-
adjust_to_local_time::<TimestampMicrosecondType>(*ts, tz)?;
154-
Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
155-
Some(adjusted_ts),
156-
None,
157-
)))
158-
}
159-
ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
160-
Some(ts),
161-
Some(_),
162-
)) => {
163-
let adjusted_ts =
164-
adjust_to_local_time::<TimestampMillisecondType>(*ts, tz)?;
165-
Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
166-
Some(adjusted_ts),
167-
None,
168-
)))
169-
}
170-
ColumnarValue::Scalar(ScalarValue::TimestampSecond(
171-
Some(ts),
172-
Some(_),
173-
)) => {
174-
let adjusted_ts =
175-
adjust_to_local_time::<TimestampSecondType>(*ts, tz)?;
176-
Ok(ColumnarValue::Scalar(ScalarValue::TimestampSecond(
177-
Some(adjusted_ts),
178-
None,
179-
)))
180-
}
181-
ColumnarValue::Array(array) => {
182-
fn transform_array<T: ArrowTimestampType>(
183-
array: &ArrayRef,
184-
tz: Tz,
185-
) -> Result<ColumnarValue> {
186-
let mut builder = PrimitiveBuilder::<T>::new();
187-
188-
let primitive_array = as_primitive_array::<T>(array)?;
189-
for ts_opt in primitive_array.iter() {
190-
match ts_opt {
191-
None => builder.append_null(),
192-
Some(ts) => {
193-
let adjusted_ts: i64 =
194-
adjust_to_local_time::<T>(ts, tz)?;
195-
builder.append_value(adjusted_ts)
196-
}
197-
}
198-
}
199-
200-
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
201-
}
202-
203-
match array.data_type() {
204-
Timestamp(_, None) => {
205-
// if no timezone specified, just return the input
206-
Ok(time_value.clone())
207-
}
208-
Timestamp(Nanosecond, Some(_)) => {
209-
transform_array::<TimestampNanosecondType>(array, tz)
210-
}
211-
Timestamp(Microsecond, Some(_)) => {
212-
transform_array::<TimestampMicrosecondType>(array, tz)
213-
}
214-
Timestamp(Millisecond, Some(_)) => {
215-
transform_array::<TimestampMillisecondType>(array, tz)
216-
}
217-
Timestamp(Second, Some(_)) => {
218-
transform_array::<TimestampSecondType>(array, tz)
219-
}
220-
_ => {
221-
exec_err!("to_local_time function requires timestamp argument in array, got {:?}", array.data_type())
222-
}
223-
}
224-
}
225-
_ => {
226-
exec_err!(
227-
"to_local_time function requires timestamp argument, got {:?}",
228-
time_value.data_type()
229-
)
230-
}
231-
}
232-
}
233-
_ => {
234-
exec_err!(
235-
"to_local_time function requires timestamp argument, got {:?}",
236-
arg_type
237-
)
128+
fn name(&self) -> &str {
129+
"to_local_time"
130+
}
131+
132+
fn signature(&self) -> &Signature {
133+
&self.signature
134+
}
135+
136+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
137+
match &arg_types[0] {
138+
DataType::Null => Ok(Timestamp(Nanosecond, None)),
139+
Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)),
140+
dt => internal_err!(
141+
"The to_local_time function can only accept timestamp as the arg, got {dt}"
142+
),
143+
}
144+
}
145+
146+
fn invoke_with_args(
147+
&self,
148+
args: datafusion_expr::ScalarFunctionArgs,
149+
) -> Result<ColumnarValue> {
150+
let [time_value] = take_function_args(self.name(), &args.args)?;
151+
to_local_time(time_value)
152+
}
153+
154+
fn documentation(&self) -> Option<&Documentation> {
155+
self.doc()
156+
}
157+
}
158+
159+
fn transform_array<T: ArrowTimestampType>(
160+
array: &ArrayRef,
161+
tz: Tz,
162+
) -> Result<ColumnarValue> {
163+
let mut builder = PrimitiveBuilder::<T>::new();
164+
165+
let primitive_array = as_primitive_array::<T>(array)?;
166+
for ts_opt in primitive_array.iter() {
167+
match ts_opt {
168+
None => builder.append_null(),
169+
Some(ts) => {
170+
let adjusted_ts: i64 = adjust_to_local_time::<T>(ts, tz)?;
171+
builder.append_value(adjusted_ts)
238172
}
239173
}
240174
}
175+
176+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
177+
}
178+
179+
fn to_local_time(time_value: &ColumnarValue) -> Result<ColumnarValue> {
180+
let arg_type = time_value.data_type();
181+
182+
let tz: Tz = match &arg_type {
183+
Timestamp(_, Some(timezone)) => timezone.parse()?,
184+
Timestamp(_, None) => {
185+
// if no timezone specified, just return the input
186+
return Ok(time_value.clone());
187+
}
188+
DataType::Null => {
189+
return Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
190+
None, None,
191+
)));
192+
}
193+
dt => {
194+
return internal_err!(
195+
"to_local_time function requires timestamp argument, got {dt}"
196+
)
197+
}
198+
};
199+
200+
// If has timezone, adjust the underlying time value. The current time value
201+
// is stored as i64 in UTC, even though the timezone may not be in UTC. Therefore,
202+
// we need to adjust the time value to the local time. See [`adjust_to_local_time`]
203+
// for more details.
204+
//
205+
// Then remove the timezone in return type, i.e. return None
206+
match time_value {
207+
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(ts), Some(_))) => {
208+
let adjusted_ts = adjust_to_local_time::<TimestampNanosecondType>(*ts, tz)?;
209+
Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
210+
Some(adjusted_ts),
211+
None,
212+
)))
213+
}
214+
ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(ts), Some(_))) => {
215+
let adjusted_ts = adjust_to_local_time::<TimestampMicrosecondType>(*ts, tz)?;
216+
Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
217+
Some(adjusted_ts),
218+
None,
219+
)))
220+
}
221+
ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(Some(ts), Some(_))) => {
222+
let adjusted_ts = adjust_to_local_time::<TimestampMillisecondType>(*ts, tz)?;
223+
Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
224+
Some(adjusted_ts),
225+
None,
226+
)))
227+
}
228+
ColumnarValue::Scalar(ScalarValue::TimestampSecond(Some(ts), Some(_))) => {
229+
let adjusted_ts = adjust_to_local_time::<TimestampSecondType>(*ts, tz)?;
230+
Ok(ColumnarValue::Scalar(ScalarValue::TimestampSecond(
231+
Some(adjusted_ts),
232+
None,
233+
)))
234+
}
235+
ColumnarValue::Array(array)
236+
if matches!(array.data_type(), Timestamp(Nanosecond, Some(_))) =>
237+
{
238+
transform_array::<TimestampNanosecondType>(array, tz)
239+
}
240+
ColumnarValue::Array(array)
241+
if matches!(array.data_type(), Timestamp(Microsecond, Some(_))) =>
242+
{
243+
transform_array::<TimestampMicrosecondType>(array, tz)
244+
}
245+
ColumnarValue::Array(array)
246+
if matches!(array.data_type(), Timestamp(Millisecond, Some(_))) =>
247+
{
248+
transform_array::<TimestampMillisecondType>(array, tz)
249+
}
250+
ColumnarValue::Array(array)
251+
if matches!(array.data_type(), Timestamp(Second, Some(_))) =>
252+
{
253+
transform_array::<TimestampSecondType>(array, tz)
254+
}
255+
_ => {
256+
internal_err!(
257+
"to_local_time function requires timestamp argument, got {arg_type}"
258+
)
259+
}
260+
}
241261
}
242262

243263
/// This function converts a timestamp with a timezone to a timestamp without a timezone.
@@ -343,68 +363,6 @@ fn adjust_to_local_time<T: ArrowTimestampType>(ts: i64, tz: Tz) -> Result<i64> {
343363
}
344364
}
345365

346-
impl ScalarUDFImpl for ToLocalTimeFunc {
347-
fn as_any(&self) -> &dyn Any {
348-
self
349-
}
350-
351-
fn name(&self) -> &str {
352-
"to_local_time"
353-
}
354-
355-
fn signature(&self) -> &Signature {
356-
&self.signature
357-
}
358-
359-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
360-
let [time_value] = take_function_args(self.name(), arg_types)?;
361-
362-
match time_value {
363-
Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)),
364-
_ => exec_err!(
365-
"The to_local_time function can only accept timestamp as the arg, got {:?}", time_value
366-
)
367-
}
368-
}
369-
370-
fn invoke_with_args(
371-
&self,
372-
args: datafusion_expr::ScalarFunctionArgs,
373-
) -> Result<ColumnarValue> {
374-
let [time_value] = take_function_args(self.name(), args.args)?;
375-
376-
self.to_local_time(std::slice::from_ref(&time_value))
377-
}
378-
379-
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
380-
if arg_types.len() != 1 {
381-
return plan_err!(
382-
"to_local_time function requires 1 argument, got {:?}",
383-
arg_types.len()
384-
);
385-
}
386-
387-
let first_arg = arg_types[0].clone();
388-
match &first_arg {
389-
DataType::Null => Ok(vec![Timestamp(Nanosecond, None)]),
390-
Timestamp(Nanosecond, timezone) => {
391-
Ok(vec![Timestamp(Nanosecond, timezone.clone())])
392-
}
393-
Timestamp(Microsecond, timezone) => {
394-
Ok(vec![Timestamp(Microsecond, timezone.clone())])
395-
}
396-
Timestamp(Millisecond, timezone) => {
397-
Ok(vec![Timestamp(Millisecond, timezone.clone())])
398-
}
399-
Timestamp(Second, timezone) => Ok(vec![Timestamp(Second, timezone.clone())]),
400-
_ => plan_err!("The to_local_time function can only accept Timestamp as the arg got {first_arg}"),
401-
}
402-
}
403-
fn documentation(&self) -> Option<&Documentation> {
404-
self.doc()
405-
}
406-
}
407-
408366
#[cfg(test)]
409367
mod tests {
410368
use std::sync::Arc;

0 commit comments

Comments
 (0)