[Mlir-commits] [mlir] b7d47ed - [mlir][memref] Add support for 0-D transfer / subview fold.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Sep 8 15:25:14 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-08T15:25:05-07:00
New Revision: b7d47ed1da974bdba4d2b74589c985e7190b8d21

URL: https://github.com/llvm/llvm-project/commit/b7d47ed1da974bdba4d2b74589c985e7190b8d21
DIFF: https://github.com/llvm/llvm-project/commit/b7d47ed1da974bdba4d2b74589c985e7190b8d21.diff

LOG: [mlir][memref] Add support for 0-D transfer / subview fold.

The 0-d case simply forwards the indexing from the source memref and
works out of the box.

Differential Revision: https://reviews.llvm.org/D133536

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 875ef126e45ed..549c8d0960518 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -330,16 +330,13 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
   if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp,
                                          indices, sourceIndices)))
     return failure();
+
   llvm::TypeSwitch<Operation *, void>(loadOp)
       .Case<AffineLoadOp, memref::LoadOp>([&](auto op) {
         rewriter.replaceOpWithNewOp<decltype(op)>(loadOp, subViewOp.source(),
                                                   sourceIndices);
       })
       .Case([&](vector::TransferReadOp transferReadOp) {
-        if (transferReadOp.getTransferRank() == 0) {
-          // TODO: Propagate the error.
-          return;
-        }
         rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
             transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
             sourceIndices,
@@ -439,15 +436,13 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
   if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp,
                                          indices, sourceIndices)))
     return failure();
+
   llvm::TypeSwitch<Operation *, void>(storeOp)
       .Case<AffineStoreOp, memref::StoreOp>([&](auto op) {
         rewriter.replaceOpWithNewOp<decltype(op)>(
             storeOp, storeOp.getValue(), subViewOp.source(), sourceIndices);
       })
       .Case([&](vector::TransferWriteOp op) {
-        // TODO: support 0-d corner case.
-        if (op.getTransferRank() == 0)
-          return;
         rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
             op, op.getValue(), subViewOp.source(), sourceIndices,
             getPermutationMapAttr(rewriter.getContext(), subViewOp,

diff  --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 9ee6893f8b1f7..393f0f49e15f7 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -81,6 +81,28 @@ func.func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %ar
 
 // -----
 
+func.func @fold_subview_with_transfer_read_0d(
+  %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index)
+    -> vector<f32> {
+  %f1 = arith.constant 1.0 : f32
+  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  %1 = vector.transfer_read %0[], %f1 : memref<f32, strided<[], offset: ?>>, vector<f32>
+  return %1 : vector<f32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+//      CHECK: func @fold_subview_with_transfer_read_0d
+// CHECK-SAME:   %[[MEM:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[SZ0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[SZ1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ST1:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP0]](%[[C0]])[%[[SZ0]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[ST1]], %[[SZ1]]]
+//      CHECK:   vector.transfer_read %[[MEM]][%[[I1]], %[[I2]]]
+
+// -----
+
 func.func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> vector<4xf32> {
   %f1 = arith.constant 1.0 : f32
   %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>>
@@ -102,6 +124,29 @@ func.func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : in
 
 // -----
 
+func.func @fold_static_stride_subview_with_transfer_write_0d(
+    %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, 
+    %v : vector<f32>) {
+  %f1 = arith.constant 1.0 : f32
+  %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+  vector.transfer_write %v, %0[] {in_bounds = []} : vector<f32>, memref<f32, strided<[], offset: ?>>
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+//      CHECK: func @fold_static_stride_subview_with_transfer_write_0d
+// CHECK-SAME:   %[[MEM:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME:   %[[SZ0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[SZ1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ST1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[V:[a-zA-Z0-9_]+]]: vector<f32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP0]](%[[C0]])[%[[SZ0]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[ST1]], %[[SZ1]]]
+//      CHECK:   vector.transfer_write %[[V]], %[[MEM]][%[[I1]], %[[I2]]]
+
+// -----
+
 func.func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5: index, %arg6 : index, %arg7 : vector<4xf32>) {
   %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
     memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>>


        


More information about the Mlir-commits mailing list