[Mlir-commits] [mlir] 5b6b2ca - [mlir][vector] Handle memory space conflicts in VectorTransferSplit patterns
Quinn Dawkins
llvmlistbot at llvm.org
Tue Jul 11 20:27:06 PDT 2023
Author: Quinn Dawkins
Date: 2023-07-11T22:58:23-04:00
New Revision: 5b6b2caf3c6ec3cdf5565f935263361fbb2013cd
URL: https://github.com/llvm/llvm-project/commit/5b6b2caf3c6ec3cdf5565f935263361fbb2013cd
DIFF: https://github.com/llvm/llvm-project/commit/5b6b2caf3c6ec3cdf5565f935263361fbb2013cd.diff
LOG: [mlir][vector] Handle memory space conflicts in VectorTransferSplit patterns
Currently the transfer splitting patterns will generate an invalid cast
when the source memref for a transfer op has a non-default memory space.
This is handled by first introducing a `memref.memory_space_cast` in
such cases.
Differential Revision: https://reviews.llvm.org/D154515
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 88253f1c520680..1d240374f8a8f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -166,6 +166,24 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
}
+/// Casts the given memref to a compatible memref type. If the source memref has
+/// a
diff erent address space than the target type, a `memref.memory_space_cast`
+/// is first inserted, followed by a `memref.cast`.
+static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
+ MemRefType compatibleMemRefType) {
+ MemRefType sourceType = memref.getType().cast<MemRefType>();
+ Value res = memref;
+ if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
+ sourceType = MemRefType::get(
+ sourceType.getShape(), sourceType.getElementType(),
+ sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
+ res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
+ }
+ if (sourceType == compatibleMemRefType)
+ return res;
+ return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
+}
+
/// Operates under a scoped context to build the intersection between the
/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
// TODO: view intersection/union/
diff erences should be a proper std op.
@@ -215,6 +233,7 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
/// Produce IR resembling:
/// ```
/// %1:3 = scf.if (%inBounds) {
+/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view, ... : compatibleMemRefType, index, index
/// } else {
@@ -237,9 +256,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
return b.create<scf::IfOp>(
loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
- Value res = memref;
- if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
+ Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
xferOp.getIndices().end());
@@ -256,7 +273,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
alloc);
b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
Value casted =
- b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
+ castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
@@ -270,6 +287,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
/// Produce IR resembling:
/// ```
/// %1:3 = scf.if (%inBounds) {
+/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
/// memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view, ... : compatibleMemRefType, index, index
/// } else {
@@ -292,9 +310,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
return b.create<scf::IfOp>(
loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
- Value res = memref;
- if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
+ Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
xferOp.getIndices().end());
@@ -309,7 +325,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
loc, MemRefType::get({}, vector.getType()), alloc));
Value casted =
- b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
+ castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
@@ -343,9 +359,8 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
.create<scf::IfOp>(
loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
- Value res = memref;
- if (compatibleMemRefType != xferOp.getShapedType())
- res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
+ Value res =
+ castToCompatibleMemRefType(b, memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(),
xferOp.getIndices().begin(),
@@ -354,7 +369,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
},
[&](OpBuilder &b, Location loc) {
Value casted =
- b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
+ castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(),
xferOp.getTransferRank(), zero);
diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
index 5c22c0cf6014ed..956bbd47ebbf1a 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
@@ -101,6 +101,37 @@ func.func @split_vector_transfer_read_strided_2d(
return %1 : vector<4x8xf32>
}
+func.func @split_vector_transfer_read_mem_space(%A: memref<?x8xf32, 3>, %i: index, %j: index) -> vector<4x8xf32> {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0.0 : f32
+
+ // CHECK: scf.if {{.*}} -> (memref<?x8xf32, strided<[8, 1]>>, index, index) {
+ // inBounds with a
diff erent memory space
+ // CHECK: %[[space_cast:.*]] = memref.memory_space_cast %{{.*}} :
+ // CHECK-SAME: memref<?x8xf32, 3> to memref<?x8xf32>
+ // CHECK: %[[cast:.*]] = memref.cast %[[space_cast]] :
+ // CHECK-SAME: memref<?x8xf32> to memref<?x8xf32, strided<[8, 1]>>
+ // CHECK: scf.yield %[[cast]], {{.*}} : memref<?x8xf32, strided<[8, 1]>>, index, index
+ // CHECK: } else {
+ // slow path, fill tmp alloc and yield a memref_casted version of it
+ // CHECK: %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst :
+ // CHECK-SAME: memref<?x8xf32, 3>, vector<4x8xf32>
+ // CHECK: %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] :
+ // CHECK-SAME: memref<4x8xf32> to memref<vector<4x8xf32>>
+ // CHECK: store %[[slow]], %[[cast_alloc]][] : memref<vector<4x8xf32>>
+ // CHECK: %[[yielded:.*]] = memref.cast %[[alloc]] :
+ // CHECK-SAME: memref<4x8xf32> to memref<?x8xf32, strided<[8, 1]>>
+ // CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
+ // CHECK-SAME: memref<?x8xf32, strided<[8, 1]>>, index, index
+ // CHECK: }
+ // CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %cst
+ // CHECK-SAME: {in_bounds = [true, true]} : memref<?x8xf32, strided<[8, 1]>>, vector<4x8xf32>
+
+ %1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32, 3>, vector<4x8xf32>
+
+ return %1: vector<4x8xf32>
+}
+
transform.sequence failures(propagate) {
^bb1(%func_op: !transform.op<"func.func">):
transform.apply_patterns to %func_op {
@@ -228,6 +259,40 @@ transform.sequence failures(propagate) {
} : !transform.op<"func.func">
}
+// -----
+
+func.func @split_vector_transfer_write_mem_space(%V: vector<4x8xf32>, %A: memref<?x8xf32, 3>, %i: index, %j: index) {
+ vector.transfer_write %V, %A[%i, %j] :
+ vector<4x8xf32>, memref<?x8xf32, 3>
+ return
+}
+
+// CHECK: func @split_vector_transfer_write_mem_space(
+// CHECK: scf.if {{.*}} -> (memref<?x8xf32, strided<[8, 1]>>, index, index) {
+// CHECK: %[[space_cast:.*]] = memref.memory_space_cast %{{.*}} :
+// CHECK-SAME: memref<?x8xf32, 3> to memref<?x8xf32>
+// CHECK: %[[cast:.*]] = memref.cast %[[space_cast]] :
+// CHECK-SAME: memref<?x8xf32> to memref<?x8xf32, strided<[8, 1]>>
+// CHECK: scf.yield %[[cast]], {{.*}} : memref<?x8xf32, strided<[8, 1]>>, index, index
+// CHECK: } else {
+// CHECK: %[[VAL_15:.*]] = memref.cast %[[TEMP]]
+// CHECK-SAME: : memref<4x8xf32> to memref<?x8xf32, strided<[8, 1]>>
+// CHECK: scf.yield %[[VAL_15]], %[[C0]], %[[C0]]
+// CHECK-SAME: : memref<?x8xf32, strided<[8, 1]>>, index, index
+// CHECK: }
+// CHECK: vector.transfer_write %[[VEC]],
+// CHECK-SAME: %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2]
+// CHECK-SAME: {in_bounds = [true, true]} : vector<4x8xf32>, memref<?x8xf32, strided<[8, 1]>>
+
+
+transform.sequence failures(propagate) {
+^bb1(%func_op: !transform.op<"func.func">):
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
+ } : !transform.op<"func.func">
+}
+
+
// -----
func.func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> ()
More information about the Mlir-commits
mailing list