[Mlir-commits] [mlir] f8a2cd6 - Support affine.load/store ops in fold-memref-subview-ops pass

Uday Bondhugula llvmlistbot at llvm.org
Sun Jan 30 21:11:14 PST 2022


Author: Uday Bondhugula
Date: 2022-01-31T10:10:49+05:30
New Revision: f8a2cd67b9ad414508235b9bd1489651ed9938e6

URL: https://github.com/llvm/llvm-project/commit/f8a2cd67b9ad414508235b9bd1489651ed9938e6
DIFF: https://github.com/llvm/llvm-project/commit/f8a2cd67b9ad414508235b9bd1489651ed9938e6.diff

LOG: Support affine.load/store ops in fold-memref-subview-ops pass

Support affine.load/store ops in fold-memref-subview ops pass. The
existing pass just "inlines" the subview operation on load/stores by
inserting affine.apply ops in front of the memref load/store ops: this
is by design always consistent with the semantics on affine.load/store
ops and the same would work even more naturally/intuitively with the
latter.

Differential Revision: https://reviews.llvm.org/D118565

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
    mlir/test/Dialect/MemRef/fold-subview-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
index 1feda57d8de03..5ac3113f620e4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
@@ -90,12 +90,13 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
 }
 
 /// Helpers to access the memref operand for each op.
-static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
+template <typename LoadOrStoreOpTy>
+static Value getMemRefOperand(LoadOrStoreOpTy op) {
+  return op.memref();
+}
 
 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
 
-static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); }
-
 static Value getMemRefOperand(vector::TransferWriteOp op) {
   return op.source();
 }
@@ -154,12 +155,12 @@ class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
                  PatternRewriter &rewriter) const;
 };
 
-template <>
-void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
-    memref::LoadOp loadOp, memref::SubViewOp subViewOp,
-    ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
-  rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(),
-                                              sourceIndices);
+template <typename LoadOpTy>
+void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
+    LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
+    PatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.source(),
+                                        sourceIndices);
 }
 
 template <>
@@ -178,12 +179,12 @@ void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
       /*mask=*/Value(), transferReadOp.in_boundsAttr());
 }
 
-template <>
-void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp(
-    memref::StoreOp storeOp, memref::SubViewOp subViewOp,
+template <typename StoreOpTy>
+void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
+    StoreOpTy storeOp, memref::SubViewOp subViewOp,
     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
-  rewriter.replaceOpWithNewOp<memref::StoreOp>(
-      storeOp, storeOp.value(), subViewOp.source(), sourceIndices);
+  rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.value(),
+                                         subViewOp.source(), sourceIndices);
 }
 
 template <>
@@ -239,8 +240,10 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
 }
 
 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
-  patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
+  patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
+               LoadOpOfSubViewFolder<memref::LoadOp>,
                LoadOpOfSubViewFolder<vector::TransferReadOp>,
+               StoreOpOfSubViewFolder<AffineStoreOp>,
                StoreOpOfSubViewFolder<memref::StoreOp>,
                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
       patterns.getContext());

diff  --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
index e177bb37f936d..fc06cd35cc8cd 100644
--- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
@@ -251,3 +251,24 @@ func @fold_vector_transfer_write_with_inner_rank_reduced_subview(
 //   CHECK-DAG:    %[[IDX1:.+]] = affine.apply #[[MAP1]](%[[ARG7]])[%[[ARG3]]]
 //   CHECK-DAG:    vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]]
 //  CHECK-SAME:    {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, memref<?x?x?xf32
+
+// -----
+
+//  Test with affine.load/store ops. We only do a basic test here since the
+//  logic is identical to that with memref.load/store ops. The same affine.apply
+//  ops would be generated.
+
+// CHECK-LABEL: func @fold_static_stride_subview_with_affine_load_store
+func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
+  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
+  %1 = affine.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
+  // CHECK-NEXT: affine.apply
+  // CHECK-NEXT: affine.apply
+  // CHECK-NEXT: affine.load
+  affine.store %1, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]>
+  // CHECK-NEXT: affine.apply
+  // CHECK-NEXT: affine.apply
+  // CHECK-NEXT: affine.store
+  // CHECK-NEXT: return
+  return %1 : f32
+}


        


More information about the Mlir-commits mailing list