Skip to content

Commit

Permalink
refactor: Refactor string input check (#101)
Browse files Browse the repository at this point in the history
Refactor string input tensor checks
  • Loading branch information
yinggeh authored Jul 31, 2024
1 parent cb72fd5 commit 30fa78a
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
17 changes: 17 additions & 0 deletions include/triton/backend/backend_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -671,4 +671,21 @@ TRITONSERVER_Error* BufferAsTypedString(
/// \return a formatted string for logging the request ID.
std::string GetRequestId(TRITONBACKEND_Request* request);

/// Validate the contiguous string buffer with correct format
/// <int32_len><bytes>...<int32_len><bytes> and parse string
/// elements into list of pairs of memory address and length.
/// Note the returned list of pairs points to valid memory as long
/// as memory pointed by buffer remains allocated.
///
/// \param buffer The pointer to the contiguous string buffer.
/// \param buffer_byte_size The size of the buffer in bytes.
/// \param expected_element_cnt The number of expected string elements.
/// \param input_name The name of the input buffer.
/// \param str_list Returns pairs of address and length of parsed strings.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_Error* ValidateStringBuffer(
const char* buffer, size_t buffer_byte_size,
const size_t expected_element_cnt, const char* input_name,
std::vector<std::pair<const char*, const uint32_t>>* str_list);

}} // namespace triton::backend
62 changes: 61 additions & 1 deletion src/backend_common.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -1372,4 +1372,64 @@ GetRequestId(TRITONBACKEND_Request* request)
return std::string("[request id: ") + request_id + "] ";
}

TRITONSERVER_Error*
ValidateStringBuffer(
const char* buffer, size_t buffer_byte_size,
const size_t expected_element_cnt, const char* input_name,
std::vector<std::pair<const char*, const uint32_t>>* str_list)
{
size_t element_idx = 0;
size_t remaining_bytes = buffer_byte_size;

// Each string in 'buffer' is a 4-byte length followed by the string itself
// with no null-terminator.
while (remaining_bytes >= sizeof(uint32_t)) {
// Do not modify this line. str_list->size() must not exceed
// expected_element_cnt.
if (element_idx >= expected_element_cnt) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"unexpected number of string elements " +
std::to_string(element_idx + 1) + " for inference input '" +
input_name + "', expecting " +
std::to_string(expected_element_cnt))
.c_str());
}

const uint32_t len = *(reinterpret_cast<const uint32_t*>(buffer));
remaining_bytes -= sizeof(uint32_t);
buffer += sizeof(uint32_t);

if (remaining_bytes < len) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"incomplete string data for inference input '" +
std::string(input_name) + "', expecting string of length " +
std::to_string(len) + " but only " +
std::to_string(remaining_bytes) + " bytes available")
.c_str());
}

if (str_list) {
str_list->push_back({buffer, len});
}
buffer += len;
remaining_bytes -= len;
element_idx++;
}

if (element_idx != expected_element_cnt) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string(
"expected " + std::to_string(expected_element_cnt) +
" strings for inference input '" + input_name + "', got " +
std::to_string(element_idx))
.c_str());
}
return nullptr;
}

}} // namespace triton::backend

0 comments on commit 30fa78a

Please sign in to comment.