[Mlir-commits] [mlir] 807e546 - [mlir] Add canonicalization for tensor_cast + tensor_to_memref

Thomas Raoux llvmlistbot at llvm.org
Tue Feb 16 07:12:12 PST 2021


Author: Thomas Raoux
Date: 2021-02-16T07:11:09-08:00
New Revision: 807e5467f3e1b115f53377ea36ecad5625ce8280

URL: https://github.com/llvm/llvm-project/commit/807e5467f3e1b115f53377ea36ecad5625ce8280
DIFF: https://github.com/llvm/llvm-project/commit/807e5467f3e1b115f53377ea36ecad5625ce8280.diff

LOG: [mlir] Add canonicalization for tensor_cast + tensor_to_memref

This helps bufferization passes by removing tensor_cast operations.

Differential Revision: https://reviews.llvm.org/D96745

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index dd760af563f5..4e6ff2e359c0 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3078,6 +3078,7 @@ def TensorToMemrefOp : Std_Op<"tensor_to_memref",
   let assemblyFormat = "$tensor attr-dict `:` type($memref)";
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 3ef48ced1b65..49082912b803 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3558,6 +3558,37 @@ OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) {
   return {};
 }
 
+namespace {
+/// Replace tensor_cast + tensor_to_memref by tensor_to_memref + memref_cast.
+struct TensorCastToMemref : public OpRewritePattern<TensorToMemrefOp> {
+  using OpRewritePattern<TensorToMemrefOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef,
+                                PatternRewriter &rewriter) const final {
+    auto tensorCastOperand =
+        tensorToMemRef.getOperand().getDefiningOp<tensor::CastOp>();
+    if (!tensorCastOperand)
+      return failure();
+    auto srcTensorType =
+        tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
+    if (!srcTensorType)
+      return failure();
+    auto memrefType = MemRefType::get(srcTensorType.getShape(),
+                                      srcTensorType.getElementType());
+    Value memref = rewriter.create<TensorToMemrefOp>(
+        tensorToMemRef.getLoc(), memrefType, tensorCastOperand.getOperand());
+    rewriter.replaceOpWithNewOp<MemRefCastOp>(tensorToMemRef,
+                                              tensorToMemRef.getType(), memref);
+    return success();
+  }
+};
+} // namespace
+
+void TensorToMemrefOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<TensorCastToMemref>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 8187c2f3215d..7b54938b0c48 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -131,3 +131,15 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
   %2 = dim %0, %c1 : tensor<?x?xf32>
   return %1, %2: index, index
 }
+
+// CHECK-LABEL: func @tensor_cast_to_memref
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+//       CHECK:   %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8>
+//       CHECK:   %[[M1:.+]] = memref_cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+//       CHECK:   return %[[M1]] : memref<?x?x16x32xi8>
+func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
+  memref<?x?x16x32xi8> {
+  %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+  %1 = tensor_to_memref %0 : memref<?x?x16x32xi8>
+  return %1 : memref<?x?x16x32xi8>
+}


        


More information about the Mlir-commits mailing list