Skip to content

Commit

Permalink
[mlir][sparse] Cleanup sparse_tensor::LvlOp's folder (llvm#71085)
Browse files Browse the repository at this point in the history
Reuse the util function instead.
  • Loading branch information
PeimingLiu authored Nov 2, 2023
1 parent e65721c commit deedf55
Showing 1 changed file with 3 additions and 32 deletions.
35 changes: 3 additions & 32 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1293,38 +1293,9 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
};

// TODO: we can remove this after SparseTensorEncoding always returns non-null
// dimToLvl map.
ArrayRef<Size> shape = stt.getDimShape();
if (stt.isPermutation()) {
Dimension dim = toOrigDim(stt, lvl);
if (!ShapedType::isDynamic(shape[dim])) {
return getIndexAttr(shape[dim]);
}
return {};
}

// Non-permutation dim2lvl/lvl2dim maps.
AffineExpr lvlExpr = stt.getDimToLvl().getResult(lvl);
if (auto binExpr = lvlExpr.dyn_cast<AffineBinaryOpExpr>()) {
if (lvlExpr.getKind() == AffineExprKind::Mod) {
// j % block_sz, the level size equals to the block size.
int64_t lvlSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
return getIndexAttr(lvlSz);
}
if (lvlExpr.getKind() == AffineExprKind::FloorDiv) {
// j / block_sz, the level size equals to dim[j] / block_sz.
Dimension dim = binExpr.getLHS().cast<AffineDimExpr>().getPosition();
int64_t blockSz = binExpr.getRHS().cast<AffineConstantExpr>().getValue();
if (ShapedType::isDynamic(shape[dim]))
return {};
return getIndexAttr(shape[dim] / blockSz);
}
}

auto dim = lvlExpr.cast<AffineDimExpr>().getPosition();
if (!ShapedType::isDynamic(dim))
return getIndexAttr(shape[dim]);
SmallVector<Size> lvlShape = stt.getLvlShape();
if (!ShapedType::isDynamic(lvlShape[lvl]))
return getIndexAttr(lvlShape[lvl]);

return {};
}
Expand Down

0 comments on commit deedf55

Please sign in to comment.