[Mlir-commits] [mlir] 199442e - [mlir][vector] Fix uniform transfer_read distribution

Lei Zhang llvmlistbot at llvm.org
Thu Aug 17 17:50:08 PDT 2023


Author: Lei Zhang
Date: 2023-08-17T17:38:55-07:00
New Revision: 199442ea2c128004980cac9ac8713f2f82db5dbb

URL: https://github.com/llvm/llvm-project/commit/199442ea2c128004980cac9ac8713f2f82db5dbb
DIFF: https://github.com/llvm/llvm-project/commit/199442ea2c128004980cac9ac8713f2f82db5dbb.diff

LOG: [mlir][vector] Fix uniform transfer_read distribution

If the original shape and the distributed shape is the same,
we don't distribute at all--every thread is handling the whole.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D158235

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2182a2ebf7f081..61c2b2d580d03a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -710,6 +710,14 @@ bool delinearizeLaneId(OpBuilder &builder, Location loc,
                        ArrayRef<int64_t> originalShape,
                        ArrayRef<int64_t> distributedShape, int64_t warpSize,
                        Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
+  // If the original shape and the distributed shape is the same, we don't
+  // distribute at all--every thread is handling the whole. For such case, we
+  // should not rely on lane IDs later. So just return an empty lane ID vector.
+  if (originalShape == distributedShape) {
+    delinearizedIds.clear();
+    return true;
+  }
+
   SmallVector<int64_t> sizes;
   for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
     if (large % small != 0)
@@ -794,8 +802,9 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
                            warpOp.getLaneid(), delinearizedIds))
       return rewriter.notifyMatchFailure(
           read, "cannot delinearize lane ID for distribution");
+    assert(!delinearizedIds.empty() || map.getNumResults() == 0);
 
-    for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
+    for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
       AffineExpr d0, d1;
       bindDims(read.getContext(), d0, d1);
       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index cd0a14eb5f7211..d69be9dcca1673 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1236,3 +1236,18 @@ func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>)
 // CHECK-PROP:   %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32>
 // CHECK-PROP:   return %[[CAST]] : vector<4xf32>
 
+// -----
+
+func.func @warp_propagate_uniform_transfer_read(%laneid: index, %src: memref<4096xf32>, %index: index) -> vector<1xf32> {
+  %f0 = arith.constant 0.000000e+00 : f32
+  %r = vector.warp_execute_on_lane_0(%laneid)[64] -> (vector<1xf32>) {
+    %1 = vector.transfer_read %src[%index], %f0 {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+    vector.yield %1 : vector<1xf32>
+  }
+  return %r : vector<1xf32>
+}
+
+// CHECK-PROP-LABEL: func.func @warp_propagate_uniform_transfer_read
+//  CHECK-PROP-SAME: (%{{.+}}: index, %[[SRC:.+]]: memref<4096xf32>, %[[INDEX:.+]]: index)
+//       CHECK-PROP:   %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[INDEX]]], %cst {in_bounds = [true]} : memref<4096xf32>, vector<1xf32>
+//       CHECK-PROP:   return %[[READ]] : vector<1xf32>


        


More information about the Mlir-commits mailing list