[Mlir-commits] [mlir] fc9b37d - [mlir][bufferization] Do not canonicalize to_tensor(to_memref(x))

Matthias Springer llvmlistbot at llvm.org
Sat Jul 9 00:27:01 PDT 2022


Author: Matthias Springer
Date: 2022-07-09T09:16:52+02:00
New Revision: fc9b37dd532dc68018c0c5947030b34ebcf68d14

URL: https://github.com/llvm/llvm-project/commit/fc9b37dd532dc68018c0c5947030b34ebcf68d14
DIFF: https://github.com/llvm/llvm-project/commit/fc9b37dd532dc68018c0c5947030b34ebcf68d14.diff

LOG: [mlir][bufferization] Do not canonicalize to_tensor(to_memref(x))

This is a partial revert of D128615.

to_memref(to_tensor(x)) always be folded to x. But to_tensor(to_memref(x)) cannot be folded in the general case because writes to the intermediary memref may go unnoticed.

Differential Revision: https://reviews.llvm.org/D129354

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir
    mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 35f6f1b6a97f4..4ab904ea39309 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -539,20 +539,6 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
 }
 
 namespace {
-/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
-struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
-  using OpRewritePattern<ToTensorOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
-                                PatternRewriter &rewriter) const final {
-    auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
-    if (!toMemrefOp)
-      return failure();
-    rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
-    return success();
-  }
-};
-
 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
 
@@ -571,7 +557,7 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
 
 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {
-  results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
+  results.add<DimOfToTensorFolder>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 535a00706100f..8e087fc0f38a4 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -787,7 +787,8 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
   }
 
   // CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
-  // CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
+  // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
+  // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
   return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
 }
 

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index f24048e60e07c..df55b8373e0ee 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -109,7 +109,8 @@
 // CHECK:             scf.yield %[[VAL_84]] : f64
 // CHECK:           }
 // CHECK:           memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
-// CHECK:           return %[[VAL_0]] : tensor<f64>
+// CHECK:           %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
+// CHECK:           return %[[VAL_87]] : tensor<f64>
 // CHECK:         }
 func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
                          %arga: tensor<64x32xf64, #SparseMatrix>,


        


More information about the Mlir-commits mailing list