[Mlir-commits] [mlir] 82c1fb5 - [mlir] Fix invalid handling of AllocOp symbolOperands by SimplifyAllocConst.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 22 06:00:56 PDT 2021


Author: Butygin
Date: 2021-06-22T15:39:53+03:00
New Revision: 82c1fb575034f81f861a299d8280a5668854a2bc

URL: https://github.com/llvm/llvm-project/commit/82c1fb575034f81f861a299d8280a5668854a2bc
DIFF: https://github.com/llvm/llvm-project/commit/82c1fb575034f81f861a299d8280a5668854a2bc.diff

LOG: [mlir] Fix invalid handling of AllocOp symbolOperands by SimplifyAllocConst.

symbolOperands were completely ignored by SimplifyAllocConst. Also, slightly improved diagnostic message for verifyAllocLikeOp.

Differential Revision: https://reviews.llvm.org/D104260

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1443296f40ccb..8d003577eb533 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -119,8 +119,9 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
   if (!memRefType.getAffineMaps().empty())
     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
   if (op.symbolOperands().size() != numSymbols)
-    return op.emitOpError(
-        "symbol operand count does not equal memref symbol count");
+    return op.emitOpError("symbol operand count does not equal memref symbol "
+                          "count: expected ")
+           << numSymbols << ", got " << op.symbolOperands().size();
 
   return success();
 }
@@ -146,7 +147,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
                                 PatternRewriter &rewriter) const override {
     // Check to see if any dimensions operands are constants.  If so, we can
     // substitute and drop them.
-    if (llvm::none_of(alloc.getOperands(), [](Value operand) {
+    if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
           return matchPattern(operand, matchConstantIndex());
         }))
       return failure();
@@ -157,7 +158,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
     // and keep track of the resultant memref type to build.
     SmallVector<int64_t, 4> newShapeConstants;
     newShapeConstants.reserve(memrefType.getRank());
-    SmallVector<Value, 4> newOperands;
+    SmallVector<Value, 4> dynamicSizes;
 
     unsigned dynamicDimPos = 0;
     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
@@ -167,14 +168,15 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
         newShapeConstants.push_back(dimSize);
         continue;
       }
-      auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
+      auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
+      auto *defOp = dynamicSize.getDefiningOp();
       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
         // Dynamic shape dimension will be folded.
         newShapeConstants.push_back(constantIndexOp.getValue());
       } else {
-        // Dynamic shape dimension not folded; copy operand from old memref.
+        // Dynamic shape dimension not folded; copy dynamicSize from old memref.
         newShapeConstants.push_back(-1);
-        newOperands.push_back(alloc.getOperand(dynamicDimPos));
+        dynamicSizes.push_back(dynamicSize);
       }
       dynamicDimPos++;
     }
@@ -182,12 +184,13 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
     // Create new memref type (which will have fewer dynamic dimensions).
     MemRefType newMemRefType =
         MemRefType::Builder(memrefType).setShape(newShapeConstants);
-    assert(static_cast<int64_t>(newOperands.size()) ==
+    assert(static_cast<int64_t>(dynamicSizes.size()) ==
            newMemRefType.getNumDynamicDims());
 
     // Create and insert the alloc op for the new memref.
     auto newAlloc = rewriter.create<AllocLikeOp>(
-        alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr());
+        alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
+        alloc.alignmentAttr());
     // Insert a cast so we have the same type as the old alloc.
     auto resultCast =
         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7c5fcd2503a8c..cbf2126a9ea2f 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -367,3 +367,56 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
   %1 = memref.buffer_cast %0 : memref<?x?x16x32xi8>
   return %1 : memref<?x?x16x32xi8>
 }
+
+// -----
+
+// CHECK-LABEL: func @alloc_const_fold
+func @alloc_const_fold() -> memref<?xf32> {
+  // CHECK-NEXT: %0 = memref.alloc() : memref<4xf32>
+  %c4 = constant 4 : index
+  %a = memref.alloc(%c4) : memref<?xf32>
+
+  // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32>
+  // CHECK-NEXT: return %1 : memref<?xf32>
+  return %a : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @alloc_alignment_const_fold
+func @alloc_alignment_const_fold() -> memref<?xf32> {
+  // CHECK-NEXT: %0 = memref.alloc() {alignment = 4096 : i64} : memref<4xf32>
+  %c4 = constant 4 : index
+  %a = memref.alloc(%c4) {alignment = 4096 : i64} : memref<?xf32>
+
+  // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32>
+  // CHECK-NEXT: return %1 : memref<?xf32>
+  return %a : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @alloc_const_fold_with_symbols1(
+//  CHECK: %[[c1:.+]] = constant 1 : index
+//  CHECK: %[[mem1:.+]] = memref.alloc({{.*}})[%[[c1]], %[[c1]]] : memref<?xi32, #map>
+//  CHECK: return %[[mem1]] : memref<?xi32, #map>
+#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+func @alloc_const_fold_with_symbols1(%arg0 : index) -> memref<?xi32, #map0> {
+  %c1 = constant 1 : index
+  %0 = memref.alloc(%arg0)[%c1, %c1] : memref<?xi32, #map0>
+  return %0 : memref<?xi32, #map0>
+}
+
+// -----
+
+// CHECK-LABEL: func @alloc_const_fold_with_symbols2(
+//  CHECK: %[[c1:.+]] = constant 1 : index
+//  CHECK: %[[mem1:.+]] = memref.alloc()[%[[c1]], %[[c1]]] : memref<1xi32, #map>
+//  CHECK: %[[mem2:.+]] = memref.cast %[[mem1]] : memref<1xi32, #map> to memref<?xi32, #map>
+//  CHECK: return %[[mem2]] : memref<?xi32, #map>
+#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+func @alloc_const_fold_with_symbols2() -> memref<?xi32, #map0> {
+  %c1 = constant 1 : index
+  %0 = memref.alloc(%c1)[%c1, %c1] : memref<?xi32, #map0>
+  return %0 : memref<?xi32, #map0>
+}

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 860e55dc04a2b..38009693b40da 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -357,28 +357,6 @@ func @fold_memref_cast_chain(%0: memref<42x42xf64>) {
   return
 }
 
-// CHECK-LABEL: func @alloc_const_fold
-func @alloc_const_fold() -> memref<?xf32> {
-  // CHECK-NEXT: %0 = memref.alloc() : memref<4xf32>
-  %c4 = constant 4 : index
-  %a = memref.alloc(%c4) : memref<?xf32>
-
-  // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32>
-  // CHECK-NEXT: return %1 : memref<?xf32>
-  return %a : memref<?xf32>
-}
-
-// CHECK-LABEL: func @alloc_alignment_const_fold
-func @alloc_alignment_const_fold() -> memref<?xf32> {
-  // CHECK-NEXT: %0 = memref.alloc() {alignment = 4096 : i64} : memref<4xf32>
-  %c4 = constant 4 : index
-  %a = memref.alloc(%c4) {alignment = 4096 : i64} : memref<?xf32>
-
-  // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32>
-  // CHECK-NEXT: return %1 : memref<?xf32>
-  return %a : memref<?xf32>
-}
-
 // CHECK-LABEL: func @dead_alloc_fold
 func @dead_alloc_fold() {
   // CHECK-NEXT: return


        


More information about the Mlir-commits mailing list