[Mlir-commits] [mlir] 030b18f - [mlir][vector] Clean up some dimension size checks
Matthias Springer
llvmlistbot at llvm.org
Mon Jul 3 00:14:40 PDT 2023
Author: Matthias Springer
Date: 2023-07-03T09:10:00+02:00
New Revision: 030b18fe148b4e49de66574b08efbc2f95dcf242
URL: https://github.com/llvm/llvm-project/commit/030b18fe148b4e49de66574b08efbc2f95dcf242
DIFF: https://github.com/llvm/llvm-project/commit/030b18fe148b4e49de66574b08efbc2f95dcf242.diff
LOG: [mlir][vector] Clean up some dimension size checks
* Add `memref::getMixedSize` (same as in the tensor dialect).
* Simplify in-bounds check in `VectorTransferSplitRewritePatterns.cpp` and fix off-by-one error in the static in-bounds check.
* Use "memref::DimOp" instead of `createOrFoldDimOp` when possible.
Differential Revision: https://reviews.llvm.org/D154218
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 0c2ea33496862c..72463dca715ca3 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -59,6 +59,10 @@ Type getTensorTypeFromMemRefType(Type type);
/// single deallocate if it exists or nullptr.
std::optional<Operation *> findDealloc(Value allocValue);
+/// Return the dimension of the given memref value.
+OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value,
+ int64_t dim);
+
/// Return the dimensions of the given memref value.
SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc,
Value value);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 01ad2dd20e7cad..6f5b8693be9585 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -108,18 +108,22 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext());
}
+OpFoldResult memref::getMixedSize(OpBuilder &builder, Location loc, Value value,
+ int64_t dim) {
+ auto memrefType = llvm::cast<MemRefType>(value.getType());
+ SmallVector<OpFoldResult> result;
+ if (memrefType.isDynamicDim(dim))
+ return builder.createOrFold<memref::DimOp>(loc, value, dim);
+
+ return builder.getIndexAttr(memrefType.getDimSize(dim));
+}
+
SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
Location loc, Value value) {
auto memrefType = llvm::cast<MemRefType>(value.getType());
SmallVector<OpFoldResult> result;
- for (int64_t i = 0; i < memrefType.getRank(); ++i) {
- if (memrefType.isDynamicDim(i)) {
- Value size = builder.create<memref::DimOp>(loc, value, i);
- result.push_back(size);
- } else {
- result.push_back(builder.getIndexAttr(memrefType.getDimSize(i)));
- }
- }
+ for (int64_t i = 0; i < memrefType.getRank(); ++i)
+ result.push_back(getMixedSize(builder, loc, value, i));
return result;
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 3aa5cd01928d77..88253f1c520680 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -38,26 +38,6 @@
using namespace mlir;
using namespace mlir::vector;
-static std::optional<int64_t> extractConstantIndex(Value v) {
- if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
- return cstOp.value();
- if (auto affineApplyOp = v.getDefiningOp<affine::AffineApplyOp>())
- if (affineApplyOp.getAffineMap().isSingleConstant())
- return affineApplyOp.getAffineMap().getSingleConstantResult();
- return std::nullopt;
-}
-
-// Missing foldings of scf.if make it necessary to perform poor man's folding
-// eagerly, especially in the case of unrolling. In the future, this should go
-// away once scf.if folds properly.
-static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) {
- auto maybeCstV = extractConstantIndex(v);
- auto maybeCstUb = extractConstantIndex(ub);
- if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
- return Value();
- return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
-}
-
/// Build the condition to ensure that a particular VectorTransferOpInterface
/// is in-bounds.
static Value createInBoundsCond(RewriterBase &b,
@@ -74,14 +54,19 @@ static Value createInBoundsCond(RewriterBase &b,
// Fold or create the check that `index + vector_size` <= `memref_size`.
Location loc = xferOp.getLoc();
int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
- auto d0 = getAffineDimExpr(0, xferOp.getContext());
- auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
- Value sum = affine::makeComposedAffineApply(b, loc, d0 + vs,
- {xferOp.indices()[indicesIdx]});
- Value cond = createFoldedSLE(
- b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
- if (!cond)
+ OpFoldResult sum = affine::makeComposedFoldedAffineApply(
+ b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
+ {xferOp.indices()[indicesIdx]});
+ OpFoldResult dimSz =
+ memref::getMixedSize(b, loc, xferOp.source(), indicesIdx);
+ auto maybeCstSum = getConstantIntValue(sum);
+ auto maybeCstDimSz = getConstantIntValue(dimSz);
+ if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
return;
+ Value cond =
+ b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
+ getValueOrCreateConstantIndexOp(b, loc, sum),
+ getValueOrCreateConstantIndexOp(b, loc, dimSz));
// Conjunction over all dims for which we are in-bounds.
if (inBoundsCond)
inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
@@ -199,8 +184,8 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
- Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
- xferOp.source(), indicesIdx);
+ Value dimMemRef =
+ b.create<memref::DimOp>(xferOp.getLoc(), xferOp.source(), indicesIdx);
Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
Value index = xferOp.indices()[indicesIdx];
AffineExpr i, j, k;
More information about the Mlir-commits
mailing list