From 83bdd9fdf05eb918e0fad815a6cb88413666573d Mon Sep 17 00:00:00 2001 From: tjthereal Date: Sun, 8 Oct 2023 13:44:19 +0800 Subject: [PATCH] fp8-cali-table-added --- .../Top/Transforms/ImportCalibration.cpp | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/lib/Dialect/Top/Transforms/ImportCalibration.cpp b/lib/Dialect/Top/Transforms/ImportCalibration.cpp index e907b850d..965db80a1 100644 --- a/lib/Dialect/Top/Transforms/ImportCalibration.cpp +++ b/lib/Dialect/Top/Transforms/ImportCalibration.cpp @@ -22,6 +22,12 @@ typedef struct { double max; } cali_info; +typedef struct { + double threshold; + double mean; + double max; +} fp8_cali_info; + class ImportCalibrationTablePass : public ImportCalibrationTableBase { public: @@ -37,6 +43,8 @@ class ImportCalibrationTablePass OpBuilder builder(mOp); std::map calibration_map; std::map calibration_map_int4; + std::map calibration_map_e4; //e4m3 map added + std::map calibration_map_e5; //e5m2 map added std::map per_chan_scales_map; std::ifstream infile(this->tableFile); if (!infile) { @@ -47,6 +55,8 @@ class ImportCalibrationTablePass std::regex info_pattern("#.*"); bool weight_scale_meeted = false; bool int4_th_meeted = false; + bool e4m3_th_meeted = false; //e4m3 th added + bool e5m2_th_meeted = false; //e5m2 th added while (std::getline(infile, line)) { if (line.back() == '\r') { line.pop_back(); @@ -67,6 +77,24 @@ class ImportCalibrationTablePass vScales->data()[i] = value; } per_chan_scales_map[name] = vScales; + } else if (e5m2_th_meeted) { + if (std::regex_match(line, cali_pattern)) { + fp8_cali_info fp8_info = {0, 0, 0}; + if (!(iss >> name >> fp8_info.threshold >> fp8_info.mean >> fp8_info.max)) { + llvm::errs() << line; + llvm_unreachable("\n => not match required format\n"); + } + calibration_map_e5[name] = fp8_info; + } + } else if (e4m3_th_meeted) { + if (std::regex_match(line, cali_pattern)) { + fp8_cali_info fp8_info = {0, 0, 0}; + if (!(iss >> name >> fp8_info.threshold >> fp8_info.mean >> fp8_info.max)) { + llvm::errs() << line; + llvm_unreachable("\n => not match required format\n"); + } + calibration_map_e4[name] = fp8_info; + } } else if (int4_th_meeted) { // second run, read int4 th if (std::regex_match(line, cali_pattern)) { cali_info info = {0, 0, 0}; @@ -91,6 +119,12 @@ class ImportCalibrationTablePass } else if (std::regex_match(line, info_pattern) && std::string::npos != line.find("#int4_th")) { int4_th_meeted = true; + } else if (std::regex_match(line, info_pattern) && + std::string::npos != line.find("E4M3")) { + e4m3_th_meeted = true; + } else if (std::regex_match(line, info_pattern) && + std::string::npos != line.find("E5M2")) { + e5m2_th_meeted = true; } } }