Skip to content

Commit

Permalink
[*] improve log message for storage view content
Browse files Browse the repository at this point in the history
  • Loading branch information
tirivo committed Jun 24, 2024
1 parent 59c7dda commit 1e4a638
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 32 deletions.
3 changes: 3 additions & 0 deletions include/ctranslate2/storage_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ namespace ctranslate2 {

friend std::ostream& operator<<(std::ostream& os, const StorageView& storage);

template <typename T>
void print_tensor(std::ostream& os, const T* data, const std::vector<dim_t>& shape, size_t dim, size_t offset, int indent) const;

protected:
DataType _dtype = DataType::FLOAT32;
Device _device = Device::CPU;
Expand Down
97 changes: 65 additions & 32 deletions src/storage_view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,44 +440,77 @@ namespace ctranslate2 {
return os;
}

std::ostream& operator<<(std::ostream& os, const StorageView& storage) {
StorageView printable(storage.dtype());
printable.copy_from(storage);
TYPE_DISPATCH(
printable.dtype(),
const auto* values = printable.data<T>();
if (printable.size() <= PRINT_MAX_VALUES) {
for (dim_t i = 0; i < printable.size(); ++i) {
os << ' ';
print_value(os, values[i]);
std::ostream& operator<<(std::ostream& os, const StorageView& storage) {
// Create a printable copy of the storage
StorageView printable(storage.dtype());
printable.copy_from(storage);

// Check the data type and print accordingly
TYPE_DISPATCH(
printable.dtype(),
const auto* values = printable.data<T>();
const auto& shape = printable.shape();

// Print tensor contents based on dimensionality
if (shape.empty()) { // Scalar case
os << "Data (Scalar): " << values[0] << std::endl;
} else {
os << "Data (" << shape.size() << "D ";
if (shape.size() == 1)
os << "Vector";
else if (shape.size() == 2)
os << "Matrix";
else
os << "Tensor";
os << "):" << std::endl;
printable.print_tensor(os, values, shape, 0, 0, 0);
os << std::endl;
}
);

// Print metadata
os << "[device:" << device_to_str(storage.device(), storage.device_index())
<< ", dtype:" << dtype_name(storage.dtype()) << ", storage viewed as ";
if (storage.is_scalar())
os << "scalar";
else {
for (dim_t i = 0; i < storage.rank(); ++i) {
if (i > 0)
os << 'x';
os << storage.dim(i);
}
}
else {
for (dim_t i = 0; i < PRINT_MAX_VALUES / 2; ++i) {
os << ' ';
print_value(os, values[i]);
os << ']';
return os;
}

template <typename T>
void StorageView::print_tensor(std::ostream& os, const T* data, const std::vector<dim_t>& shape, size_t dim, size_t offset, int indent) const {
std::string indentation(indent, ' ');

os << indentation << "[";
bool is_last_dim = (dim == shape.size() - 1);
for (dim_t i = 0; i < shape[dim]; ++i) {
if (i > 0) {
os << ", ";
if (!is_last_dim) {
os << "\n" << std::string(indent, ' ');
}
}
os << " ...";
for (dim_t i = printable.size() - (PRINT_MAX_VALUES / 2); i < printable.size(); ++i) {
os << ' ';
print_value(os, values[i]);

if (i == PRINT_MAX_VALUES / 2) {
os << "...";
i = shape[dim] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
} else {
if (is_last_dim) {
os << data[offset + i];
} else {
print_tensor(os, data, shape, dim + 1, offset + i * shape[dim + 1], indent);
}
}
}
os << std::endl);
os << "[" << device_to_str(storage.device(), storage.device_index())
<< " " << dtype_name(storage.dtype()) << " storage viewed as ";
if (storage.is_scalar())
os << "scalar";
else {
for (dim_t i = 0; i < storage.rank(); ++i) {
if (i > 0)
os << 'x';
os << storage.dim(i);
}
os << "]";
}
os << ']';
return os;
}

#define DECLARE_IMPL(T) \
template \
Expand Down

0 comments on commit 1e4a638

Please sign in to comment.