From 7a89c460db15e9cd7d36b25dc64a1f24d669b71e Mon Sep 17 00:00:00 2001 From: "thierry.vo" Date: Mon, 3 Jun 2024 16:36:28 +0200 Subject: [PATCH] [*] improve log message for storage view content --- include/ctranslate2/storage_view.h | 3 ++ src/storage_view.cc | 70 ++++++++++++++++++++++-------- 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/include/ctranslate2/storage_view.h b/include/ctranslate2/storage_view.h index 8834ef651..d25d47048 100644 --- a/include/ctranslate2/storage_view.h +++ b/include/ctranslate2/storage_view.h @@ -238,6 +238,9 @@ namespace ctranslate2 { friend std::ostream& operator<<(std::ostream& os, const StorageView& storage); + template + void print_tensor(std::ostream& os, const T* data, const std::vector& shape, size_t dim, size_t offset, int indent) const; + protected: DataType _dtype = DataType::FLOAT32; Device _device = Device::CPU; diff --git a/src/storage_view.cc b/src/storage_view.cc index 0cbdf25c2..d8b7615c7 100644 --- a/src/storage_view.cc +++ b/src/storage_view.cc @@ -441,31 +441,35 @@ namespace ctranslate2 { } std::ostream& operator<<(std::ostream& os, const StorageView& storage) { + // Create a printable copy of the storage StorageView printable(storage.dtype()); printable.copy_from(storage); + + // Check the data type and print accordingly TYPE_DISPATCH( printable.dtype(), const auto* values = printable.data(); - if (printable.size() <= PRINT_MAX_VALUES) { - for (dim_t i = 0; i < printable.size(); ++i) { - os << ' '; - print_value(os, values[i]); - } - } - else { - for (dim_t i = 0; i < PRINT_MAX_VALUES / 2; ++i) { - os << ' '; - print_value(os, values[i]); - } - os << " ..."; - for (dim_t i = printable.size() - (PRINT_MAX_VALUES / 2); i < printable.size(); ++i) { - os << ' '; - print_value(os, values[i]); - } + const auto& shape = printable.shape(); + + // Print tensor contents based on dimensionality + if (shape.empty()) { // Scalar case + os << "Data (Scalar): " << values[0] << std::endl; + } else { + os << "Data (" << shape.size() << "D "; + if (shape.size() == 1) + os << "Vector"; + else if (shape.size() == 2) + os << "Matrix"; + else + os << "Tensor"; + os << "):" << std::endl; + printable.print_tensor(os, values, shape, 0, 0, 0); + os << std::endl; } - os << std::endl); - os << "[" << device_to_str(storage.device(), storage.device_index()) - << " " << dtype_name(storage.dtype()) << " storage viewed as "; + ); + + os << "[device:" << device_to_str(storage.device(), storage.device_index()) + << ", dtype:" << dtype_name(storage.dtype()) << ", storage viewed as "; if (storage.is_scalar()) os << "scalar"; else { @@ -479,6 +483,34 @@ namespace ctranslate2 { return os; } + template + void StorageView::print_tensor(std::ostream& os, const T* data, const std::vector& shape, size_t dim, size_t offset, int indent) const { + std::string indentation(indent, ' '); + + os << indentation << "["; + bool is_last_dim = (dim == shape.size() - 1); + for (dim_t i = 0; i < shape[dim]; ++i) { + if (i > 0) { + os << ", "; + if (!is_last_dim) { + os << "\n" << std::string(indent, ' '); + } + } + + if (i == PRINT_MAX_VALUES / 2 && shape[dim] > PRINT_MAX_VALUES) { + os << "..."; + i = shape[dim] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part + } else { + if (is_last_dim) { + os << +data[offset + i]; + } else { + print_tensor(os, data, shape, dim + 1, offset + i * shape[dim + 1], indent); + } + } + } + os << "]"; + } + #define DECLARE_IMPL(T) \ template \ StorageView::StorageView(Shape shape, T init, Device device); \