Skip to content

Commit 28d7d29

Browse files
committed
support arabic-indic digits date.
1 parent a8f3f0c commit 28d7d29

File tree

4 files changed

+253
-2
lines changed

4 files changed

+253
-2
lines changed

backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
6363
.set("spark.sql.shuffle.partitions", "5")
6464
.set("spark.sql.autoBroadcastJoinThreshold", "10MB")
6565
.set(GlutenConfig.GLUTEN_SUPPORTED_SCALA_UDFS.key, "compare_substrings:compare_substrings")
66+
.set(
67+
"spark.sql.optimizer.excludedRules",
68+
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding" +
69+
"," +
70+
"org.apache.spark.sql.catalyst.optimizer.NullPropagation"
71+
)
6672
}
6773

6874
override def beforeAll(): Unit = {
@@ -1386,6 +1392,38 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
13861392
}
13871393
}
13881394

1395+
test("arabic_indic digit date") {
1396+
withSQLConf(
1397+
SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
1398+
(ConstantFolding.ruleName + "," + NullPropagation.ruleName)) {
1399+
sql("create table tb_arabic_date(d string) using parquet")
1400+
sql("""
1401+
|insert into tb_arabic_date values
1402+
|('2020-01-01'),
1403+
|(cast(unbase64('2aLZoNmi2aQt2aDZpi3ZoNmh') as string)),
1404+
|(cast(unbase64('2aLZoNmi2aQt2aHZoi3Zo9mh') as string)),
1405+
|('2022-10-11'),
1406+
|(cast(unbase64('2aLZoNmi2aQt2aHZoi3Zo9mh') as string))
1407+
|""".stripMargin)
1408+
var query_sql = "select from_unixtime(unix_timestamp(d, 'yyyy-MM-dd')) from tb_arabic_date"
1409+
compareResultsAgainstVanillaSpark(query_sql, true, { _ => })
1410+
1411+
query_sql = """
1412+
|select from_unixtime(
1413+
| unix_timestamp(cast(unbase64('2aLZoNmi2aQt2aDZpi3ZoNmh') as string),
1414+
| 'yyyy-MM-dd'))
1415+
|""".stripMargin
1416+
compareResultsAgainstVanillaSpark(query_sql, true, { _ => })
1417+
1418+
query_sql = """
1419+
|select from_unixtime(unix_timestamp('2020-01-01', 'yyyy-MM-dd'))
1420+
|""".stripMargin
1421+
compareResultsAgainstVanillaSpark(query_sql, true, { _ => })
1422+
1423+
sql("drop table tb_arabic_date")
1424+
}
1425+
}
1426+
13891427
test("Test map with nullable key") {
13901428
val sql = "select map(string_field1, int_field1) from json_test where string_field1 is not null"
13911429
compareResultsAgainstVanillaSpark(sql, true, { _ => })
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
19+
#include <Columns/IColumn.h>
20+
#include <Columns/ColumnNullable.h>
21+
#include <Columns/ColumnConst.h>
22+
#include <Columns/ColumnString.h>
23+
#include <DataTypes/DataTypeNullable.h>
24+
#include <DataTypes/IDataType.h>
25+
#include <Functions/FunctionFactory.h>
26+
#include <Functions/FunctionHelpers.h>
27+
#include <Functions/IFunction.h>
28+
#include <Common/Exception.h>
29+
#include <Common/logger_useful.h>
30+
31+
namespace DB
32+
{
33+
namespace ErrorCodes
34+
{
35+
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
36+
}
37+
}
38+
39+
namespace local_engine
40+
{
41+
// Since spark 3.3, unix_timestamp support arabic number input, e.g., "٢٠٢١-٠٧-٠١ ١٢:٠٠:٠٠".
42+
// We implement a function to translate arabic indic digits to ascii digits here.
43+
class ArabicIndicToAsciiDigitForDateFunction : public DB::IFunction
44+
{
45+
public:
46+
static constexpr auto name = "arabic_indic_to_ascii_digit_for_date";
47+
48+
static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<ArabicIndicToAsciiDigitForDateFunction>(); }
49+
50+
String getName() const override { return name; }
51+
52+
bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo & /*arguments*/) const override { return false; }
53+
size_t getNumberOfArguments() const override { return 1; }
54+
55+
DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & arguments) const override
56+
{
57+
auto nested_type = DB::removeNullable(arguments[0]);
58+
if (!DB::WhichDataType(nested_type).isString())
59+
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument for function {} must be String, but got {}", getName(), arguments[0]->getName());
60+
return arguments[0];
61+
}
62+
63+
DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &, size_t input_rows_count) const override
64+
{
65+
auto data_col = arguments[0].column;
66+
const DB::ColumnString * col_str = nullptr;
67+
const DB::ColumnNullable * col_nullable = nullptr;
68+
const DB::NullMap * null_map = nullptr;
69+
if (data_col->isConst())
70+
{
71+
if (data_col->isNullAt(0))
72+
{
73+
return data_col;
74+
}
75+
const DB::ColumnConst * col_const = DB::checkAndGetColumn<DB::ColumnConst>(data_col.get());
76+
data_col = col_const->getDataColumnPtr();
77+
if (data_col->isNullable())
78+
{
79+
col_nullable = DB::checkAndGetColumn<DB::ColumnNullable>(data_col.get());
80+
null_map = &(col_nullable->getNullMapData());
81+
col_str = DB::checkAndGetColumn<DB::ColumnString>(&(col_nullable->getNestedColumn()));
82+
}
83+
else
84+
{
85+
col_str = DB::checkAndGetColumn<DB::ColumnString>(data_col.get());
86+
}
87+
if (!col_str)
88+
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument for function {} must be String, but got {}", getName(), data_col->getName());
89+
auto date_str = col_str->getDataAt(0);
90+
auto new_str = convertArabicIndicDigit(date_str);
91+
auto new_data_col = data_col->cloneEmpty();
92+
new_data_col->insertData(new_str.c_str(), new_str.size());
93+
return DB::ColumnConst::create(std::move(new_data_col), input_rows_count);
94+
}
95+
96+
if (data_col->isNullable())
97+
{
98+
col_nullable = DB::checkAndGetColumn<DB::ColumnNullable>(data_col.get());
99+
null_map = &(col_nullable->getNullMapData());
100+
col_str = DB::checkAndGetColumn<DB::ColumnString>(&(col_nullable->getNestedColumn()));
101+
}
102+
else
103+
{
104+
col_str = DB::checkAndGetColumn<DB::ColumnString>(data_col.get());
105+
}
106+
if (!col_str)
107+
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument for function {} must be String, but got {}", getName(), data_col->getName());
108+
109+
auto nested_data_col = DB::removeNullable(arguments[0].column);
110+
bool has_arabic_indic_digit = false;
111+
size_t row_index = 0;
112+
for (row_index = 0; row_index < input_rows_count; ++row_index)
113+
{
114+
if (null_map && (*null_map)[row_index])
115+
{
116+
continue;
117+
}
118+
auto str = col_str->getDataAt(row_index);
119+
if (hasArabicIndicDigit(str))
120+
{
121+
has_arabic_indic_digit = true;
122+
break;
123+
}
124+
}
125+
126+
if (!has_arabic_indic_digit)
127+
{
128+
// No Arabic indic digits found, return the original column
129+
return arguments[0].column;
130+
}
131+
132+
auto res_col = data_col->cloneEmpty();
133+
if (row_index)
134+
{
135+
res_col->insertManyFrom(*data_col, 0, row_index);
136+
}
137+
for (; row_index < input_rows_count; ++row_index)
138+
{
139+
if (null_map && (*null_map)[row_index])
140+
{
141+
res_col->insertDefault();
142+
continue;
143+
}
144+
auto str = convertArabicIndicDigit(col_str->getDataAt(row_index));
145+
res_col->insertData(str.c_str(), str.size());
146+
}
147+
return res_col;
148+
}
149+
150+
private:
151+
bool hasArabicIndicDigit(StringRef str) const
152+
{
153+
// In most cases, the first byte is a digit.
154+
char c = reinterpret_cast<char>(str.data[0]);
155+
if ('0' <= c && c <= '9')
156+
{
157+
return false;
158+
}
159+
return true;
160+
}
161+
162+
163+
bool isArabicIndicDigit(char32_t c) const { return c >= 0x0660 && c <= 0x0669; }
164+
char toAsciiDigit(char32_t c) const { return static_cast<char>(c - 0x0660 + '0'); }
165+
166+
String convertArabicIndicDigit(const StringRef & str) const
167+
{
168+
std::string result;
169+
result.reserve(str.size);
170+
for (size_t i = 0; i < str.size;)
171+
{
172+
unsigned char c = str.data[i];
173+
char32_t cp = 0;
174+
if ((c & 0x80) == 0) // 1-byte
175+
{
176+
cp = c;
177+
i += 1;
178+
}
179+
else if ((c & 0xE0) == 0xC0) // 2-byte
180+
{
181+
cp = ((c & 0x1F) << 6) | (str.data[i + 1] & 0x3F);
182+
i += 2;
183+
}
184+
else if ((c & 0xF0) == 0xE0) // 3-byte
185+
{
186+
cp = ((c & 0x0F) << 12) | ((str.data[i + 1] & 0x3F) << 6) | (str.data[i + 2] & 0x3F);
187+
i += 3;
188+
}
189+
else if ((c & 0xF8) == 0xF0) // 4-byte
190+
{
191+
cp = ((c & 0x07) << 18) | ((str.data[i + 1] & 0x3F) << 12) | ((str.data[i + 2] & 0x3F) << 6) | (str.data[i + 3] & 0x3F);
192+
i += 4;
193+
}
194+
if (isArabicIndicDigit(cp))
195+
result.push_back(toAsciiDigit(cp));
196+
else
197+
result.push_back(cp);
198+
}
199+
return result;
200+
}
201+
};
202+
203+
using namespace DB;
204+
REGISTER_FUNCTION(ArabicIndicToAsciiDigitForDate)
205+
{
206+
factory.registerFunction<ArabicIndicToAsciiDigitForDateFunction>();
207+
}
208+
}

cpp-ch/local-engine/Parser/scalar_function_parser/getTimestamp.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class FunctionParserGetTimestamp : public FunctionParser
6161
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
6262
if (parsed_args.size() != 2)
6363
throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName());
64-
const auto * expr_arg = parsed_args[0];
64+
const auto * expr_arg = convertArabicIndicDigit(actions_dag, parsed_args[0]);
6565
const auto * fmt_arg = parsed_args[1];
6666

6767
const auto & args = substrait_func.arguments();
@@ -129,5 +129,11 @@ class FunctionParserGetTimestamp : public FunctionParser
129129
return std::regex_match(fmt, fmtPattern);
130130
}
131131
}
132+
133+
const DB::ActionsDAG::Node * convertArabicIndicDigit(DB::ActionsDAG & actions_dag, const DB::ActionsDAG::Node * node) const
134+
{
135+
const auto * func_node = toFunctionNode(actions_dag, "arabic_indic_to_ascii_digit_for_date", {node});
136+
return func_node;
137+
}
132138
};
133139
}

cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ class FunctionParserUnixTimestamp : public FunctionParserGetTimestamp
5959
throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName());
6060

6161
const auto * expr_arg = parsed_args[0];
62-
const auto * fmt_arg = parsed_args[1];
6362
auto expr_type = removeNullable(expr_arg->result_type);
6463
if (isString(expr_type))
6564
return FunctionParserGetTimestamp::parse(substrait_func, actions_dag);

0 commit comments

Comments
 (0)