[Mlir-commits] [mlir] cb47124 - [mlir][bufferize] Improve to_tensor/to_memref folding

Matthias Springer llvmlistbot at llvm.org
Mon Jun 27 12:42:54 PDT 2022


Author: Matthias Springer
Date: 2022-06-27T21:42:39+02:00
New Revision: cb471241797b64964dae719382ce43c51bd2441a

URL: https://github.com/llvm/llvm-project/commit/cb471241797b64964dae719382ce43c51bd2441a
DIFF: https://github.com/llvm/llvm-project/commit/cb471241797b64964dae719382ce43c51bd2441a.diff

LOG: [mlir][bufferize] Improve to_tensor/to_memref folding

Differential Revision: https://reviews.llvm.org/D128615

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/test/Dialect/SCF/canonicalize.mlir
    mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 5c72b5ea5c664..42d88156300c2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -55,8 +55,7 @@ FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
 /// to_memref op are 
diff erent, a memref.cast is needed.
 LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
-                                       ToMemrefOp toMemref,
-                                       bool allowSameType = true);
+                                       ToMemrefOp toMemref);
 
 } // namespace bufferization
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 679c4e8bbba3f..62cf424e6fef5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -84,8 +84,9 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
 
 /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
 /// to_memref op are 
diff erent, a memref.cast is needed.
-LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
-    RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
+LogicalResult
+mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
+                                              ToMemrefOp toMemref) {
   auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
   if (!memrefToTensor)
     return failure();
@@ -95,9 +96,6 @@ LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
 
   // Directly rewrite if the type did not change.
   if (srcType == destType) {
-    // Function can be configured to only handle cases where a cast is needed.
-    if (!allowSameType)
-      return failure();
     rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
     return success();
   }
@@ -541,6 +539,19 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
 }
 
 namespace {
+/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
+struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
+  using OpRewritePattern<ToTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
+                                PatternRewriter &rewriter) const final {
+    auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
+    if (!toMemrefOp)
+      return failure();
+    rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
+    return success();
+  }
+};
 
 struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
@@ -556,12 +567,11 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
     return success();
   }
 };
-
 } // namespace
 
 void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {
-  results.add<DimOfToTensorFolder>(context);
+  results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -601,17 +611,14 @@ struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
   }
 };
 
-/// Canonicalize bufferization.to_tensor + bufferization.to_memref to
-/// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
-struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
+/// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
+/// cast if necessary.
+struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
   using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ToMemrefOp toMemref,
                                 PatternRewriter &rewriter) const final {
-    // Only handle cases where a cast is needed. The other case is handled by
-    // the folder.
-    return foldToMemrefToTensorPair(rewriter, toMemref,
-                                    /*allowSameType=*/false);
+    return foldToMemrefToTensorPair(rewriter, toMemref);
   }
 };
 
@@ -651,8 +658,8 @@ struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
 
 void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {
-  results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
-      context);
+  results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
+              ToMemrefToTensorFolding>(context);
 }
 
 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,

diff  --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ad5afa9c36015..da7d043133020 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -787,8 +787,7 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
   }
 
   // CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
-  // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
-  // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
+  // CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
   return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
 }
 

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index df55b8373e0ee..f24048e60e07c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -109,8 +109,7 @@
 // CHECK:             scf.yield %[[VAL_84]] : f64
 // CHECK:           }
 // CHECK:           memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
-// CHECK:           %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
-// CHECK:           return %[[VAL_87]] : tensor<f64>
+// CHECK:           return %[[VAL_0]] : tensor<f64>
 // CHECK:         }
 func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
                          %arga: tensor<64x32xf64, #SparseMatrix>,


        


More information about the Mlir-commits mailing list