[Mlir-commits] [mlir] 5b6b2ca - [mlir][vector] Handle memory space conflicts in VectorTransferSplit patterns

Quinn Dawkins llvmlistbot at llvm.org
Tue Jul 11 20:27:06 PDT 2023


Author: Quinn Dawkins
Date: 2023-07-11T22:58:23-04:00
New Revision: 5b6b2caf3c6ec3cdf5565f935263361fbb2013cd

URL: https://github.com/llvm/llvm-project/commit/5b6b2caf3c6ec3cdf5565f935263361fbb2013cd
DIFF: https://github.com/llvm/llvm-project/commit/5b6b2caf3c6ec3cdf5565f935263361fbb2013cd.diff

LOG: [mlir][vector] Handle memory space conflicts in VectorTransferSplit patterns

Currently the transfer splitting patterns will generate an invalid cast
when the source memref for a transfer op has a non-default memory space.
This is handled by first introducing a `memref.memory_space_cast` in
such cases.

Differential Revision: https://reviews.llvm.org/D154515

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
    mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 88253f1c520680..1d240374f8a8f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -166,6 +166,24 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
       StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
 }
 
+/// Casts the given memref to a compatible memref type. If the source memref has
+/// a 
diff erent address space than the target type, a `memref.memory_space_cast`
+/// is first inserted, followed by a `memref.cast`.
+static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
+                                        MemRefType compatibleMemRefType) {
+  MemRefType sourceType = memref.getType().cast<MemRefType>();
+  Value res = memref;
+  if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
+    sourceType = MemRefType::get(
+        sourceType.getShape(), sourceType.getElementType(),
+        sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
+    res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
+  }
+  if (sourceType == compatibleMemRefType)
+    return res;
+  return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
+}
+
 /// Operates under a scoped context to build the intersection between the
 /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`.
 // TODO: view intersection/union/
diff erences should be a proper std op.
@@ -215,6 +233,7 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
 /// Produce IR resembling:
 /// ```
 ///    %1:3 = scf.if (%inBounds) {
+///      (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
 ///      %view = memref.cast %A: memref<A...> to compatibleMemRefType
 ///      scf.yield %view, ... : compatibleMemRefType, index, index
 ///    } else {
@@ -237,9 +256,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
   return b.create<scf::IfOp>(
       loc, inBoundsCond,
       [&](OpBuilder &b, Location loc) {
-        Value res = memref;
-        if (compatibleMemRefType != xferOp.getShapedType())
-          res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
+        Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
         scf::ValueVector viewAndIndices{res};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
                               xferOp.getIndices().end());
@@ -256,7 +273,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
             alloc);
         b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
         Value casted =
-            b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
+            castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
         scf::ValueVector viewAndIndices{casted};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
                               zero);
@@ -270,6 +287,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
 /// Produce IR resembling:
 /// ```
 ///    %1:3 = scf.if (%inBounds) {
+///      (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
 ///      memref.cast %A: memref<A...> to compatibleMemRefType
 ///      scf.yield %view, ... : compatibleMemRefType, index, index
 ///    } else {
@@ -292,9 +310,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
   return b.create<scf::IfOp>(
       loc, inBoundsCond,
       [&](OpBuilder &b, Location loc) {
-        Value res = memref;
-        if (compatibleMemRefType != xferOp.getShapedType())
-          res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
+        Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
         scf::ValueVector viewAndIndices{res};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
                               xferOp.getIndices().end());
@@ -309,7 +325,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
                 loc, MemRefType::get({}, vector.getType()), alloc));
 
         Value casted =
-            b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
+            castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
         scf::ValueVector viewAndIndices{casted};
         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
                               zero);
@@ -343,9 +359,8 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
       .create<scf::IfOp>(
           loc, inBoundsCond,
           [&](OpBuilder &b, Location loc) {
-            Value res = memref;
-            if (compatibleMemRefType != xferOp.getShapedType())
-              res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
+            Value res =
+                castToCompatibleMemRefType(b, memref, compatibleMemRefType);
             scf::ValueVector viewAndIndices{res};
             viewAndIndices.insert(viewAndIndices.end(),
                                   xferOp.getIndices().begin(),
@@ -354,7 +369,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
           },
           [&](OpBuilder &b, Location loc) {
             Value casted =
-                b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
+                castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
             scf::ValueVector viewAndIndices{casted};
             viewAndIndices.insert(viewAndIndices.end(),
                                   xferOp.getTransferRank(), zero);

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
index 5c22c0cf6014ed..956bbd47ebbf1a 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir
@@ -101,6 +101,37 @@ func.func @split_vector_transfer_read_strided_2d(
   return %1 : vector<4x8xf32>
 }
 
+func.func @split_vector_transfer_read_mem_space(%A: memref<?x8xf32, 3>, %i: index, %j: index) -> vector<4x8xf32> {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0.0 : f32
+
+  //      CHECK: scf.if {{.*}} -> (memref<?x8xf32, strided<[8, 1]>>, index, index) {
+  //               inBounds with a 
diff erent memory space
+  //      CHECK:   %[[space_cast:.*]] = memref.memory_space_cast %{{.*}} :
+  // CHECK-SAME:     memref<?x8xf32, 3> to memref<?x8xf32>
+  //      CHECK:   %[[cast:.*]] = memref.cast %[[space_cast]] :
+  // CHECK-SAME:     memref<?x8xf32> to memref<?x8xf32, strided<[8, 1]>>
+  //      CHECK:   scf.yield %[[cast]], {{.*}} : memref<?x8xf32, strided<[8, 1]>>, index, index
+  //      CHECK: } else {
+  //               slow path, fill tmp alloc and yield a memref_casted version of it
+  //      CHECK:   %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst :
+  // CHECK-SAME:     memref<?x8xf32, 3>, vector<4x8xf32>
+  //      CHECK:   %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] :
+  // CHECK-SAME:     memref<4x8xf32> to memref<vector<4x8xf32>>
+  //      CHECK:   store %[[slow]], %[[cast_alloc]][] : memref<vector<4x8xf32>>
+  //      CHECK:   %[[yielded:.*]] = memref.cast %[[alloc]] :
+  // CHECK-SAME:     memref<4x8xf32> to memref<?x8xf32, strided<[8, 1]>>
+  //      CHECK:   scf.yield %[[yielded]], %[[c0]], %[[c0]] :
+  // CHECK-SAME:     memref<?x8xf32, strided<[8, 1]>>, index, index
+  //      CHECK: }
+  //      CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %cst
+  // CHECK-SAME:   {in_bounds = [true, true]} : memref<?x8xf32, strided<[8, 1]>>, vector<4x8xf32>
+
+  %1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32, 3>, vector<4x8xf32>
+
+  return %1: vector<4x8xf32>
+}
+
 transform.sequence failures(propagate) {
 ^bb1(%func_op: !transform.op<"func.func">):
   transform.apply_patterns to %func_op {
@@ -228,6 +259,40 @@ transform.sequence failures(propagate) {
   } : !transform.op<"func.func">
 }
 
+// -----
+
+func.func @split_vector_transfer_write_mem_space(%V: vector<4x8xf32>, %A: memref<?x8xf32, 3>, %i: index, %j: index) {
+  vector.transfer_write %V, %A[%i, %j] :
+    vector<4x8xf32>, memref<?x8xf32, 3>
+  return
+}
+
+// CHECK:     func @split_vector_transfer_write_mem_space(
+// CHECK:           scf.if {{.*}} -> (memref<?x8xf32, strided<[8, 1]>>, index, index) {
+// CHECK:             %[[space_cast:.*]] = memref.memory_space_cast %{{.*}} :
+// CHECK-SAME:          memref<?x8xf32, 3> to memref<?x8xf32>
+// CHECK:             %[[cast:.*]] = memref.cast %[[space_cast]] :
+// CHECK-SAME:          memref<?x8xf32> to memref<?x8xf32, strided<[8, 1]>>
+// CHECK:             scf.yield %[[cast]], {{.*}} : memref<?x8xf32, strided<[8, 1]>>, index, index
+// CHECK:           } else {
+// CHECK:             %[[VAL_15:.*]] = memref.cast %[[TEMP]]
+// CHECK-SAME:            : memref<4x8xf32> to memref<?x8xf32, strided<[8, 1]>>
+// CHECK:             scf.yield %[[VAL_15]], %[[C0]], %[[C0]]
+// CHECK-SAME:            : memref<?x8xf32, strided<[8, 1]>>, index, index
+// CHECK:           }
+// CHECK:           vector.transfer_write %[[VEC]],
+// CHECK-SAME:           %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2]
+// CHECK-SAME:           {in_bounds = [true, true]} : vector<4x8xf32>, memref<?x8xf32, strided<[8, 1]>>
+
+
+transform.sequence failures(propagate) {
+^bb1(%func_op: !transform.op<"func.func">):
+  transform.apply_patterns to %func_op {
+    transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer"
+  } : !transform.op<"func.func">
+}
+
+
 // -----
 
 func.func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> ()


        


More information about the Mlir-commits mailing list