[Mlir-commits] [mlir] a552deb - [mlir] Add patterns for vector.transfer_read/write to Linalg bufferization.

Alexander Belyaev llvmlistbot at llvm.org
Fri Aug 6 11:25:03 PDT 2021


Author: Alexander Belyaev
Date: 2021-08-06T20:24:44+02:00
New Revision: a552debdcf01d422ef8d28c0bdc153cff350710a

URL: https://github.com/llvm/llvm-project/commit/a552debdcf01d422ef8d28c0bdc153cff350710a
DIFF: https://github.com/llvm/llvm-project/commit/a552debdcf01d422ef8d28c0bdc153cff350710a.diff

LOG: [mlir] Add patterns for vector.transfer_read/write to Linalg bufferization.

Differential Revision: https://reviews.llvm.org/D107643

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/test/Dialect/Linalg/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index b918b98a76b80..cd0445f1ff87d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -309,6 +309,47 @@ class InsertSliceOpConverter
     return success();
   }
 };
+
+class VectorTransferReadOpConverter
+    : public OpConversionPattern<vector::TransferReadOp> {
+public:
+  using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferReadOp readOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (readOp.getShapedType().isa<MemRefType>())
+      return failure();
+    vector::TransferReadOp::Adaptor adaptor(operands,
+                                            readOp->getAttrDictionary());
+    rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+        readOp, readOp.getType(), adaptor.source(), adaptor.indices(),
+        adaptor.permutation_map(), adaptor.padding(), adaptor.mask(),
+        adaptor.in_bounds());
+    return success();
+  }
+};
+
+class VectorTransferWriteOpConverter
+    : public OpConversionPattern<vector::TransferWriteOp> {
+public:
+  using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp writeOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (writeOp.getShapedType().isa<MemRefType>())
+      return failure();
+    vector::TransferWriteOp::Adaptor adaptor(operands,
+                                             writeOp->getAttrDictionary());
+    rewriter.create<vector::TransferWriteOp>(
+        writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(),
+        adaptor.permutation_map(),
+        adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr());
+    rewriter.replaceOp(writeOp, adaptor.source());
+    return success();
+  }
+};
 } // namespace
 
 namespace {
@@ -332,10 +373,10 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
       return typeConverter.isLegal(op);
     };
     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
-    target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
+    target.addDynamicallyLegalOp<ConstantOp, vector::TransferReadOp,
+                                 vector::TransferWriteOp>(isLegalOperation);
 
     RewritePatternSet patterns(&context);
-    patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
     populateLinalgBufferizePatterns(typeConverter, patterns);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -359,7 +400,10 @@ void mlir::linalg::populateLinalgBufferizePatterns(
       BufferizeTensorReshapeOp<TensorExpandShapeOp>,
       BufferizeTensorReshapeOp<TensorCollapseShapeOp>,
       ExtractSliceOpConverter,
-      InsertSliceOpConverter
+      InsertSliceOpConverter,
+      VectorTransferReadOpConverter,
+      VectorTransferWriteOpConverter
     >(typeConverter, patterns.getContext());
   // clang-format on
+  patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index fa95deb1dbaa3..93dbf4a563675 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -299,3 +299,20 @@ func @pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tens
 // CHECK:           %[[OUT_TENSOR:.*]] = memref.tensor_load %[[OUT]] : memref<4x?x?x?xf32>
 // CHECK:           return %[[OUT_TENSOR]] : tensor<4x?x?x?xf32>
 // CHECK:         }
+
+
+// -----
+
+// CHECK-LABEL:   func @vector_transfer
+func @vector_transfer(%in: tensor<4xf32>, %out: tensor<4xf32>) {
+  %c0 = constant 0 : index
+  %cst = constant 0.000000e+00 : f32
+  %read = vector.transfer_read %in[%c0], %cst {in_bounds = [true]}
+      : tensor<4xf32>, vector<4xf32>
+  %tanh = math.tanh %read : vector<4xf32>
+  %write = vector.transfer_write %tanh, %out[%c0] {in_bounds = [true]}
+      : vector<4xf32>, tensor<4xf32>
+  return
+  // CHECK: vector.transfer_read {{.*}} : memref<4xf32>, vector<4xf32>
+  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32>
+}


        


More information about the Mlir-commits mailing list