Skip to content

Commit

Permalink
Add graph capture validation pass
Browse files Browse the repository at this point in the history
  • Loading branch information
derdeljanTT committed Nov 7, 2024
1 parent 51f6356 commit e9a9237
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 23 deletions.
10 changes: 10 additions & 0 deletions env/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ project(ttmlir-toolchain LANGUAGES CXX C)
set(FLATBUFFERS_VERSION "fb9afbafc7dfe226b9db54d4923bfb8839635274")
set(LLVM_PROJECT_VERSION "e813750354bbc08551cf23ff559a54b4a9ea1f29")
set(STABLEHLO_VERSION "d40285ef3db0687e3f1e2bb0d716d748485a9739")
set(NLOHMANN_JSON_VERSION "9cca280a4d0ccf0c08f47a99aa71d1b0e52f8d03")

include(ExternalProject)

Expand Down Expand Up @@ -78,5 +79,14 @@ ExternalProject_Add(stablehlo
INSTALL_COMMAND ""
)

ExternalProject_Add(nlohmann_json
PREFIX ${TTMLIR_TOOLCHAIN_DIR}
GIT_REPOSITORY https://github.com/nlohmann/json.git
GIT_TAG ${NLOHMANN_JSON_VERSION}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
)

add_custom_target(llvm-lit ALL COMMAND cp llvm-project-prefix/src/llvm-project-build/bin/llvm-lit ${TTMLIR_TOOLCHAIN_DIR}/bin/llvm-lit DEPENDS llvm-project)
add_custom_target(run-clang-tidy-install ALL COMMAND cp llvm-project-prefix/src/llvm-project/clang-tools-extra/clang-tidy/tool/run-clang-tidy.py ${TTMLIR_TOOLCHAIN_DIR}/bin/run-clang-tidy.py DEPENDS llvm-project)
14 changes: 14 additions & 0 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ struct TTIRToTTNNBackendPipelineOptions

ListOption<int64_t> meshShape{
*this, "mesh-shape", llvm::cl::desc("Set the multi-device mesh shape.")};

// If this option is true, run the entire graph with graph capture to validate
// it.
//
Option<bool> graphCaptureValidationEnabled{
*this, "graph-capture-validation-enabled",
llvm::cl::desc("Enable TTNN graph validation using graph capture."),
llvm::cl::init(false)};
};

void createTTNNPipelineTTIRPasses(
Expand All @@ -121,6 +129,9 @@ void createTTNNPipelineLayoutDecompositionPass(
void createTTNNPipelineDeallocPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options);

void createTTNNPipelineValidateGraphCapturePass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options);

void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm,
std::string options);

Expand All @@ -136,6 +147,9 @@ void createTTNNPipelineLayoutDecompositionPassFromString(OpPassManager &pm,
void createTTNNPipelineDeallocPassFromString(OpPassManager &pm,
std::string options);

void createTTNNPipelineValidateGraphCapturePassFromString(OpPassManager &pm,
std::string options);

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options);

Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,11 @@ def TTNNDecomposeLayouts: Pass<"ttnn-decompose-layouts", "::mlir::ModuleOp"> {
}];
}

def TTNNValidateGraphCapture: Pass<"ttnn-validate-graph-capture", "::mlir::ModuleOp"> {
let summary = "Validate op graph with graph capture.";
let description = [{
This pass validates that the produced TTNN op graph is valid using graph capture.
}];
}

#endif
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo)
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo-build)
include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/nlohmann_json/include)

add_subdirectory(CAPI)
add_subdirectory(Conversion)
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ void createTTNNPipelineDeallocPass(
pm.addPass(createTTNNDeallocate());
}

void createTTNNPipelineValidateGraphCapturePass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
pm.addPass(createTTNNValidateGraphCapture());
}

void createTTNNPipelineTTIRPassesFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
Expand Down Expand Up @@ -109,13 +114,24 @@ void createTTNNPipelineDeallocPassFromString(OpPassManager &pm,
createTTNNPipelineDeallocPass(pm, *optionsStruct);
}

void createTTNNPipelineValidateGraphCapturePassFromString(OpPassManager &pm,
std::string options) {
auto optionsStruct =
TTIRToTTNNBackendPipelineOptions::createFromString(options);
createTTNNPipelineValidateGraphCapturePass(pm, *optionsStruct);
}

void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
createTTNNPipelineTTIRPasses(pm, options);
createTTNNPipelineLoweringPasses(pm, options);
createTTNNPipelineAnalysisPasses(pm, options);
createTTNNPipelineLayoutDecompositionPass(pm, options);
createTTNNPipelineDeallocPass(pm, options);

if (options.graphCaptureValidationEnabled) {
createTTNNPipelineValidateGraphCapturePass(pm, options);
}
}

//===----------------------------------------------------------------------===//
Expand Down
86 changes: 78 additions & 8 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,23 @@
#include "mlir/IR/PatternMatch.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
#include <algorithm>

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wcovered-switch-default"
#include "nlohmann/json.hpp"
#pragma clang diagnostic pop

#include <cstdio>
#include <filesystem>
#include <fstream>

#include <iostream>

namespace mlir::tt::ttnn {
#define GEN_PASS_DEF_TTNNDEALLOCATE
#define GEN_PASS_DEF_TTNNDECOMPOSELAYOUTS
#define GEN_PASS_DEF_TTNNVALIDATEGRAPHCAPTURE
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc"

class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {
Expand Down Expand Up @@ -98,6 +111,66 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {
}
};

class TTNNValidateGraphCapture
: public impl::TTNNValidateGraphCaptureBase<TTNNValidateGraphCapture> {

public:
using impl::TTNNValidateGraphCaptureBase<
TTNNValidateGraphCapture>::TTNNValidateGraphCaptureBase;

void runOnOperation() final {
const std::filesystem::path tmpDirPath =
std::filesystem::temp_directory_path();

const std::string mlirFilePath = tmpDirPath / "module.mlir";
const std::string flatBufferFilePath = tmpDirPath / "module.ttnn";
const std::string outReportPath = tmpDirPath / "module_graph_capture.json";

outputTTNNIRFile(mlirFilePath);
outputFlatBufferFile(mlirFilePath, flatBufferFilePath);
runGraphCapture(flatBufferFilePath, outReportPath);

if (!isValidGraphCaptureReport(outReportPath)) {
// TODO (nobradovic/odjuricic): Handle recompile.
}
}

void outputTTNNIRFile(const std::string &mlirFilePath) {
ModuleOp module = getOperation();
std::error_code _ec;
auto fs = llvm::raw_fd_stream(mlirFilePath, _ec);
module.print(fs);
}

void outputFlatBufferFile(const std::string &mlirFilePath,
const std::string &flatBufferFilePath) {
const std::string cmd =
"./build/bin/ttmlir-translate --ttnn-to-flatbuffer " + mlirFilePath +
" -o " + flatBufferFilePath;

system(cmd.c_str());
}

void runGraphCapture(const std::string &flatBufferFilePath,
const std::string &outReportFilePath) {
// TODO(mbezulj): Add required env variable to be able to run graph capture
// with mockup device and without kernel compilation.
const std::string cmd = "ttrt run " + flatBufferFilePath +
" --use-graph-capture --result-file " +
outReportFilePath;
system(cmd.c_str());
}

bool isValidGraphCaptureReport(const std::string &outReportPath) {
std::ifstream reportFile(outReportPath);
nlohmann::json jsonData = nlohmann::json::parse(reportFile);

return std::all_of(jsonData.begin(), jsonData.end(), [](auto &jsonElement) {
return jsonElement["result"] == "pass";
});
}
};

class TTNNDecomposeLayouts
: public impl::TTNNDecomposeLayoutsBase<TTNNDecomposeLayouts> {

Expand Down Expand Up @@ -163,14 +236,11 @@ class TTNNDecomposeLayouts

void print() const {
llvm::errs() << "OpsToCreate{ \n"
<< "\t"
<< "CreateToDeviceOp: " << createToDeviceOp << "\n"
<< "\t"
<< "CreateFromDeviceOp: " << createFromDeviceOp << "\n"
<< "\t"
<< "CreateToLayoutOp: " << createToLayoutOp << "\n"
<< "\t"
<< "CreateTypecastOp: " << createTypecastOp << "\n"
<< "\t" << "CreateToDeviceOp: " << createToDeviceOp << "\n"
<< "\t" << "CreateFromDeviceOp: " << createFromDeviceOp
<< "\n"
<< "\t" << "CreateToLayoutOp: " << createToLayoutOp << "\n"
<< "\t" << "CreateTypecastOp: " << createTypecastOp << "\n"
<< "\t"
<< "CreateToMemoryConfigOp: " << createToMemoryConfigOp
<< "\n"
Expand Down
5 changes: 3 additions & 2 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ void deallocateBuffers(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);
std::vector<Tensor> const &outputs, bool useGraphCapture);

void wait(Event event);

void runProgram(::ttnn::MeshDevice &meshDevice,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs);
std::vector<::ttnn::Tensor *> const &outputs,
bool useGraphCapture);

} // namespace tt::runtime::ttnn

Expand Down
2 changes: 1 addition & 1 deletion runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void closeDevice(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);
std::vector<Tensor> const &outputs, bool useGraphCapture = false);

void wait(Event event);

Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,12 @@ void closeDevice(Device device) {
Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles,
std::vector<Tensor> const &outputHandles) {
std::vector<Tensor> const &outputHandles, bool useGraphCapture) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::submit(deviceHandle, executableHandle,
programIndex, inputHandles,
outputHandles);
outputHandles, useGraphCapture);
}
#endif

Expand Down
73 changes: 66 additions & 7 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,26 @@
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"
#include "ttnn/graph/graph_processor.hpp"
#include <exception>
#include <memory>
#include <optional>

namespace tt::runtime::ttnn {
using LogType = ::tt::runtime::logger::LogType;

struct ProgramExecutor {
class ProgramExecutor {
public:
ProgramExecutor(const TensorMap &liveTensors,
const std::unordered_set<uint32_t> &programInputs,
const std::unordered_set<uint32_t> &programOutputs,
::ttnn::MeshDevice *meshDevice)
: context(ProgramContext(liveTensors, programInputs, programOutputs,
meshDevice)) {}

void execute(const ::tt::target::ttnn::Program *program) {
virtual ~ProgramExecutor() = default;

virtual void execute(const ::tt::target::ttnn::Program *program) {
for (const ::tt::target::ttnn::Operation *op : *program->operations()) {
LOG_DEBUG(LogType::LogRuntimeTTNN,
"Executing operation: ", op->debug_info()->c_str());
Expand All @@ -50,12 +57,49 @@ struct ProgramExecutor {

ProgramContext &getContext() { return context; }

private:
protected:
ProgramContext context;
void runOperation(const ::tt::target::ttnn::Operation *op);
void runEltwiseOperation(const ::tt::target::ttnn::EltwiseOp *op);
};

class GraphCaptureProgramExecutor : public ProgramExecutor {
public:
using ProgramExecutor::ProgramExecutor;

void execute(const ::tt::target::ttnn::Program *program) override {
const auto execute_impl = [&]() {
unsigned int opIndex = 0;
for (const ::tt::target::ttnn::Operation *op : *program->operations()) {
LOG_DEBUG(LogType::LogRuntimeTTNN,
"Executing operation: ", op->debug_info()->c_str());

try {
runOperation(op);
} catch (const std::exception &ex) {
// TODO(mbezulj): Replace opIndex with loc attribute of the operation
// which failed (loc attribute needs to be propagated to the flat
// buffer).
std::stringstream ss;
ss << "Failed on op " << std::to_string(opIndex) << "( "
<< op->debug_info()->c_str() << " ) "
<< " because of: " << ex.what();
throw std::runtime_error(ss.str());
}

++opIndex;
}

return std::nullopt;
};

::ttnn::graph::GraphProcessor::begin_graph_capture(
tt::tt_metal::IGraphProcessor::RunMode::NO_DISPATCH);
execute_impl();
::ttnn::graph::GraphProcessor::GraphProcessor::end_graph_capture();
}
};

void ProgramExecutor::runEltwiseOperation(
const ::tt::target::ttnn::EltwiseOp *op) {
auto runUnaryOp = [&]() {
Expand Down Expand Up @@ -176,10 +220,25 @@ static bool handleNopProgram(::tt::target::ttnn::Program const *program,
return isNop;
}

std::unique_ptr<ProgramExecutor>
makeProgramExecutor(const TensorMap &liveTensors,
const std::unordered_set<uint32_t> &programInputs,
const std::unordered_set<uint32_t> &programOutputs,
::ttnn::MeshDevice *meshDevice, bool useGraphCapture) {
if (useGraphCapture) {
return std::make_unique<GraphCaptureProgramExecutor>(
liveTensors, programInputs, programOutputs, meshDevice);
}

return std::make_unique<ProgramExecutor>(liveTensors, programInputs,
programOutputs, meshDevice);
}

void runProgram(::ttnn::MeshDevice &meshDevice,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs) {
std::vector<::ttnn::Tensor *> const &outputs,
bool useGraphCapture) {
if (handleNopProgram(program, inputs, outputs)) {
return;
}
Expand All @@ -205,9 +264,9 @@ void runProgram(::ttnn::MeshDevice &meshDevice,
LOG_ASSERT(inserted, "Duplicate output tensor");
programOutputs.emplace(output->global_id());
}
ProgramExecutor executor(liveTensors, programInputs, programOutputs,
&meshDevice);
executor.execute(program);
std::unique_ptr<ProgramExecutor> executor = makeProgramExecutor(
liveTensors, programInputs, programOutputs, &meshDevice, useGraphCapture);
executor->execute(program);
}

} // namespace tt::runtime::ttnn
Loading

0 comments on commit e9a9237

Please sign in to comment.