[Mlir-commits] [mlir] 41e731f - [mlir][vector] Add additional scalar vector transfer foldings

Matthias Springer llvmlistbot at llvm.org
Thu Dec 22 05:43:58 PST 2022


Author: Matthias Springer
Date: 2022-12-22T14:43:48+01:00
New Revision: 41e731f2a4328f37b5d4e14d912814493e206f32

URL: https://github.com/llvm/llvm-project/commit/41e731f2a4328f37b5d4e14d912814493e206f32
DIFF: https://github.com/llvm/llvm-project/commit/41e731f2a4328f37b5d4e14d912814493e206f32.diff

LOG: [mlir][vector] Add additional scalar vector transfer foldings

* Rewrite vector.transfer_write of vectors with 1 element to
  memref.store
* Rewrite vector.extract(vector.transfer_read) to memref.load

Differential Revision: https://reviews.llvm.org/D140391

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 727a356210a38..38062b9893f1a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -567,7 +567,7 @@ class FlattenContiguousRowMajorTransferWritePattern
 /// %2 = vector.extractelement %0[%b : index] : vector<1024xf32>
 /// Rewriting such IR (replacing one vector load with multiple scalar loads) may
 /// negatively affect performance.
-class FoldScalarExtractOfTransferRead
+class RewriteScalarExtractElementOfTransferRead
     : public OpRewritePattern<vector::ExtractElementOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -618,17 +618,79 @@ class FoldScalarExtractOfTransferRead
   }
 };
 
-/// Rewrite scalar transfer_write(broadcast) to memref.store.
-class FoldScalarTransferWriteOfBroadcast
-    : public OpRewritePattern<vector::TransferWriteOp> {
+/// Rewrite extract(transfer_read) to memref.load.
+///
+/// Rewrite only if the extractelement op is the single user of the transfer op.
+/// E.g., do not rewrite IR such as:
+/// %0 = vector.transfer_read ... : vector<1024xf32>
+/// %1 = vector.extract %0[0] : vector<1024xf32>
+/// %2 = vector.extract %0[5] : vector<1024xf32>
+/// Rewriting such IR (replacing one vector load with multiple scalar loads) may
+/// negatively affect performance.
+class RewriteScalarExtractOfTransferRead
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    // Only match scalar extracts.
+    if (extractOp.getType().isa<VectorType>())
+      return failure();
+    auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
+    if (!xferOp)
+      return failure();
+    // xfer result must have a single use. Otherwise, it may be better to
+    // perform a vector load.
+    if (!extractOp.getVector().hasOneUse())
+      return failure();
+    // Mask not supported.
+    if (xferOp.getMask())
+      return failure();
+    // Map not supported.
+    if (!xferOp.getPermutationMap().isMinorIdentity())
+      return failure();
+    // Cannot rewrite if the indices may be out of bounds. The starting point is
+    // always inbounds, so we don't care in case of 0d transfers.
+    if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
+      return failure();
+    // Construct scalar load.
+    SmallVector<Value> newIndices(xferOp.getIndices().begin(),
+                                  xferOp.getIndices().end());
+    for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
+      int64_t offset = it.value().cast<IntegerAttr>().getInt();
+      int64_t idx =
+          newIndices.size() - extractOp.getPosition().size() + it.index();
+      OpFoldResult ofr = makeComposedFoldedAffineApply(
+          rewriter, extractOp.getLoc(),
+          rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
+      if (ofr.is<Value>()) {
+        newIndices[idx] = ofr.get<Value>();
+      } else {
+        newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
+            extractOp.getLoc(), *getConstantIntValue(ofr));
+      }
+    }
+    if (xferOp.getSource().getType().isa<MemRefType>()) {
+      rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
+                                                  newIndices);
+    } else {
+      rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
+          extractOp, xferOp.getSource(), newIndices);
+    }
+    return success();
+  }
+};
+
+/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
+/// to memref.store.
+class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
                                 PatternRewriter &rewriter) const override {
     // Must be a scalar write.
     auto vecType = xferOp.getVectorType();
-    if (vecType.getRank() != 0 &&
-        (vecType.getRank() != 1 || vecType.getShape()[0] != 1))
+    if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
       return failure();
     // Mask not supported.
     if (xferOp.getMask())
@@ -636,19 +698,25 @@ class FoldScalarTransferWriteOfBroadcast
     // Map not supported.
     if (!xferOp.getPermutationMap().isMinorIdentity())
       return failure();
-    // Must be a broadcast of a scalar.
-    auto broadcastOp = xferOp.getVector().getDefiningOp<vector::BroadcastOp>();
-    if (!broadcastOp || broadcastOp.getSource().getType().isa<VectorType>())
-      return failure();
+    // Only float and integer element types are supported.
+    Value scalar;
+    if (vecType.getRank() == 0) {
+      // vector.extract does not support vector<f32> etc., so use
+      // vector.extractelement instead.
+      scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
+                                                         xferOp.getVector());
+    } else {
+      SmallVector<int64_t> pos(vecType.getRank(), 0);
+      scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
+                                                  xferOp.getVector(), pos);
+    }
     // Construct a scalar store.
     if (xferOp.getSource().getType().isa<MemRefType>()) {
       rewriter.replaceOpWithNewOp<memref::StoreOp>(
-          xferOp, broadcastOp.getSource(), xferOp.getSource(),
-          xferOp.getIndices());
+          xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
     } else {
       rewriter.replaceOpWithNewOp<tensor::InsertOp>(
-          xferOp, broadcastOp.getSource(), xferOp.getSource(),
-          xferOp.getIndices());
+          xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
     }
     return success();
   }
@@ -673,9 +741,9 @@ void mlir::vector::transferOpflowOpt(Operation *rootOp) {
 
 void mlir::vector::populateScalarVectorTransferLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns
-      .add<FoldScalarExtractOfTransferRead, FoldScalarTransferWriteOfBroadcast>(
-          patterns.getContext(), benefit);
+  patterns.add<RewriteScalarExtractElementOfTransferRead,
+               RewriteScalarExtractOfTransferRead, RewriteScalarWrite>(
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorTransferDropUnitDimsPatterns(

diff  --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index d34b9c3091f69..f3fd6e0d25cb5 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -44,7 +44,9 @@ func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
 
 // CHECK-LABEL: func @transfer_write_0d(
 //  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-//       CHECK:   memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
+//       CHECK:   %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
+//       CHECK:   memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
 func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
   %0 = vector.broadcast %f : f32 to vector<f32>
   vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
@@ -66,10 +68,43 @@ func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
 
 // CHECK-LABEL: func @tensor_transfer_write_0d(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
-//       CHECK:   %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
+//       CHECK:   %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
+//       CHECK:   %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
+//       CHECK:   %[[r:.*]] = tensor.insert %[[extract]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
 //       CHECK:   return %[[r]]
 func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
   %0 = vector.broadcast %f : f32 to vector<f32>
   %1 = vector.transfer_write %0, %t[%idx, %idx, %idx] : vector<f32>, tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
 }
+
+// -----
+
+//       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 8)>
+//       CHECK: #[[$map1:.*]] = affine_map<()[s0] -> (s0 + 1)>
+// CHECK-LABEL: func @transfer_read_2d_extract(
+//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?x?xf32>, %[[idx:.*]]: index, %[[idx2:.*]]: index
+//       CHECK:   %[[added:.*]] = affine.apply #[[$map]]()[%[[idx]]]
+//       CHECK:   %[[added1:.*]] = affine.apply #[[$map1]]()[%[[idx]]]
+//       CHECK:   %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[added]], %[[added1]]]
+//       CHECK:   return %[[r]]
+func.func @transfer_read_2d_extract(%m: memref<?x?x?x?xf32>, %idx: index, %idx2: index) -> f32 {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %0 = vector.transfer_read %m[%idx, %idx, %idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?x?x?xf32>, vector<10x5xf32>
+  %1 = vector.extract %0[8, 1] : vector<10x5xf32>
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_arith_constant(
+//  CHECK-SAME:     %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
+//       CHECK:   %[[cst:.*]] = arith.constant dense<5.000000e+00> : vector<1x1xf32>
+//       CHECK:   %[[extract:.*]] = vector.extract %[[cst]][0, 0] : vector<1x1xf32>
+//       CHECK:   memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
+func.func @transfer_write_arith_constant(%m: memref<?x?x?xf32>, %idx: index) {
+  %cst = arith.constant dense<5.000000e+00> : vector<1x1xf32>
+  vector.transfer_write %cst, %m[%idx, %idx, %idx] : vector<1x1xf32>, memref<?x?x?xf32>
+  return
+}


        


More information about the Mlir-commits mailing list