[Mlir-commits] [mlir] 030b18f - [mlir][vector] Clean up some dimension size checks

Matthias Springer llvmlistbot at llvm.org
Mon Jul 3 00:14:40 PDT 2023


Author: Matthias Springer
Date: 2023-07-03T09:10:00+02:00
New Revision: 030b18fe148b4e49de66574b08efbc2f95dcf242

URL: https://github.com/llvm/llvm-project/commit/030b18fe148b4e49de66574b08efbc2f95dcf242
DIFF: https://github.com/llvm/llvm-project/commit/030b18fe148b4e49de66574b08efbc2f95dcf242.diff

LOG: [mlir][vector] Clean up some dimension size checks

* Add `memref::getMixedSize` (same as in the tensor dialect).
* Simplify in-bounds check in `VectorTransferSplitRewritePatterns.cpp` and fix off-by-one error in the static in-bounds check.
* Use "memref::DimOp" instead of `createOrFoldDimOp` when possible.

Differential Revision: https://reviews.llvm.org/D154218

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index 0c2ea33496862c..72463dca715ca3 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -59,6 +59,10 @@ Type getTensorTypeFromMemRefType(Type type);
 /// single deallocate if it exists or nullptr.
 std::optional<Operation *> findDealloc(Value allocValue);
 
+/// Return the dimension of the given memref value.
+OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value,
+                          int64_t dim);
+
 /// Return the dimensions of the given memref value.
 SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc,
                                         Value value);

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 01ad2dd20e7cad..6f5b8693be9585 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -108,18 +108,22 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
   return NoneType::get(type.getContext());
 }
 
+OpFoldResult memref::getMixedSize(OpBuilder &builder, Location loc, Value value,
+                                  int64_t dim) {
+  auto memrefType = llvm::cast<MemRefType>(value.getType());
+  SmallVector<OpFoldResult> result;
+  if (memrefType.isDynamicDim(dim))
+    return builder.createOrFold<memref::DimOp>(loc, value, dim);
+
+  return builder.getIndexAttr(memrefType.getDimSize(dim));
+}
+
 SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
                                                 Location loc, Value value) {
   auto memrefType = llvm::cast<MemRefType>(value.getType());
   SmallVector<OpFoldResult> result;
-  for (int64_t i = 0; i < memrefType.getRank(); ++i) {
-    if (memrefType.isDynamicDim(i)) {
-      Value size = builder.create<memref::DimOp>(loc, value, i);
-      result.push_back(size);
-    } else {
-      result.push_back(builder.getIndexAttr(memrefType.getDimSize(i)));
-    }
-  }
+  for (int64_t i = 0; i < memrefType.getRank(); ++i)
+    result.push_back(getMixedSize(builder, loc, value, i));
   return result;
 }
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 3aa5cd01928d77..88253f1c520680 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -38,26 +38,6 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-static std::optional<int64_t> extractConstantIndex(Value v) {
-  if (auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>())
-    return cstOp.value();
-  if (auto affineApplyOp = v.getDefiningOp<affine::AffineApplyOp>())
-    if (affineApplyOp.getAffineMap().isSingleConstant())
-      return affineApplyOp.getAffineMap().getSingleConstantResult();
-  return std::nullopt;
-}
-
-// Missing foldings of scf.if make it necessary to perform poor man's folding
-// eagerly, especially in the case of unrolling. In the future, this should go
-// away once scf.if folds properly.
-static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) {
-  auto maybeCstV = extractConstantIndex(v);
-  auto maybeCstUb = extractConstantIndex(ub);
-  if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
-    return Value();
-  return b.create<arith::CmpIOp>(v.getLoc(), arith::CmpIPredicate::sle, v, ub);
-}
-
 /// Build the condition to ensure that a particular VectorTransferOpInterface
 /// is in-bounds.
 static Value createInBoundsCond(RewriterBase &b,
@@ -74,14 +54,19 @@ static Value createInBoundsCond(RewriterBase &b,
     // Fold or create the check that `index + vector_size` <= `memref_size`.
     Location loc = xferOp.getLoc();
     int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
-    auto d0 = getAffineDimExpr(0, xferOp.getContext());
-    auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext());
-    Value sum = affine::makeComposedAffineApply(b, loc, d0 + vs,
-                                                {xferOp.indices()[indicesIdx]});
-    Value cond = createFoldedSLE(
-        b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx));
-    if (!cond)
+    OpFoldResult sum = affine::makeComposedFoldedAffineApply(
+        b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
+        {xferOp.indices()[indicesIdx]});
+    OpFoldResult dimSz =
+        memref::getMixedSize(b, loc, xferOp.source(), indicesIdx);
+    auto maybeCstSum = getConstantIntValue(sum);
+    auto maybeCstDimSz = getConstantIntValue(dimSz);
+    if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
       return;
+    Value cond =
+        b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
+                                getValueOrCreateConstantIndexOp(b, loc, sum),
+                                getValueOrCreateConstantIndexOp(b, loc, dimSz));
     // Conjunction over all dims for which we are in-bounds.
     if (inBoundsCond)
       inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
@@ -199,8 +184,8 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
   auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
-    Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(),
-                                                xferOp.source(), indicesIdx);
+    Value dimMemRef =
+        b.create<memref::DimOp>(xferOp.getLoc(), xferOp.source(), indicesIdx);
     Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
     Value index = xferOp.indices()[indicesIdx];
     AffineExpr i, j, k;


        


More information about the Mlir-commits mailing list