[Mlir-commits] [mlir] 807e546 - [mlir] Add canonicalization for tensor_cast + tensor_to_memref
Thomas Raoux
llvmlistbot at llvm.org
Tue Feb 16 07:12:12 PST 2021
Author: Thomas Raoux
Date: 2021-02-16T07:11:09-08:00
New Revision: 807e5467f3e1b115f53377ea36ecad5625ce8280
URL: https://github.com/llvm/llvm-project/commit/807e5467f3e1b115f53377ea36ecad5625ce8280
DIFF: https://github.com/llvm/llvm-project/commit/807e5467f3e1b115f53377ea36ecad5625ce8280.diff
LOG: [mlir] Add canonicalization for tensor_cast + tensor_to_memref
This helps bufferization passes by removing tensor_cast operations.
Differential Revision: https://reviews.llvm.org/D96745
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index dd760af563f5..4e6ff2e359c0 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3078,6 +3078,7 @@ def TensorToMemrefOp : Std_Op<"tensor_to_memref",
let assemblyFormat = "$tensor attr-dict `:` type($memref)";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 3ef48ced1b65..49082912b803 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3558,6 +3558,37 @@ OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) {
return {};
}
+namespace {
+/// Replace tensor_cast + tensor_to_memref by tensor_to_memref + memref_cast.
+struct TensorCastToMemref : public OpRewritePattern<TensorToMemrefOp> {
+ using OpRewritePattern<TensorToMemrefOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef,
+ PatternRewriter &rewriter) const final {
+ auto tensorCastOperand =
+ tensorToMemRef.getOperand().getDefiningOp<tensor::CastOp>();
+ if (!tensorCastOperand)
+ return failure();
+ auto srcTensorType =
+ tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
+ if (!srcTensorType)
+ return failure();
+ auto memrefType = MemRefType::get(srcTensorType.getShape(),
+ srcTensorType.getElementType());
+ Value memref = rewriter.create<TensorToMemrefOp>(
+ tensorToMemRef.getLoc(), memrefType, tensorCastOperand.getOperand());
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(tensorToMemRef,
+ tensorToMemRef.getType(), memref);
+ return success();
+ }
+};
+} // namespace
+
+void TensorToMemrefOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<TensorCastToMemref>(context);
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 8187c2f3215d..7b54938b0c48 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -131,3 +131,15 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
%2 = dim %0, %c1 : tensor<?x?xf32>
return %1, %2: index, index
}
+
+// CHECK-LABEL: func @tensor_cast_to_memref
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+// CHECK: %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8>
+// CHECK: %[[M1:.+]] = memref_cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
+func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
+ memref<?x?x16x32xi8> {
+ %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+ %1 = tensor_to_memref %0 : memref<?x?x16x32xi8>
+ return %1 : memref<?x?x16x32xi8>
+}
More information about the Mlir-commits
mailing list