[Mlir-commits] [mlir] 965ad79 - [MLIR][MemRef] Only allow fold of cast for the pointer operand, not the value
William S. Moses
llvmlistbot at llvm.org
Tue Jun 8 08:43:35 PDT 2021
Author: William S. Moses
Date: 2021-06-08T11:43:09-04:00
New Revision: 965ad79ea7d0b98f905a27785a6fd0091b904218
URL: https://github.com/llvm/llvm-project/commit/965ad79ea7d0b98f905a27785a6fd0091b904218
DIFF: https://github.com/llvm/llvm-project/commit/965ad79ea7d0b98f905a27785a6fd0091b904218.diff
LOG: [MLIR][MemRef] Only allow fold of cast for the pointer operand, not the value
Currently canonicalizations of a store and a cast try to fold all casts into the store.
In the case where the operand being stored is itself a cast, this is illegal as the type of the value being stored
will change. This PR fixes this by not checking the value for folding with a cast.
Depends on https://reviews.llvm.org/D103828
Differential Revision: https://reviews.llvm.org/D103829
Added:
Modified:
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/Affine/canonicalize.mlir
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index ef990b70f3575..480b53811483f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -942,11 +942,12 @@ void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
+static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto cast = operand.get().getDefiningOp<memref::CastOp>();
- if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
+ if (cast && operand.get() != ignore &&
+ !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;
}
@@ -2270,7 +2271,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
- return foldMemRefCast(*this);
+ return foldMemRefCast(*this, getValueToStore());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a4ab6c1d0859f..f20234bd1d686 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -73,11 +73,12 @@ static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
/// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
+static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto cast = operand.get().getDefiningOp<CastOp>();
- if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
+ if (cast && operand.get() != inner &&
+ !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
operand.set(cast.getOperand());
folded = true;
}
@@ -1425,7 +1426,7 @@ static LogicalResult verify(StoreOp op) {
LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
- return foldMemRefCast(*this);
+ return foldMemRefCast(*this, getValueToStore());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 0a47285e18c49..3d6bd57c27ffc 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -924,3 +924,15 @@ func @compose_into_affine_vector_load_vector_store(%A : memref<1024xf32>, %u : i
}
return
}
+
+// -----
+
+// CHECK-LABEL: func @no_fold_of_store
+// CHECK: %[[cst:.+]] = memref.cast %arg
+// CHECK: affine.store %[[cst]]
+func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
+ %0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
+ affine.store %0, %holder[] : memref<memref<?xi8>>
+ return
+}
+
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 354be2237ec30..140cd43ede147 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -206,4 +206,14 @@ func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
return %1 : index
}
+// -----
+
+// CHECK-LABEL: func @no_fold_of_store
+// CHECK: %[[cst:.+]] = memref.cast %arg
+// CHECK: memref.store %[[cst]]
+func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
+ %0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
+ memref.store %0, %holder[] : memref<memref<?xi8>>
+ return
+}
More information about the Mlir-commits
mailing list