Skip to content

Commit

Permalink
Refactor buffer class so that storage can be inserted (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Oct 4, 2024
1 parent 28b65a5 commit 7efbedc
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 32 deletions.
25 changes: 12 additions & 13 deletions src/libspdl/core/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct CUDABuffer;
using CPUBufferPtr = std::unique_ptr<CPUBuffer>;
using CUDABufferPtr = std::unique_ptr<CUDABuffer>;

/// Abstract base buffer class (to be exposed to Python)
/// Abstract base buffer class (technically not needed)
/// Represents contiguous array memory.
struct Buffer {
///
Expand All @@ -40,45 +40,44 @@ struct Buffer {
/// Size of unit element
size_t depth = sizeof(uint8_t);

///
/// The actual data.
std::shared_ptr<Storage> storage;

Buffer(
std::vector<size_t> shape,
ElemClass elem_class,
size_t depth,
Storage* storage);
Buffer(std::vector<size_t> shape, ElemClass elem_class, size_t depth);
virtual ~Buffer() = default;

///
/// Returns the pointer to the head of the data buffer.
void* data();
virtual void* data() = 0;
};

///
/// Contiguous array data on CPU memory.
struct CPUBuffer : public Buffer {
std::shared_ptr<CPUStorage> storage;

CPUBuffer(
const std::vector<size_t>& shape,
ElemClass elem_class,
size_t depth,
CPUStorage* storage);
std::shared_ptr<CPUStorage> storage);

void* data() override;
};

///
/// Contiguous array data on a CUDA device.
struct CUDABuffer : Buffer {
#ifdef SPDL_USE_CUDA
std::shared_ptr<CUDAStorage> storage;
int device_index;

CUDABuffer(
std::vector<size_t> shape,
ElemClass elem_class,
size_t depth,
CUDAStorage* storage,
std::shared_ptr<CUDAStorage> storage,
int device_index);

void* data() override;

uintptr_t get_cuda_stream() const;

#endif
Expand Down
25 changes: 9 additions & 16 deletions src/libspdl/core/buffer/cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,8 @@ namespace spdl::core {
////////////////////////////////////////////////////////////////////////////////
// Buffer
////////////////////////////////////////////////////////////////////////////////
Buffer::Buffer(
std::vector<size_t> shape_,
ElemClass elem_class_,
size_t depth_,
Storage* storage_)
: shape(std::move(shape_)),
elem_class(elem_class_),
depth(depth_),
storage(storage_) {}

void* Buffer::data() {
return storage->data();
}
Buffer::Buffer(std::vector<size_t> shape_, ElemClass elem_class_, size_t depth_)
: shape(std::move(shape_)), elem_class(elem_class_), depth(depth_) {}

////////////////////////////////////////////////////////////////////////////////
// CPUBuffer
Expand All @@ -39,8 +28,12 @@ CPUBuffer::CPUBuffer(
const std::vector<size_t>& shape_,
ElemClass elem_class_,
size_t depth_,
CPUStorage* storage_)
: Buffer(shape_, elem_class_, depth_, (Storage*)storage_) {}
std::shared_ptr<CPUStorage> storage_)
: Buffer(shape_, elem_class_, depth_), storage(std::move(storage_)) {}

void* CPUBuffer::data() {
return storage->data();
}

////////////////////////////////////////////////////////////////////////////////
// Factory functions
Expand All @@ -67,7 +60,7 @@ std::unique_ptr<CPUBuffer> cpu_buffer(
fmt::join(shape, ", "),
depth);
return std::make_unique<CPUBuffer>(
shape, elem_class, depth, new CPUStorage{size, pin_memory});
shape, elem_class, depth, std::make_shared<CPUStorage>(size, pin_memory));
}

} // namespace spdl::core
11 changes: 8 additions & 3 deletions src/libspdl/core/buffer/cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ CUDABuffer::CUDABuffer(
std::vector<size_t> shape_,
ElemClass elem_class_,
size_t depth_,
CUDAStorage* storage_,
std::shared_ptr<CUDAStorage> storage_,
int device_index_)
: Buffer(std::move(shape_), elem_class_, depth_, (Storage*)storage_),
: Buffer(std::move(shape_), elem_class_, depth_),
storage(std::move(storage_)),
device_index(device_index_) {}

void* CUDABuffer::data() {
return storage->data();
}

uintptr_t CUDABuffer::get_cuda_stream() const {
return (uintptr_t)(((CUDAStorage*)(storage.get()))->stream);
}
Expand All @@ -48,7 +53,7 @@ CUDABufferPtr cuda_buffer(
shape,
elem_class,
depth,
new CUDAStorage{depth * prod(shape), cfg},
std::make_shared<CUDAStorage>(depth * prod(shape), cfg),
cfg.device_index);
}

Expand Down

0 comments on commit 7efbedc

Please sign in to comment.