[Mlir-commits] [mlir] cb47124 - [mlir][bufferize] Improve to_tensor/to_memref folding
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 27 12:42:54 PDT 2022
Author: Matthias Springer
Date: 2022-06-27T21:42:39+02:00
New Revision: cb471241797b64964dae719382ce43c51bd2441a
URL: https://github.com/llvm/llvm-project/commit/cb471241797b64964dae719382ce43c51bd2441a
DIFF: https://github.com/llvm/llvm-project/commit/cb471241797b64964dae719382ce43c51bd2441a.diff
LOG: [mlir][bufferize] Improve to_tensor/to_memref folding
Differential Revision: https://reviews.llvm.org/D128615
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 5c72b5ea5c664..42d88156300c2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -55,8 +55,7 @@ FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
/// to_memref op are
diff erent, a memref.cast is needed.
LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
- ToMemrefOp toMemref,
- bool allowSameType = true);
+ ToMemrefOp toMemref);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 679c4e8bbba3f..62cf424e6fef5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -84,8 +84,9 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
/// to_memref op are
diff erent, a memref.cast is needed.
-LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
- RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
+LogicalResult
+mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
+ ToMemrefOp toMemref) {
auto memrefToTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>();
if (!memrefToTensor)
return failure();
@@ -95,9 +96,6 @@ LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
// Directly rewrite if the type did not change.
if (srcType == destType) {
- // Function can be configured to only handle cases where a cast is needed.
- if (!allowSameType)
- return failure();
rewriter.replaceOp(toMemref, memrefToTensor.getMemref());
return success();
}
@@ -541,6 +539,19 @@ OpFoldResult ToTensorOp::fold(ArrayRef<Attribute>) {
}
namespace {
+/// Canonicalize bufferization.to_tensor + bufferization.to_memref.
+struct ToTensorToMemrefFolding : public OpRewritePattern<ToTensorOp> {
+ using OpRewritePattern<ToTensorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ToTensorOp toTensorOp,
+ PatternRewriter &rewriter) const final {
+ auto toMemrefOp = toTensorOp.getMemref().getDefiningOp<ToMemrefOp>();
+ if (!toMemrefOp)
+ return failure();
+ rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor());
+ return success();
+ }
+};
struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
@@ -556,12 +567,11 @@ struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
return success();
}
};
-
} // namespace
void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfToTensorFolder>(context);
+ results.add<DimOfToTensorFolder, ToTensorToMemrefFolding>(context);
}
//===----------------------------------------------------------------------===//
@@ -601,17 +611,14 @@ struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
}
};
-/// Canonicalize bufferization.to_tensor + bufferization.to_memref to
-/// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
-struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
+/// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a
+/// cast if necessary.
+struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
using OpRewritePattern<ToMemrefOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ToMemrefOp toMemref,
PatternRewriter &rewriter) const final {
- // Only handle cases where a cast is needed. The other case is handled by
- // the folder.
- return foldToMemrefToTensorPair(rewriter, toMemref,
- /*allowSameType=*/false);
+ return foldToMemrefToTensorPair(rewriter, toMemref);
}
};
@@ -651,8 +658,8 @@ struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast, TensorLoadToMemref>(
- context);
+ results.add<DimOfCastOp, LoadOfToMemref, ToMemrefOfCast,
+ ToMemrefToTensorFolding>(context);
}
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ad5afa9c36015..da7d043133020 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -787,8 +787,7 @@ func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
}
// CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
- // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
- // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
+ // CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
index df55b8373e0ee..f24048e60e07c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir
@@ -109,8 +109,7 @@
// CHECK: scf.yield %[[VAL_84]] : f64
// CHECK: }
// CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref<f64>
-// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<f64>
-// CHECK: return %[[VAL_87]] : tensor<f64>
+// CHECK: return %[[VAL_0]] : tensor<f64>
// CHECK: }
func.func @sparse_matrix_sum(%argx: tensor<f64> {linalg.inplaceable = true},
%arga: tensor<64x32xf64, #SparseMatrix>,
More information about the Mlir-commits
mailing list