Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507) #9515

Open
wants to merge 1 commit into
base: release-5.4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
357 changes: 357 additions & 0 deletions dbms/src/Functions/FunctionsString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,7 @@ class FunctionSubstringUTF8 : public IFunction
{
const ColumnPtr column_string = block.getByPosition(arguments[0]).column;

<<<<<<< HEAD
const ColumnPtr column_start = block.getByPosition(arguments[1]).column;
if (!column_start->isColumnConst())
throw Exception("2nd arguments of function " + getName() + " must be constants.");
Expand All @@ -1531,6 +1532,216 @@ class FunctionSubstringUTF8 : public IFunction
throw Exception("2nd argument of function " + getName() + " must have UInt/Int type.");
Int64 start;
if (start_field.getType() == Field::Types::Int64)
=======
const ColumnPtr & column_start = block.getByPosition(arguments[1]).column;

bool implicit_length = (arguments.size() == 2);

bool is_start_type_valid
= getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) {
using StartType = std::decay_t<decltype(start_type)>;
using StartFieldType = typename StartType::FieldType;
const ColumnVector<StartFieldType> * column_vector_start
= getInnerColumnVector<StartFieldType>(column_start);
if unlikely (!column_vector_start)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

// vector const const
if (!column_string->isColumnConst() && column_start->isColumnConst()
&& (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst()))
{
auto [is_positive, start_abs] = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
UInt64 length = 0;
if (!implicit_length)
{
bool is_length_type_valid = getNumberType(
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
using LengthFieldType = typename LengthType::FieldType;
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(block.getByPosition(arguments[2]).column);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

length = getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
return true;
});

if (!is_length_type_valid)
throw Exception(
fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
}

// for const zero start or const zero length, return const blank string.
if (start_abs == 0 || (!implicit_length && length == 0))
{
block.getByPosition(result).column
= DataTypeString().createColumnConst(column_string->size(), toField(String("")));
return true;
}

const auto * col = checkAndGetColumn<ColumnString>(column_string.get());
assert(col);
auto col_res = ColumnString::create();
getVectorConstConstFunc(implicit_length, is_positive)(
col->getChars(),
col->getOffsets(),
start_abs,
length,
col_res->getChars(),
col_res->getOffsets());
block.getByPosition(result).column = std::move(col_res);
}
else // all other cases are converted to vector vector vector
{
std::function<std::pair<bool, size_t>(size_t)> get_start_func;
if (column_start->isColumnConst())
{
// func always return const value
auto start_const = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
get_start_func = [start_const](size_t) {
return start_const;
};
}
else
{
get_start_func = [column_vector_start](size_t i) {
return getValueFromStartColumn<StartFieldType>(*column_vector_start, i);
};
}

// if implicit_length, get_length_func be nil is ok.
std::function<size_t(size_t)> get_length_func;
if (!implicit_length)
{
const ColumnPtr & column_length = block.getByPosition(arguments[2]).column;
bool is_length_type_valid = getNumberType(
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
using LengthFieldType = typename LengthType::FieldType;
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

if (column_length->isColumnConst())
{
// func always return const value
auto length_const
= getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
get_length_func = [length_const](size_t) {
return length_const;
};
}
else
{
get_length_func = [column_vector_length](size_t i) {
return getValueFromLengthColumn<LengthFieldType>(*column_vector_length, i);
};
}
return true;
});

if unlikely (!is_length_type_valid)
throw Exception(
fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
}

// convert to vector if string is const.
ColumnPtr full_column_string = column_string->isColumnConst()
? column_string->convertToFullColumnIfConst()
: column_string;
const auto * col = checkAndGetColumn<ColumnString>(full_column_string.get());
assert(col);
auto col_res = ColumnString::create();
if (implicit_length)
{
SubstringUTF8Impl::vectorVectorVector<true>(
col->getChars(),
col->getOffsets(),
get_start_func,
get_length_func,
col_res->getChars(),
col_res->getOffsets());
}
else
{
SubstringUTF8Impl::vectorVectorVector<false>(
col->getChars(),
col->getOffsets(),
get_start_func,
get_length_func,
col_res->getChars(),
col_res->getOffsets());
}
block.getByPosition(result).column = std::move(col_res);
}

return true;
});

if unlikely (!is_start_type_valid)
throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName()));
}

template <typename Integer>
static const ColumnVector<Integer> * getInnerColumnVector(const ColumnPtr & column)
{
if (column->isColumnConst())
return checkAndGetColumn<ColumnVector<Integer>>(
checkAndGetColumn<ColumnConst>(column.get())->getDataColumnPtr().get());
return checkAndGetColumn<ColumnVector<Integer>>(column.get());
}

template <typename Integer>
static size_t getValueFromLengthColumn(const ColumnVector<Integer> & column, size_t index)
{
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
return val < 0 ? 0 : val;
}
else
{
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return val;
}
}

private:
using VectorConstConstFunc = std::function<void(
const ColumnString::Chars_t &,
const ColumnString::Offsets &,
size_t,
size_t,
ColumnString::Chars_t &,
ColumnString::Offsets &)>;

static VectorConstConstFunc getVectorConstConstFunc(bool implicit_length, bool is_positive_start)
{
if (implicit_length)
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))
{
start = start_field.get<Int64>();
}
Expand All @@ -1542,6 +1753,7 @@ class FunctionSubstringUTF8 : public IFunction
start = static_cast<Int64>(u_start);
}

<<<<<<< HEAD
bool implicit_length = true;
UInt64 length = 0;
if (arguments.size() == 3)
Expand Down Expand Up @@ -1590,6 +1802,42 @@ class FunctionSubstringUTF8 : public IFunction
throw Exception(
"Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
=======
// return {is_positive, abs}
template <typename Integer>
static std::pair<bool, size_t> getValueFromStartColumn(const ColumnVector<Integer> & column, size_t index)
{
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
if (val < 0)
return {false, static_cast<size_t>(-val)};
return {true, static_cast<size_t>(val)};
}
else
{
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return {true, val};
}
}

template <typename F>
static bool getNumberType(DataTypePtr type, F && f)
{
return castTypeToEither<
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))
}
};

Expand Down Expand Up @@ -1643,6 +1891,7 @@ class FunctionRightUTF8 : public IFunction
const ColumnPtr column_string = block.getByPosition(arguments[0]).column;

const ColumnPtr column_length = block.getByPosition(arguments[1]).column;
<<<<<<< HEAD
if (!column_length->isColumnConst())
throw Exception("2nd arguments of function " + getName() + " must be constants.");
Field length_field = (*block.getByPosition(arguments[1]).column)[0];
Expand Down Expand Up @@ -1677,6 +1926,114 @@ class FunctionRightUTF8 : public IFunction
throw Exception(
"Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function " + getName(),
ErrorCodes::ILLEGAL_COLUMN);
=======

bool is_length_type_valid
= getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
using LengthFieldType = typename LengthType::FieldType;

const ColumnVector<LengthFieldType> * column_vector_length
= FunctionSubstringUTF8::getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);


auto col_res = ColumnString::create();
if (const auto * col_string = checkAndGetColumn<ColumnString>(column_string.get()))
{
if (column_length->isColumnConst())
{
// vector const
size_t length = FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
0);

// for const 0, return const blank string.
if (0 == length)
{
block.getByPosition(result).column
= DataTypeString().createColumnConst(column_string->size(), toField(String("")));
return true;
}

RightUTF8Impl::vectorConst(
col_string->getChars(),
col_string->getOffsets(),
length,
col_res->getChars(),
col_res->getOffsets());
}
else
{
// vector vector
auto get_length_func = [column_vector_length](size_t i) {
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
i);
};
RightUTF8Impl::vectorVector(
col_string->getChars(),
col_string->getOffsets(),
get_length_func,
col_res->getChars(),
col_res->getOffsets());
}
}
else if (
const ColumnConst * col_const_string = checkAndGetColumnConst<ColumnString>(column_string.get()))
{
// const vector
const auto * col_string_from_const
= checkAndGetColumn<ColumnString>(col_const_string->getDataColumnPtr().get());
assert(col_string_from_const);
// When useDefaultImplementationForConstants is true, string and length are not both constants
assert(!column_length->isColumnConst());
auto get_length_func = [column_vector_length](size_t i) {
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
i);
};
RightUTF8Impl::constVector(
column_length->size(),
col_string_from_const->getChars(),
col_string_from_const->getOffsets(),
get_length_func,
col_res->getChars(),
col_res->getOffsets());
}
else
{
// Impossible to reach here
return false;
}
block.getByPosition(result).column = std::move(col_res);
return true;
});

if (!is_length_type_valid)
throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName()));
}

private:
template <typename F>
static bool getLengthType(DataTypePtr type, F && f)
{
return castTypeToEither<
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
>>>>>>> 9f85bb2f16 (Support other integer types for SubstringUTF8 & RightUTF8 functions (#9507))
}
};

Expand Down
Loading