2929#include < Common/CHUtil.h>
3030#include < Common/JNIUtils.h>
3131#include < Common/logger_useful.h>
32+ #include < DataTypes/DataTypesNumber.h>
3233
3334namespace DB
3435{
@@ -67,12 +68,12 @@ DB::Block resetBuildTableBlockName(Block & block, bool only_one = false)
6768 // add a sequence to avoid duplicate name in some rare cases
6869 if (names.find (col.name ) == names.end ())
6970 {
70- new_name << BlockUtil::RIHGT_COLUMN_PREFIX << col.name ;
71+ new_name << BlockUtil::RIGHT_COLUMN_PREFIX << col.name ;
7172 names.insert (col.name );
7273 }
7374 else
7475 {
75- new_name << BlockUtil::RIHGT_COLUMN_PREFIX << (seq++) << " _" << col.name ;
76+ new_name << BlockUtil::RIGHT_COLUMN_PREFIX << (seq++) << " _" << col.name ;
7677 }
7778 new_cols.emplace_back (col.column , col.type , new_name.str ());
7879
@@ -108,6 +109,51 @@ std::shared_ptr<StorageJoinFromReadBuffer> getJoin(const std::string & key)
108109 return wrapper;
109110}
110111
112+ // A join in cross rel.
113+ static bool isCrossRelJoin (const std::string & key)
114+ {
115+ return key.starts_with (" BuiltBNLJBroadcastTable-" );
116+ }
117+
118+ static void collectBlocksForCountingRows (NativeReader & block_stream, Block & header, Blocks & result)
119+ {
120+ ProfileInfo profile;
121+ Block block = block_stream.read ();
122+ while (!block.empty ())
123+ {
124+ const auto & col = block.getByPosition (0 );
125+ auto counting_col = BlockUtil::buildRowCountBlock (col.column ->size ()).getColumnsWithTypeAndName ()[0 ];
126+ DB::ColumnsWithTypeAndName columns;
127+ columns.emplace_back (counting_col.column ->convertToFullColumnIfConst (), counting_col.type , counting_col.name );
128+ DB::Block new_block (columns);
129+ profile.update (new_block);
130+ result.emplace_back (std::move (new_block));
131+ block = block_stream.read ();
132+ }
133+ header = BlockUtil::buildRowCountHeader ();
134+ }
135+
136+ static void collectBlocksForJoinRel (NativeReader & reader, Block & header, Blocks & result)
137+ {
138+ ProfileInfo profile;
139+ Block block = reader.read ();
140+ while (!block.empty ())
141+ {
142+ DB::ColumnsWithTypeAndName columns;
143+ for (size_t i = 0 ; i < block.columns (); ++i)
144+ {
145+ const auto & column = block.getByPosition (i);
146+ columns.emplace_back (BlockUtil::convertColumnAsNecessary (column, header.getByPosition (i)));
147+ }
148+
149+ DB::Block final_block (columns);
150+ profile.update (final_block);
151+ result.emplace_back (std::move (final_block));
152+
153+ block = reader.read ();
154+ }
155+ }
156+
111157std::shared_ptr<StorageJoinFromReadBuffer> buildJoin (
112158 const std::string & key,
113159 DB::ReadBuffer & input,
@@ -123,12 +169,14 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
123169 auto join_key_list = Poco::StringTokenizer (join_keys, " ," );
124170 Names key_names;
125171 for (const auto & key_name : join_key_list)
126- key_names.emplace_back (BlockUtil::RIHGT_COLUMN_PREFIX + key_name);
172+ key_names.emplace_back (BlockUtil::RIGHT_COLUMN_PREFIX + key_name);
127173
128174 DB::JoinKind kind;
129175 DB::JoinStrictness strictness;
176+ bool is_cross_rel_join = isCrossRelJoin (key);
177+ assert (is_cross_rel_join && key_names.empty ()); // cross rel join should not have join keys
130178
131- if (key. starts_with ( " BuiltBNLJBroadcastTable- " ) )
179+ if (is_cross_rel_join )
132180 std::tie (kind, strictness) = JoinUtil::getCrossJoinKindAndStrictness (static_cast <substrait::CrossRel_JoinType>(join_type));
133181 else
134182 std::tie (kind, strictness) = JoinUtil::getJoinKindAndStrictness (static_cast <substrait::JoinRel_JoinType>(join_type), is_existence_join);
@@ -139,40 +187,41 @@ std::shared_ptr<StorageJoinFromReadBuffer> buildJoin(
139187 Block header = TypeParser::buildBlockFromNamedStruct (substrait_struct);
140188 header = resetBuildTableBlockName (header);
141189
190+ bool only_one_column = header.getNamesAndTypesList ().empty ();
191+ if (only_one_column)
192+ header = BlockUtil::buildRowCountBlock (0 ).getColumnsWithTypeAndName ();
193+
142194 Blocks data;
143- auto collect_data = [&]
195+ auto collect_data = [&]()
144196 {
145- bool only_one_column = header. getNamesAndTypesList (). empty ( );
197+ NativeReader block_stream (input );
146198 if (only_one_column)
147- header = BlockUtil::buildRowCountBlock (0 ).getColumnsWithTypeAndName ();
199+ collectBlocksForCountingRows (block_stream, header, data);
200+ else
201+ collectBlocksForJoinRel (block_stream, header, data);
148202
149- NativeReader block_stream (input);
150- ProfileInfo info;
151- Block block = block_stream.read ();
152- while (!block.empty ())
203+ // For not cross join, we need to add a constant join key column
204+ // to make it behavior like a normal join.
205+ if (is_cross_rel_join && kind != JoinKind::Cross)
153206 {
154- DB::ColumnsWithTypeAndName columns;
155- for (size_t i = 0 ; i < block.columns (); ++i)
207+ auto data_type_u8 = std::make_shared<DataTypeUInt8>();
208+ UInt8 const_key_val = 0 ;
209+ String const_key_name = JoinUtil::CROSS_REL_RIGHT_CONST_KEY_COLUMN;
210+ Blocks new_data;
211+ for (const auto & block : data)
156212 {
157- const auto & column = block.getByPosition (i);
158- if (only_one_column)
159- {
160- auto virtual_block = BlockUtil::buildRowCountBlock (column.column ->size ()).getColumnsWithTypeAndName ();
161- header = virtual_block;
162- columns.emplace_back (virtual_block.back ());
163- break ;
164- }
165-
166- columns.emplace_back (BlockUtil::convertColumnAsNecessary (column, header.getByPosition (i)));
213+ auto cols = block.getColumnsWithTypeAndName ();
214+ cols.emplace_back (data_type_u8->createColumnConst (block.rows (), const_key_val), data_type_u8, const_key_name);
215+ new_data.emplace_back (Block (cols));
167216 }
168-
169- DB::Block final_block (columns);
170- info.update (final_block);
171- data.emplace_back (std::move (final_block));
172-
173- block = block_stream.read ();
217+ data.swap (new_data);
218+ key_names.emplace_back (const_key_name);
219+ auto cols = header.getColumnsWithTypeAndName ();
220+ cols.emplace_back (data_type_u8->createColumnConst (0 , const_key_val), data_type_u8, const_key_name);
221+ header = Block (cols);
174222 }
175223 };
224+
176225 // / Record memory usage in Total Memory Tracker
177226 ThreadFromGlobalPoolNoTracingContextPropagation thread (collect_data);
178227 thread.join ();
0 commit comments