[Mlir-commits] [mlir] a3f22d0 - [mlir] Add a pattern to bufferize linalg.tensor_reshape.

Alexander Belyaev llvmlistbot at llvm.org
Fri May 7 12:31:23 PDT 2021


Author: Alexander Belyaev
Date: 2021-05-07T21:31:17+02:00
New Revision: a3f22d020b2709b2b4897ae3450c33834e646329

URL: https://github.com/llvm/llvm-project/commit/a3f22d020b2709b2b4897ae3450c33834e646329
DIFF: https://github.com/llvm/llvm-project/commit/a3f22d020b2709b2b4897ae3450c33834e646329.diff

LOG: [mlir] Add a pattern to bufferize linalg.tensor_reshape.

Differential Revision: https://reviews.llvm.org/D102089

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/test/Dialect/Linalg/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 892942eeaef0..bd2fdac9aa4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -149,6 +149,23 @@ class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
   }
 };
 
+/// Conversion pattern that replaces `linalg.tensor_reshape` with
+/// `linalg.reshape`.
+class BufferizeTensorReshapeOp : public OpConversionPattern<TensorReshapeOp> {
+public:
+  using OpConversionPattern<TensorReshapeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(TensorReshapeOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    linalg::TensorReshapeOpAdaptor adaptor(operands, op->getAttrDictionary());
+    rewriter.replaceOpWithNewOp<linalg::ReshapeOp>(
+        op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
+        adaptor.src(), adaptor.reassociation());
+    return success();
+  }
+};
+
 /// Conversion pattern that bufferizes `linalg.fill` operation.
 class BufferizeFillOp : public OpConversionPattern<FillOp> {
 public:
@@ -336,6 +353,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
       BufferizeAnyLinalgOp,
       BufferizeFillOp,
       BufferizeInitTensorOp,
+      BufferizeTensorReshapeOp,
       SubTensorOpConverter,
       SubTensorInsertOpConverter
     >(typeConverter, patterns.getContext());

diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 198936abc2f2..0270c5e4ab79 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -278,3 +278,18 @@ func @bufferize_fill(%arg0: tensor<?xf32>) -> tensor<?xf32> {
   %0 = linalg.fill(%arg0, %c0) : tensor<?xf32>, f32 -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @bufferize_tensor_reshape(
+// CHECK-SAME:    %[[IN:.*]]: tensor<4x5xf32>
+func @bufferize_tensor_reshape(%arg0: tensor<4x5xf32>) -> tensor<20xf32> {
+  %out = linalg.tensor_reshape %arg0 [[0, 1]] :
+     tensor<4x5xf32> into tensor<20xf32>
+  return %out : tensor<20xf32>
+}
+// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x5xf32>
+// CHECK: %[[RESHAPE:.*]] = linalg.reshape %[[MEMREF]] {{\[}}[0, 1]]
+// CHECK-SAME: : memref<4x5xf32> into memref<20xf32>
+// CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32>
+// CHECK: return %[[TENSOR]]


        


More information about the Mlir-commits mailing list