Skip to content

Commit

Permalink
Fix bug when INT_MAX < byte_size <= LLONG_MAX
Browse files Browse the repository at this point in the history
  • Loading branch information
yinggeh committed Aug 29, 2024
1 parent c131fa1 commit 794a5f0
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1202,26 +1202,20 @@ InferenceRequest::Normalize()
// TensorRT backend.
if (!input.IsNonLinearFormatIo()) {
TRITONSERVER_MemoryType input_memory_type;
// Because Triton expects STRING type to be in special format
// (prepend 4 bytes to specify string length), so need to add all the
// first 4 bytes for each element to find expected byte size
if (data_type == inference::DataType::TYPE_STRING) {
RETURN_IF_ERROR(ValidateBytesInputs(
input_name, input, model_name, &input_memory_type));
// FIXME: Temporarily skips byte size checks for GPU tensors. See
// DLIS-6820.
} else {
// Shape tensor with dynamic batching does not introduce a new
// dimension to the tensor but adds an additional value to the 1-D
// array.
const std::vector<int64_t>& input_dims =
input.IsShapeTensor() ? input.OriginalShape()
: input.ShapeWithBatchDim();
int64_t expected_byte_size = INT_MAX;
expected_byte_size =
int64_t expected_byte_size =
triton::common::GetByteSize(data_type, input_dims);
const size_t& byte_size = input.Data()->TotalByteSize();
if ((byte_size > INT_MAX) ||
if ((byte_size > LLONG_MAX) ||
(static_cast<int64_t>(byte_size) != expected_byte_size)) {
return Status(
Status::Code::INVALID_ARG,
Expand Down Expand Up @@ -1331,9 +1325,13 @@ InferenceRequest::ValidateBytesInputs(
}
}

constexpr size_t kElementSizeIndicator = sizeof(uint32_t);
// Get the next element if not currently processing one.
if (!remaining_element_size) {
// Triton expects STRING type to be in special format
// (prepend 4 bytes to specify string length), so need to add the
// first 4 bytes for each element to find expected byte size.
constexpr size_t kElementSizeIndicator = sizeof(uint32_t);

// FIXME: Assume the string element's byte size indicator is not spread
// across buffer boundaries for simplicity.
if (remaining_buffer_size < kElementSizeIndicator) {
Expand Down

0 comments on commit 794a5f0

Please sign in to comment.