[Mlir-commits] [mlir] 965ad79 - [MLIR][MemRef] Only allow fold of cast for the pointer operand, not the value

William S. Moses llvmlistbot at llvm.org
Tue Jun 8 08:43:35 PDT 2021


Author: William S. Moses
Date: 2021-06-08T11:43:09-04:00
New Revision: 965ad79ea7d0b98f905a27785a6fd0091b904218

URL: https://github.com/llvm/llvm-project/commit/965ad79ea7d0b98f905a27785a6fd0091b904218
DIFF: https://github.com/llvm/llvm-project/commit/965ad79ea7d0b98f905a27785a6fd0091b904218.diff

LOG: [MLIR][MemRef] Only allow fold of cast for the pointer operand, not the value

Currently canonicalizations of a store and a cast try to fold all casts into the store.

In the case where the operand being stored is itself a cast, this is illegal as the type of the value being stored
will change. This PR fixes this by not checking the value for folding with a cast.

Depends on https://reviews.llvm.org/D103828

Differential Revision: https://reviews.llvm.org/D103829

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index ef990b70f3575..480b53811483f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -942,11 +942,12 @@ void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// This is a common class used for patterns of the form
 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
 /// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
+static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
   bool folded = false;
   for (OpOperand &operand : op->getOpOperands()) {
     auto cast = operand.get().getDefiningOp<memref::CastOp>();
-    if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
+    if (cast && operand.get() != ignore &&
+        !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
       operand.set(cast.getOperand());
       folded = true;
     }
@@ -2270,7 +2271,7 @@ void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
                                   SmallVectorImpl<OpFoldResult> &results) {
   /// store(memrefcast) -> store
-  return foldMemRefCast(*this);
+  return foldMemRefCast(*this, getValueToStore());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a4ab6c1d0859f..f20234bd1d686 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -73,11 +73,12 @@ static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
 /// This is a common class used for patterns of the form
 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
 /// into the root operation directly.
-static LogicalResult foldMemRefCast(Operation *op) {
+static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) {
   bool folded = false;
   for (OpOperand &operand : op->getOpOperands()) {
     auto cast = operand.get().getDefiningOp<CastOp>();
-    if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
+    if (cast && operand.get() != inner &&
+        !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
       operand.set(cast.getOperand());
       folded = true;
     }
@@ -1425,7 +1426,7 @@ static LogicalResult verify(StoreOp op) {
 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
                             SmallVectorImpl<OpFoldResult> &results) {
   /// store(memrefcast) -> store
-  return foldMemRefCast(*this);
+  return foldMemRefCast(*this, getValueToStore());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 0a47285e18c49..3d6bd57c27ffc 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -924,3 +924,15 @@ func @compose_into_affine_vector_load_vector_store(%A : memref<1024xf32>, %u : i
   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @no_fold_of_store
+//  CHECK:   %[[cst:.+]] = memref.cast %arg
+//  CHECK:   affine.store %[[cst]]
+func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
+  %0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
+  affine.store %0, %holder[] : memref<memref<?xi8>>
+  return
+}
+

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 354be2237ec30..140cd43ede147 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -206,4 +206,14 @@ func @dim_of_sized_view(%arg : memref<?xi8>, %size: index) -> index {
   return %1 : index
 }
 
+// -----
+
+// CHECK-LABEL: func @no_fold_of_store
+//  CHECK:   %[[cst:.+]] = memref.cast %arg
+//  CHECK:   memref.store %[[cst]]
+func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
+  %0 = memref.cast %arg : memref<32xi8> to memref<?xi8>
+  memref.store %0, %holder[] : memref<memref<?xi8>>
+  return
+}
 


        


More information about the Mlir-commits mailing list