diff --git a/dbms/src/Functions/FunctionsString.cpp b/dbms/src/Functions/FunctionsString.cpp index 5cf47842889..39c9ca8968e 100644 --- a/dbms/src/Functions/FunctionsString.cpp +++ b/dbms/src/Functions/FunctionsString.cpp @@ -1523,6 +1523,7 @@ class FunctionSubstringUTF8 : public IFunction { const ColumnPtr column_string = block.getByPosition(arguments[0]).column; +<<<<<<< HEAD const ColumnPtr column_start = block.getByPosition(arguments[1]).column; if (!column_start->isColumnConst()) throw Exception("2nd arguments of function " + getName() + " must be constants."); @@ -1531,6 +1532,216 @@ class FunctionSubstringUTF8 : public IFunction throw Exception("2nd argument of function " + getName() + " must have UInt/Int type."); Int64 start; if (start_field.getType() == Field::Types::Int64) +======= + const ColumnPtr & column_start = block.getByPosition(arguments[1]).column; + + bool implicit_length = (arguments.size() == 2); + + bool is_start_type_valid + = getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) { + using StartType = std::decay_t; + using StartFieldType = typename StartType::FieldType; + const ColumnVector * column_vector_start + = getInnerColumnVector(column_start); + if unlikely (!column_vector_start) + throw Exception( + fmt::format( + "Illegal type {} of argument 2 of function {}", + block.getByPosition(arguments[1]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + // vector const const + if (!column_string->isColumnConst() && column_start->isColumnConst() + && (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst())) + { + auto [is_positive, start_abs] = getValueFromStartColumn(*column_vector_start, 0); + UInt64 length = 0; + if (!implicit_length) + { + bool is_length_type_valid = getNumberType( + block.getByPosition(arguments[2]).type, + [&](const auto & length_type, bool) { + using LengthType = std::decay_t; + using LengthFieldType = typename LengthType::FieldType; + const ColumnVector * column_vector_length + = getInnerColumnVector(block.getByPosition(arguments[2]).column); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 3 of function {}", + block.getByPosition(arguments[2]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + length = getValueFromLengthColumn(*column_vector_length, 0); + return true; + }); + + if (!is_length_type_valid) + throw Exception( + fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); + } + + // for const zero start or const zero length, return const blank string. + if (start_abs == 0 || (!implicit_length && length == 0)) + { + block.getByPosition(result).column + = DataTypeString().createColumnConst(column_string->size(), toField(String(""))); + return true; + } + + const auto * col = checkAndGetColumn(column_string.get()); + assert(col); + auto col_res = ColumnString::create(); + getVectorConstConstFunc(implicit_length, is_positive)( + col->getChars(), + col->getOffsets(), + start_abs, + length, + col_res->getChars(), + col_res->getOffsets()); + block.getByPosition(result).column = std::move(col_res); + } + else // all other cases are converted to vector vector vector + { + std::function(size_t)> get_start_func; + if (column_start->isColumnConst()) + { + // func always return const value + auto start_const = getValueFromStartColumn(*column_vector_start, 0); + get_start_func = [start_const](size_t) { + return start_const; + }; + } + else + { + get_start_func = [column_vector_start](size_t i) { + return getValueFromStartColumn(*column_vector_start, i); + }; + } + + // if implicit_length, get_length_func be nil is ok. + std::function get_length_func; + if (!implicit_length) + { + const ColumnPtr & column_length = block.getByPosition(arguments[2]).column; + bool is_length_type_valid = getNumberType( + block.getByPosition(arguments[2]).type, + [&](const auto & length_type, bool) { + using LengthType = std::decay_t; + using LengthFieldType = typename LengthType::FieldType; + const ColumnVector * column_vector_length + = getInnerColumnVector(column_length); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 3 of function {}", + block.getByPosition(arguments[2]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + if (column_length->isColumnConst()) + { + // func always return const value + auto length_const + = getValueFromLengthColumn(*column_vector_length, 0); + get_length_func = [length_const](size_t) { + return length_const; + }; + } + else + { + get_length_func = [column_vector_length](size_t i) { + return getValueFromLengthColumn(*column_vector_length, i); + }; + } + return true; + }); + + if unlikely (!is_length_type_valid) + throw Exception( + fmt::format("3nd argument of function {} must have UInt/Int type.", getName())); + } + + // convert to vector if string is const. + ColumnPtr full_column_string = column_string->isColumnConst() + ? column_string->convertToFullColumnIfConst() + : column_string; + const auto * col = checkAndGetColumn(full_column_string.get()); + assert(col); + auto col_res = ColumnString::create(); + if (implicit_length) + { + SubstringUTF8Impl::vectorVectorVector( + col->getChars(), + col->getOffsets(), + get_start_func, + get_length_func, + col_res->getChars(), + col_res->getOffsets()); + } + else + { + SubstringUTF8Impl::vectorVectorVector( + col->getChars(), + col->getOffsets(), + get_start_func, + get_length_func, + col_res->getChars(), + col_res->getOffsets()); + } + block.getByPosition(result).column = std::move(col_res); + } + + return true; + }); + + if unlikely (!is_start_type_valid) + throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName())); + } + + template + static const ColumnVector * getInnerColumnVector(const ColumnPtr & column) + { + if (column->isColumnConst()) + return checkAndGetColumn>( + checkAndGetColumn(column.get())->getDataColumnPtr().get()); + return checkAndGetColumn>(column.get()); + } + + template + static size_t getValueFromLengthColumn(const ColumnVector & column, size_t index) + { + Integer val = column.getElement(index); + if constexpr ( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v) + { + return val < 0 ? 0 : val; + } + else + { + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v); + return val; + } + } + +private: + using VectorConstConstFunc = std::function; + + static VectorConstConstFunc getVectorConstConstFunc(bool implicit_length, bool is_positive_start) + { + if (implicit_length) +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) { start = start_field.get(); } @@ -1542,6 +1753,7 @@ class FunctionSubstringUTF8 : public IFunction start = static_cast(u_start); } +<<<<<<< HEAD bool implicit_length = true; UInt64 length = 0; if (arguments.size() == 3) @@ -1590,6 +1802,42 @@ class FunctionSubstringUTF8 : public IFunction throw Exception( "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); +======= + // return {is_positive, abs} + template + static std::pair getValueFromStartColumn(const ColumnVector & column, size_t index) + { + Integer val = column.getElement(index); + if constexpr ( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v) + { + if (val < 0) + return {false, static_cast(-val)}; + return {true, static_cast(val)}; + } + else + { + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v + || std::is_same_v); + return {true, val}; + } + } + + template + static bool getNumberType(DataTypePtr type, F && f) + { + return castTypeToEither< + DataTypeUInt8, + DataTypeUInt16, + DataTypeUInt32, + DataTypeUInt64, + DataTypeInt8, + DataTypeInt16, + DataTypeInt32, + DataTypeInt64>(type.get(), std::forward(f)); +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) } }; @@ -1643,6 +1891,7 @@ class FunctionRightUTF8 : public IFunction const ColumnPtr column_string = block.getByPosition(arguments[0]).column; const ColumnPtr column_length = block.getByPosition(arguments[1]).column; +<<<<<<< HEAD if (!column_length->isColumnConst()) throw Exception("2nd arguments of function " + getName() + " must be constants."); Field length_field = (*block.getByPosition(arguments[1]).column)[0]; @@ -1677,6 +1926,114 @@ class FunctionRightUTF8 : public IFunction throw Exception( "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); +======= + + bool is_length_type_valid + = getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) { + using LengthType = std::decay_t; + using LengthFieldType = typename LengthType::FieldType; + + const ColumnVector * column_vector_length + = FunctionSubstringUTF8::getInnerColumnVector(column_length); + if unlikely (!column_vector_length) + throw Exception( + fmt::format( + "Illegal type {} of argument 2 of function {}", + block.getByPosition(arguments[1]).type->getName(), + getName()), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + + auto col_res = ColumnString::create(); + if (const auto * col_string = checkAndGetColumn(column_string.get())) + { + if (column_length->isColumnConst()) + { + // vector const + size_t length = FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + 0); + + // for const 0, return const blank string. + if (0 == length) + { + block.getByPosition(result).column + = DataTypeString().createColumnConst(column_string->size(), toField(String(""))); + return true; + } + + RightUTF8Impl::vectorConst( + col_string->getChars(), + col_string->getOffsets(), + length, + col_res->getChars(), + col_res->getOffsets()); + } + else + { + // vector vector + auto get_length_func = [column_vector_length](size_t i) { + return FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + i); + }; + RightUTF8Impl::vectorVector( + col_string->getChars(), + col_string->getOffsets(), + get_length_func, + col_res->getChars(), + col_res->getOffsets()); + } + } + else if ( + const ColumnConst * col_const_string = checkAndGetColumnConst(column_string.get())) + { + // const vector + const auto * col_string_from_const + = checkAndGetColumn(col_const_string->getDataColumnPtr().get()); + assert(col_string_from_const); + // When useDefaultImplementationForConstants is true, string and length are not both constants + assert(!column_length->isColumnConst()); + auto get_length_func = [column_vector_length](size_t i) { + return FunctionSubstringUTF8::getValueFromLengthColumn( + *column_vector_length, + i); + }; + RightUTF8Impl::constVector( + column_length->size(), + col_string_from_const->getChars(), + col_string_from_const->getOffsets(), + get_length_func, + col_res->getChars(), + col_res->getOffsets()); + } + else + { + // Impossible to reach here + return false; + } + block.getByPosition(result).column = std::move(col_res); + return true; + }); + + if (!is_length_type_valid) + throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName())); + } + +private: + template + static bool getLengthType(DataTypePtr type, F && f) + { + return castTypeToEither< + DataTypeUInt8, + DataTypeUInt16, + DataTypeUInt32, + DataTypeUInt64, + DataTypeInt8, + DataTypeInt16, + DataTypeInt32, + DataTypeInt64>(type.get(), std::forward(f)); +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) } }; diff --git a/dbms/src/Functions/tests/gtest_string_left.cpp b/dbms/src/Functions/tests/gtest_string_left.cpp new file mode 100644 index 00000000000..f5be8fcdfbf --- /dev/null +++ b/dbms/src/Functions/tests/gtest_string_left.cpp @@ -0,0 +1,163 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include +#include + +namespace DB +{ +namespace tests +{ +class StringLeftTest : public DB::tests::FunctionTest +{ +public: + static constexpr auto func_name = "leftUTF8"; + + template + void testBoundary() + { + std::vector> strings = {"", "www.pingcap", "中文.测.试。。。", {}}; + for (const auto & str : strings) + { + auto return_null_if_str_null = [&](const std::optional & not_null_result) { + return !str.has_value() ? std::optional{} : not_null_result; + }; + test(str, 0, return_null_if_str_null("")); + test(str, std::numeric_limits::max(), return_null_if_str_null(str)); + test(str, std::optional{}, std::optional{}); + if constexpr (std::is_signed_v) + { + test(str, -1, return_null_if_str_null("")); + test(str, std::numeric_limits::min(), return_null_if_str_null("")); + } + } + } + + template + void test( + const std::optional & str, + const std::optional & length, + const std::optional & result) + { + auto inner_test = [&](bool is_str_const, bool is_length_const) { + bool is_one_of_args_null_const + = (is_str_const && !str.has_value()) || (is_length_const && !length.has_value()); + bool is_result_const = (is_str_const && is_length_const) || is_one_of_args_null_const; + auto expected_res_column = is_result_const + ? (is_one_of_args_null_const ? createConstColumn>(1, result) + : createConstColumn(1, result.value())) + : createColumn>({result}); + auto str_column + = is_str_const ? createConstColumn>(1, str) : createColumn>({str}); + auto length_column = is_length_const ? createConstColumn>(1, length) + : createColumn>({length}); + auto actual_res_column = executeFunction(func_name, str_column, length_column); + ASSERT_COLUMN_EQ(expected_res_column, actual_res_column); + }; + std::vector is_consts = {true, false}; + for (bool is_str_const : is_consts) + for (bool is_length_const : is_consts) + inner_test(is_str_const, is_length_const); + } +}; + +TEST_F(StringLeftTest, testBoundary) +try +{ + testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); +} +CATCH + +TEST_F(StringLeftTest, testMoreCases) +try +{ +#define CALL(A, B, C) \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); + + // test big string + // big_string.size() > length + String big_string; + // unit_string length = 22 + String unit_string = "big string is 我!!!!!!!"; + for (size_t i = 0; i < 1000; ++i) + big_string += unit_string; + CALL(big_string, 22, unit_string); + + // test origin_str.size() == length + String origin_str = "我的 size = 12"; + CALL(origin_str, 12, origin_str); + + // test origin_str.size() < length + CALL(origin_str, 22, origin_str); + + // Mixed language + String english_str = "This is English"; + String mixed_language_str = english_str + ",这是中文,C'est français,これが日本の"; + CALL(mixed_language_str, english_str.size(), english_str); + + // column size != 1 + // case 1 + ASSERT_COLUMN_EQ( + createColumn>({unit_string, origin_str, origin_str, english_str}), + executeFunction( + func_name, + createColumn>({big_string, origin_str, origin_str, mixed_language_str}), + createColumn>({22, 12, 22, english_str.size()}))); + // case 2 + String second_case_string = "abc"; + ASSERT_COLUMN_EQ( + createColumn>({"", "a", "", "a", "", "", "a", "a"}), + executeFunction( + func_name, + createColumn>( + {second_case_string, + second_case_string, + second_case_string, + second_case_string, + second_case_string, + second_case_string, + second_case_string, + second_case_string}), + createColumn>({0, 1, 0, 1, 0, 0, 1, 1}))); + ASSERT_COLUMN_EQ( + createColumn>({"", "a", "", "a", "", "", "a", "a"}), + executeFunction( + func_name, + createConstColumn>(8, second_case_string), + createColumn>({0, 1, 0, 1, 0, 0, 1, 1}))); + +#undef CALL +} +CATCH + +} // namespace tests +} // namespace DB diff --git a/dbms/src/Functions/tests/gtest_strings_right.cpp b/dbms/src/Functions/tests/gtest_strings_right.cpp new file mode 100644 index 00000000000..1371077f005 --- /dev/null +++ b/dbms/src/Functions/tests/gtest_strings_right.cpp @@ -0,0 +1,158 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include +#include +#include + +namespace DB::tests +{ +class StringRightTest : public DB::tests::FunctionTest +{ +public: + static constexpr auto func_name = "rightUTF8"; + + template + void testBoundary() + { + std::vector> strings = {"", "www.pingcap", "中文.测.试。。。", {}}; + for (const auto & str : strings) + { + auto return_null_if_str_null = [&](const std::optional & not_null_result) { + return !str.has_value() ? std::optional{} : not_null_result; + }; + test(str, 0, return_null_if_str_null("")); + test(str, std::numeric_limits::max(), return_null_if_str_null(str)); + test(str, std::optional{}, std::optional{}); + if constexpr (std::is_signed_v) + { + test(str, -1, return_null_if_str_null("")); + test(str, std::numeric_limits::min(), return_null_if_str_null("")); + } + } + } + + template + void test(const std::optional & str, std::optional length, const std::optional & result) + { + auto inner_test = [&](bool is_str_const, bool is_length_const) { + bool is_one_of_args_null_const + = (is_str_const && !str.has_value()) || (is_length_const && !length.has_value()); + bool is_result_const = (is_str_const && is_length_const) || is_one_of_args_null_const; + auto expected_res_column = is_result_const + ? (is_one_of_args_null_const ? createConstColumn>(1, result) + : createConstColumn(1, result.value())) + : createColumn>({result}); + auto str_column + = is_str_const ? createConstColumn>(1, str) : createColumn>({str}); + auto length_column = is_length_const ? createConstColumn>(1, length) + : createColumn>({length}); + auto actual_res_column = executeFunction(func_name, str_column, length_column); + ASSERT_COLUMN_EQ(expected_res_column, actual_res_column); + }; + std::vector is_consts = {true, false}; + for (bool is_str_const : is_consts) + for (bool is_length_const : is_consts) + inner_test(is_str_const, is_length_const); + } +}; + +TEST_F(StringRightTest, testBoundary) +try +{ + testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); + testBoundary(); +} +CATCH + +TEST_F(StringRightTest, testMoreCases) +try +{ +#define CALL(A, B, C) \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); \ + test(A, B, C); + + // test big string + // big_string.size() > length + String big_string; + // unit_string length = 22 + String unit_string = "big string is 我!!!!!!!"; + for (size_t i = 0; i < 1000; ++i) + big_string += unit_string; + CALL(big_string, 22, unit_string); + + // test origin_str.size() == length + String origin_str = "我的 size = 12"; + CALL(origin_str, 12, origin_str); + + // test origin_str.size() < length + CALL(origin_str, 22, origin_str); + + // Mixed language + String english_str = "This is English"; + String mixed_language_str = "这是中文,C'est français,これが日本の," + english_str; + CALL(mixed_language_str, english_str.size(), english_str); + + // column size != 1 + // case 1 + ASSERT_COLUMN_EQ( + createColumn>({unit_string, origin_str, origin_str, english_str}), + executeFunction( + func_name, + createColumn>({big_string, origin_str, origin_str, mixed_language_str}), + createColumn>({22, 12, 22, english_str.size()}))); + // case 2 + String second_case_string = "abc"; + ASSERT_COLUMN_EQ( + createColumn>({"", "c", "", "c", "", "", "c", "c"}), + executeFunction( + func_name, + createColumn>( + {second_case_string, + second_case_string, + second_case_string, + second_case_string, + second_case_string, + second_case_string, + second_case_string, + second_case_string}), + createColumn>({0, 1, 0, 1, 0, 0, 1, 1}))); + ASSERT_COLUMN_EQ( + createColumn>({"", "c", "", "c", "", "", "c", "c"}), + executeFunction( + func_name, + createConstColumn>(8, second_case_string), + createColumn>({0, 1, 0, 1, 0, 0, 1, 1}))); + +#undef CALL +} +CATCH + +} // namespace DB::tests diff --git a/dbms/src/Functions/tests/gtest_substring.cpp b/dbms/src/Functions/tests/gtest_substring.cpp index 87cb351f407..57b2b008b5c 100644 --- a/dbms/src/Functions/tests/gtest_substring.cpp +++ b/dbms/src/Functions/tests/gtest_substring.cpp @@ -14,9 +14,160 @@ class SubString : public DB::tests::FunctionTest { }; +template +class TestNullableSigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", "", {}, {}, {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}, + "pingcap", + "pingcap"}), + createColumn({-5, 1, 3, -3, 8, 2, -100, 0, 2, {}, -3}), + createColumn({4, 4, 7, 4, 5, -5, 2, 3, 6, 4, {}}))); + } +}; + +template +class TestSigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", "", {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}}), + createColumn({-5, 1, 3, -3, 8, 2, -100, 0, 2}), + createColumn({4, 4, 7, 4, 5, -5, 2, 3, 6}))); + } +}; + +template +class TestNullableUnsigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", {}, {}, {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}, + "pingcap", + "pingcap"}), + createColumn({11, 1, 3, 10, 8, 2, 0, 9, {}, 7}), + createColumn({4, 4, 7, 4, 5, 0, 3, 6, 1, {}}))); + } +}; + +template +class TestUnsigned +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"p.co", "ww.p", "pingcap", "com", ".com", "", "", {}}), + sub_string.executeFunction( + "substringUTF8", + createColumn>( + {"www.pingcap.com", + "ww.pingcap.com", + "w.pingcap.com", + ".pingcap.com", + "pingcap.com", + "pingcap.com", + "pingcap.com", + {}}), + createColumn({11, 1, 3, 10, 8, 2, 0, 2}), + createColumn({4, 4, 7, 4, 5, 0, 3, 1}))); + } +}; + +template +class TestConstPos +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"w", "ww", "w.p", ".pin"}), + sub_string.executeFunction( + "substringUTF8", + createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), + createConstColumn(4, 1), + createColumn({1, 2, 3, 4}))); + } +}; + +template +class TestConstLength +{ +public: + static void operator()(SubString & sub_string) + { + ASSERT_COLUMN_EQ( + createColumn>({"www.", "w.pi", "ping", "ngca"}), + sub_string.executeFunction( + "substringUTF8", + createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), + createColumn({1, 2, 3, 4}), + createConstColumn(4, 4))); + } +}; + TEST_F(SubString, subStringUTF8Test) try { + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + + TestTypePair::run(*this); + TestTypePair::run(*this); + // column, const, const ASSERT_COLUMN_EQ( createColumn>({"www.", "ww.p", "w.pi", ".pin"}), @@ -25,6 +176,7 @@ try createColumn>({"www.pingcap.com", "ww.pingcap.com", "w.pingcap.com", ".pingcap.com"}), createConstColumn>(4, 1), createConstColumn>(4, 4))); + // const, const, const ASSERT_COLUMN_EQ( createConstColumn>(1, "www."), @@ -33,6 +185,7 @@ try createConstColumn>(1, "www.pingcap.com"), createConstColumn>(1, 1), createConstColumn>(1, 4))); +<<<<<<< HEAD // Test Null ASSERT_COLUMN_EQ( createColumn>({{}, "www."}), @@ -42,8 +195,10 @@ try {{}, "www.pingcap.com"}), createConstColumn>(2, 1), createConstColumn>(2, 4))); +======= +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) } CATCH } // namespace tests -} // namespace DB \ No newline at end of file +} // namespace DB diff --git a/dbms/src/TestUtils/FunctionTestUtils.h b/dbms/src/TestUtils/FunctionTestUtils.h index 9217b651bcb..b27f591a3df 100644 --- a/dbms/src/TestUtils/FunctionTestUtils.h +++ b/dbms/src/TestUtils/FunctionTestUtils.h @@ -462,6 +462,85 @@ class FunctionTest : public ::testing::Test std::unique_ptr dag_context_ptr; }; +template +struct TestTypeList +{ +}; + +using TestNullableIntTypes = TestTypeList, Nullable, Nullable, Nullable>; + +using TestNullableUIntTypes = TestTypeList, Nullable, Nullable, Nullable>; + +using TestIntTypes = TestTypeList; + +using TestUIntTypes = TestTypeList; + +using TestAllIntTypes + = TestTypeList, Nullable, Nullable, Nullable, Int8, Int16, Int32, Int64>; + +using TestAllUIntTypes = TestTypeList< + Nullable, + Nullable, + Nullable, + Nullable, + UInt8, + UInt16, + UInt32, + UInt64>; + +template class Func, typename FuncParam> +struct TestTypeSingle; + +template class Func, typename FuncParam> +struct TestTypeSingle, Func, FuncParam> +{ + static void run(FuncParam & p) + { + Func::operator()(p); + // Recursively handle the rest of T2List + TestTypeSingle, Func, FuncParam>::run(p); + } +}; + +template class Func, typename FuncParam> +struct TestTypeSingle, Func, FuncParam> +{ + static void run(FuncParam &) + { + // Do nothing when T2List is empty + } +}; + +template class Func, typename FuncParam> +struct TestTypePair; + +template < + typename T1, + typename... T1Rest, + typename T2List, + template + class Func, + typename FuncParam> +struct TestTypePair, T2List, Func, FuncParam> +{ + static void run(FuncParam & p) + { + // For the current T1, traverse all types in T2List + TestTypeSingle::run(p); + // Recursively handle the rest of T1List + TestTypePair, T2List, Func, FuncParam>::run(p); + } +}; + +template class Func, typename FuncParam> +struct TestTypePair, T2List, Func, FuncParam> +{ + static void run(FuncParam &) + { + // Do nothing when T1List is empty + } +}; + #define ASSERT_COLUMN_EQ(expected, actual) ASSERT_TRUE(DB::tests::columnEqual((expected), (actual))) } // namespace tests } // namespace DB diff --git a/tests/fullstack-test/expr/substring_utf8.test b/tests/fullstack-test/expr/substring_utf8.test index cea9ec5e009..ba656e89663 100644 --- a/tests/fullstack-test/expr/substring_utf8.test +++ b/tests/fullstack-test/expr/substring_utf8.test @@ -1,10 +1,11 @@ mysql> drop table if exists test.t -mysql> create table test.t(a char(10)) -mysql> insert into test.t values(''), ('abc') +mysql> create table test.t(a char(10), b int, c tinyint unsigned) +mysql> insert into test.t values('', -3, 2), ('abc', -3, 2) mysql> alter table test.t set tiflash replica 1 func> wait_table test t +<<<<<<< HEAD mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select * from test.t where substring(a, -3, 4) = 'abc' a abc @@ -14,6 +15,21 @@ a abc mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select * from test.t where substring(a, -4, 3) = 'abc' +======= +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -3, 4) = 'abc' +a +abc + +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -3, 2) = 'ab' +a +abc + +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, b, c) = 'ab' +a +abc + +mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; set tidb_allow_tiflash_cop = ON; select a from test.t where substring(a, -4, 3) = 'abc' +>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507)) # Empty mysql> set session tidb_isolation_read_engines='tiflash'; set tidb_allow_mpp=0; select count(*) from test.t where substring(a, 0, 3) = '' order by a