[Mlir-commits] [mlir] 82c1fb5 - [mlir] Fix invalid handling of AllocOp symbolOperands by SimplifyAllocConst.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 22 06:00:56 PDT 2021
Author: Butygin
Date: 2021-06-22T15:39:53+03:00
New Revision: 82c1fb575034f81f861a299d8280a5668854a2bc
URL: https://github.com/llvm/llvm-project/commit/82c1fb575034f81f861a299d8280a5668854a2bc
DIFF: https://github.com/llvm/llvm-project/commit/82c1fb575034f81f861a299d8280a5668854a2bc.diff
LOG: [mlir] Fix invalid handling of AllocOp symbolOperands by SimplifyAllocConst.
symbolOperands were completely ignored by SimplifyAllocConst. Also, slightly improved diagnostic message for verifyAllocLikeOp.
Differential Revision: https://reviews.llvm.org/D104260
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1443296f40ccb..8d003577eb533 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -119,8 +119,9 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
if (!memRefType.getAffineMaps().empty())
numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
if (op.symbolOperands().size() != numSymbols)
- return op.emitOpError(
- "symbol operand count does not equal memref symbol count");
+ return op.emitOpError("symbol operand count does not equal memref symbol "
+ "count: expected ")
+ << numSymbols << ", got " << op.symbolOperands().size();
return success();
}
@@ -146,7 +147,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
PatternRewriter &rewriter) const override {
// Check to see if any dimensions operands are constants. If so, we can
// substitute and drop them.
- if (llvm::none_of(alloc.getOperands(), [](Value operand) {
+ if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
return matchPattern(operand, matchConstantIndex());
}))
return failure();
@@ -157,7 +158,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
// and keep track of the resultant memref type to build.
SmallVector<int64_t, 4> newShapeConstants;
newShapeConstants.reserve(memrefType.getRank());
- SmallVector<Value, 4> newOperands;
+ SmallVector<Value, 4> dynamicSizes;
unsigned dynamicDimPos = 0;
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
@@ -167,14 +168,15 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
newShapeConstants.push_back(dimSize);
continue;
}
- auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
+ auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
+ auto *defOp = dynamicSize.getDefiningOp();
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
// Dynamic shape dimension will be folded.
newShapeConstants.push_back(constantIndexOp.getValue());
} else {
- // Dynamic shape dimension not folded; copy operand from old memref.
+ // Dynamic shape dimension not folded; copy dynamicSize from old memref.
newShapeConstants.push_back(-1);
- newOperands.push_back(alloc.getOperand(dynamicDimPos));
+ dynamicSizes.push_back(dynamicSize);
}
dynamicDimPos++;
}
@@ -182,12 +184,13 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
// Create new memref type (which will have fewer dynamic dimensions).
MemRefType newMemRefType =
MemRefType::Builder(memrefType).setShape(newShapeConstants);
- assert(static_cast<int64_t>(newOperands.size()) ==
+ assert(static_cast<int64_t>(dynamicSizes.size()) ==
newMemRefType.getNumDynamicDims());
// Create and insert the alloc op for the new memref.
auto newAlloc = rewriter.create<AllocLikeOp>(
- alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr());
+ alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
+ alloc.alignmentAttr());
// Insert a cast so we have the same type as the old alloc.
auto resultCast =
rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7c5fcd2503a8c..cbf2126a9ea2f 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -367,3 +367,56 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
%1 = memref.buffer_cast %0 : memref<?x?x16x32xi8>
return %1 : memref<?x?x16x32xi8>
}
+
+// -----
+
+// CHECK-LABEL: func @alloc_const_fold
+func @alloc_const_fold() -> memref<?xf32> {
+ // CHECK-NEXT: %0 = memref.alloc() : memref<4xf32>
+ %c4 = constant 4 : index
+ %a = memref.alloc(%c4) : memref<?xf32>
+
+ // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32>
+ // CHECK-NEXT: return %1 : memref<?xf32>
+ return %a : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @alloc_alignment_const_fold
+func @alloc_alignment_const_fold() -> memref<?xf32> {
+ // CHECK-NEXT: %0 = memref.alloc() {alignment = 4096 : i64} : memref<4xf32>
+ %c4 = constant 4 : index
+ %a = memref.alloc(%c4) {alignment = 4096 : i64} : memref<?xf32>
+
+ // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32>
+ // CHECK-NEXT: return %1 : memref<?xf32>
+ return %a : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @alloc_const_fold_with_symbols1(
+// CHECK: %[[c1:.+]] = constant 1 : index
+// CHECK: %[[mem1:.+]] = memref.alloc({{.*}})[%[[c1]], %[[c1]]] : memref<?xi32, #map>
+// CHECK: return %[[mem1]] : memref<?xi32, #map>
+#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+func @alloc_const_fold_with_symbols1(%arg0 : index) -> memref<?xi32, #map0> {
+ %c1 = constant 1 : index
+ %0 = memref.alloc(%arg0)[%c1, %c1] : memref<?xi32, #map0>
+ return %0 : memref<?xi32, #map0>
+}
+
+// -----
+
+// CHECK-LABEL: func @alloc_const_fold_with_symbols2(
+// CHECK: %[[c1:.+]] = constant 1 : index
+// CHECK: %[[mem1:.+]] = memref.alloc()[%[[c1]], %[[c1]]] : memref<1xi32, #map>
+// CHECK: %[[mem2:.+]] = memref.cast %[[mem1]] : memref<1xi32, #map> to memref<?xi32, #map>
+// CHECK: return %[[mem2]] : memref<?xi32, #map>
+#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+func @alloc_const_fold_with_symbols2() -> memref<?xi32, #map0> {
+ %c1 = constant 1 : index
+ %0 = memref.alloc(%c1)[%c1, %c1] : memref<?xi32, #map0>
+ return %0 : memref<?xi32, #map0>
+}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 860e55dc04a2b..38009693b40da 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -357,28 +357,6 @@ func @fold_memref_cast_chain(%0: memref<42x42xf64>) {
return
}
-// CHECK-LABEL: func @alloc_const_fold
-func @alloc_const_fold() -> memref<?xf32> {
- // CHECK-NEXT: %0 = memref.alloc() : memref<4xf32>
- %c4 = constant 4 : index
- %a = memref.alloc(%c4) : memref<?xf32>
-
- // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32>
- // CHECK-NEXT: return %1 : memref<?xf32>
- return %a : memref<?xf32>
-}
-
-// CHECK-LABEL: func @alloc_alignment_const_fold
-func @alloc_alignment_const_fold() -> memref<?xf32> {
- // CHECK-NEXT: %0 = memref.alloc() {alignment = 4096 : i64} : memref<4xf32>
- %c4 = constant 4 : index
- %a = memref.alloc(%c4) {alignment = 4096 : i64} : memref<?xf32>
-
- // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref<?xf32>
- // CHECK-NEXT: return %1 : memref<?xf32>
- return %a : memref<?xf32>
-}
-
// CHECK-LABEL: func @dead_alloc_fold
func @dead_alloc_fold() {
// CHECK-NEXT: return
More information about the Mlir-commits
mailing list