[Mlir-commits] [mlir] b7d47ed - [mlir][memref] Add support for 0-D transfer / subview fold.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Sep 8 15:25:14 PDT 2022
Author: Nicolas Vasilache
Date: 2022-09-08T15:25:05-07:00
New Revision: b7d47ed1da974bdba4d2b74589c985e7190b8d21
URL: https://github.com/llvm/llvm-project/commit/b7d47ed1da974bdba4d2b74589c985e7190b8d21
DIFF: https://github.com/llvm/llvm-project/commit/b7d47ed1da974bdba4d2b74589c985e7190b8d21.diff
LOG: [mlir][memref] Add support for 0-D transfer / subview fold.
The 0-d case simply forwards the indexing from the source memref and
works out of the box.
Differential Revision: https://reviews.llvm.org/D133536
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 875ef126e45ed..549c8d0960518 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -330,16 +330,13 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp,
indices, sourceIndices)))
return failure();
+
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case<AffineLoadOp, memref::LoadOp>([&](auto op) {
rewriter.replaceOpWithNewOp<decltype(op)>(loadOp, subViewOp.source(),
sourceIndices);
})
.Case([&](vector::TransferReadOp transferReadOp) {
- if (transferReadOp.getTransferRank() == 0) {
- // TODO: Propagate the error.
- return;
- }
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
sourceIndices,
@@ -439,15 +436,13 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp,
indices, sourceIndices)))
return failure();
+
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case<AffineStoreOp, memref::StoreOp>([&](auto op) {
rewriter.replaceOpWithNewOp<decltype(op)>(
storeOp, storeOp.getValue(), subViewOp.source(), sourceIndices);
})
.Case([&](vector::TransferWriteOp op) {
- // TODO: support 0-d corner case.
- if (op.getTransferRank() == 0)
- return;
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, op.getValue(), subViewOp.source(), sourceIndices,
getPermutationMapAttr(rewriter.getContext(), subViewOp,
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 9ee6893f8b1f7..393f0f49e15f7 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -81,6 +81,28 @@ func.func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %ar
// -----
+func.func @fold_subview_with_transfer_read_0d(
+ %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index)
+ -> vector<f32> {
+ %f1 = arith.constant 1.0 : f32
+ %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+ %1 = vector.transfer_read %0[], %f1 : memref<f32, strided<[], offset: ?>>, vector<f32>
+ return %1 : vector<f32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @fold_subview_with_transfer_read_0d
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME: %[[SZ0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[SZ1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ST1:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP0]](%[[C0]])[%[[SZ0]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[ST1]], %[[SZ1]]]
+// CHECK: vector.transfer_read %[[MEM]][%[[I1]], %[[I2]]]
+
+// -----
+
func.func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> vector<4xf32> {
%f1 = arith.constant 1.0 : f32
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>>
@@ -102,6 +124,29 @@ func.func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : in
// -----
+func.func @fold_static_stride_subview_with_transfer_write_0d(
+ %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index,
+ %v : vector<f32>) {
+ %f1 = arith.constant 1.0 : f32
+ %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
+ vector.transfer_write %v, %0[] {in_bounds = []} : vector<f32>, memref<f32, strided<[], offset: ?>>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @fold_static_stride_subview_with_transfer_write_0d
+// CHECK-SAME: %[[MEM:[a-zA-Z0-9_]+]]: memref<12x32xf32>
+// CHECK-SAME: %[[SZ0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[SZ1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ST1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[V:[a-zA-Z0-9_]+]]: vector<f32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP0]](%[[C0]])[%[[SZ0]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[ST1]], %[[SZ1]]]
+// CHECK: vector.transfer_write %[[V]], %[[MEM]][%[[I1]], %[[I2]]]
+
+// -----
+
func.func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5: index, %arg6 : index, %arg7 : vector<4xf32>) {
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] :
memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>>
More information about the Mlir-commits
mailing list