[PATCH] D152149: [mlir][Vector] Fix a propagation bug with transfer_read

Quentin Colombet via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 5 05:49:44 PDT 2023


qcolombet created this revision.
qcolombet added reviewers: ThomasRaoux, springerm, mravishankar.
qcolombet added a project: MLIR.
Herald added subscribers: bviyer, Moerafaat, zero9178, bzcheeseman, sdasgup3, Groverkss, wenzhicui, wrengr, jsetoain, cota, teijeong, rdzhabarov, tatianashp, msifontes, jurahul, Kayjukh, grosul1, Joonsoo, liufengdb, aartbik, mgester, arpith-jacob, antiagainst, shauheen, rriddle, mehdi_amini.
Herald added a reviewer: aartbik.
Herald added a project: All.
qcolombet requested review of this revision.
Herald added a reviewer: nicolasvasilache.
Herald added subscribers: stephenneuendorffer, nicolasvasilache.
Herald added a reviewer: dcaballe.

In the vector distribute patterns, we used to move
`vector.transfer_read`s out of `vector.warp_execute_on_lane0`s
irrespectively of how they were defined.

This could create transfer_read operations that would read values from
within the warpOp's body from outside of the body.
E.g.,

  warpop {
    %defined_in_body
    %read = transfer_read %defined_in_body
    vector.yield %read
  }

>
=

  warpop {
    %defined_in_body
    vector.yield ...
  }
  // %defined_in_body is referenced outside of its scope.
  %read = transfer_read %defined_in_body

The fix consists in checking that all the values feeding the new
`transfer_read` are defined outside of warpOp's body.

Note: We could do this check before creating any operation, but that would
mean knowing what `affine::makeComposedAffineApply` actually do. So the
current fix is a trade off of coupling the implementations of this
propagation and `makeComposedAffineApply` versus compile time.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D152149

Files:
  mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
  mlir/test/Dialect/Vector/vector-warp-distribute.mlir


Index: mlir/test/Dialect/Vector/vector-warp-distribute.mlir
===================================================================
--- mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1109,3 +1109,46 @@
   }
   return %r : vector<4x96xf32>
 }
+
+// -----
+
+// Check that we don't propagate transfer_reads that have dependencies on
+// values inside the warp_execute_on_lane_0.
+// In this case, propagating would create transfer_read that depends on the
+// extractelment defined in the body.
+
+// CHECK-PROP-LABEL: func @transfer_read_no_prop(
+//  CHECK-PROP-SAME:     %[[IN2:[^ :]*]]: vector<1x2xindex>,
+//  CHECK-PROP-SAME:     %[[AR1:[^ :]*]]: memref<1x4x2xi32>,
+//  CHECK-PROP-SAME:     %[[AR2:[^ :]*]]: memref<1x4x1024xf32>)
+//   CHECK-PROP-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-PROP-DAG:   %[[THREADID:.*]] = gpu.thread_id  x
+//       CHECK-PROP:   %[[W:.*]] = vector.warp_execute_on_lane_0(%[[THREADID]])[32] args(%[[IN2]]
+//       CHECK-PROP:     %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}]
+//       CHECK-PROP:     %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<1x64xi32>
+//       CHECK-PROP:     %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex>
+//       CHECK-PROP:     %[[EXTRACTELT:.*]] = vector.extractelement %[[CAST]][{{.*}}: i32] : vector<64xindex>
+//       CHECK-PROP:     %[[TRANSFERREAD:.*]] = vector.transfer_read %[[AR2]][%[[C0]], %[[EXTRACTELT]], %[[C0]]],
+//       CHECK-PROP:     vector.yield %[[TRANSFERREAD]] : vector<64xf32>
+//       CHECK-PROP:   return %[[W]]
+func.func @transfer_read_no_prop(%in2: vector<1x2xindex>, %ar1 :  memref<1x4x2xi32>, %ar2 : memref<1x4x1024xf32>)-> vector<2xf32> {
+  %0 = gpu.thread_id  x
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0> : vector<1x64xi32>
+  %cst_0 = arith.constant dense<true> : vector<1x64xi1>
+  %cst_1 = arith.constant dense<3> : vector<64xindex>
+  %cst_2 = arith.constant dense<0> : vector<64xindex>
+  %cst_6 = arith.constant 0.000000e+00 : f32
+
+  %18 = vector.warp_execute_on_lane_0(%0)[32] args(%in2 : vector<1x2xindex>) -> (vector<2xf32>) {
+  ^bb0(%arg4: vector<1x64xindex>):
+    %28 = vector.gather %ar1[%c0, %c0, %c0] [%arg4], %cst_0, %cst : memref<1x4x2xi32>, vector<1x64xindex>, vector<1x64xi1>, vector<1x64xi32> into vector<1x64xi32>
+    %29 = vector.extract %28[0] : vector<1x64xi32>
+    %30 = arith.index_cast %29 : vector<64xi32> to vector<64xindex>
+    %36 = vector.extractelement %30[%c0_i32 : i32] : vector<64xindex>
+    %37 = vector.transfer_read %ar2[%c0, %36, %c0], %cst_6 {in_bounds = [true]} : memref<1x4x1024xf32>, vector<64xf32>
+    vector.yield %37 : vector<64xf32>
+  }
+  return %18 : vector<2xf32>
+}
Index: mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
===================================================================
--- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -757,10 +757,31 @@
           rewriter, read.getLoc(), d0 + scale * d1,
           {indices[indexPos], warpOp.getLaneid()});
     }
-    Value newRead = rewriter.create<vector::TransferReadOp>(
+    auto newRead = rewriter.create<vector::TransferReadOp>(
         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
         read.getInBoundsAttr());
+
+    // Check that the produced operation is legal.
+    // The transfer op may be reading from values that are defined within
+    // warpOp's body, which is illegal.
+    // We do the check late because incdices may be changed by
+    // makeComposeAffineApply. This rewrite may remove dependencies from
+    // warOp's body.
+    // E.g., warop {
+    //   %idx = affine.apply...[%outsideDef]
+    //   ... = transfer_read ...[%idx]
+    // }
+    // will be rewritten in:
+    // warop {
+    // }
+    //  %new_idx = affine.apply...[%outsideDef]
+    //   ... = transfer_read ...[%new_idx]
+    if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
+          return warpOp.isDefinedOutsideOfRegion(value);
+        }))
+      return failure();
+
     rewriter.replaceAllUsesWith(distributedVal, newRead);
     return success();
   }


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D152149.528384.patch
Type: text/x-patch
Size: 4358 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20230605/5d6f4edf/attachment.bin>


More information about the llvm-commits mailing list