[Mlir-commits] [mlir] 4abb9e5 - [mlir][vector] Clean up and generalize lowering of warp_execute to scf

Thomas Raoux llvmlistbot at llvm.org
Wed Sep 14 10:36:29 PDT 2022


Author: Thomas Raoux
Date: 2022-09-14T17:36:16Z
New Revision: 4abb9e5d2054be9d1a9d2d859675aa9bb9c9a105

URL: https://github.com/llvm/llvm-project/commit/4abb9e5d2054be9d1a9d2d859675aa9bb9c9a105
DIFF: https://github.com/llvm/llvm-project/commit/4abb9e5d2054be9d1a9d2d859675aa9bb9c9a105.diff

LOG: [mlir][vector] Clean up and generalize lowering of warp_execute to scf

Simplify the lowering of warp_execute_on_lane0 of scf.if by making the
logic more generic. Also remove the assumption that the most inner
dimension is the dimension distributed.

Differential Revision: https://reviews.llvm.org/D133826

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 3064bdf17bb79..7bc1799a1fe58 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -20,20 +20,33 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-/// TODO: add an analysis step that determines which vector dimension should be
-/// used for distribution.
-static llvm::Optional<int64_t>
-getDistributedVectorDim(VectorType distributedVectorType) {
-  return (distributedVectorType)
-             ? llvm::Optional<int64_t>(distributedVectorType.getRank() - 1)
-             : llvm::None;
-}
-
-static llvm::Optional<int64_t>
-getDistributedSize(VectorType distributedVectorType) {
-  auto dim = getDistributedVectorDim(distributedVectorType);
-  return (dim) ? llvm::Optional<int64_t>(distributedVectorType.getDimSize(*dim))
-               : llvm::None;
+/// Currently the distribution map is implicit based on the vector shape. In the
+/// future it will be part of the op.
+/// Example:
+/// ```
+/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
+///   ...
+///   vector.yield %3 : vector<32x16x64xf32>
+/// }
+/// ```
+/// Would have an implicit map of:
+/// `(d0, d1, d2) -> (d0, d2)`
+static AffineMap calculateImplicitMap(VectorType sequentialType,
+                                      VectorType distributedType) {
+  SmallVector<AffineExpr> perm;
+  perm.reserve(1);
+  // Check which dimensions of the sequential type are 
diff erent than the
+  // dimensions of the distributed type to know the distributed dimensions. Then
+  // associate each distributed dimension to an ID in order.
+  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
+    if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
+      perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
+  }
+  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
+                            distributedType.getContext());
+  assert(map.getNumResults() <= 1 &&
+         "only support distribution along one dimension for now.");
+  return map;
 }
 
 namespace {
@@ -42,28 +55,23 @@ namespace {
 /// through the parallel / sequential and the sequential / parallel boundaries
 /// when performing `rewriteWarpOpToScfFor`.
 ///
-/// All this assumes the vector distribution occurs along the most minor
-/// distributed vector dimension.
-/// TODO: which is expected to be a multiple of the warp size ?
-/// TODO: add an analysis step that determines which vector dimension should
-/// be used for distribution.
+/// The vector distribution dimension is inferred from the vector types.
 struct DistributedLoadStoreHelper {
   DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
                              Value laneId, Value zero)
       : sequentialVal(sequentialVal), distributedVal(distributedVal),
         laneId(laneId), zero(zero) {
-    sequentialType = sequentialVal.getType();
-    distributedType = distributedVal.getType();
-    sequentialVectorType = sequentialType.dyn_cast<VectorType>();
-    distributedVectorType = distributedType.dyn_cast<VectorType>();
+    sequentialVectorType = sequentialVal.getType().dyn_cast<VectorType>();
+    distributedVectorType = distributedVal.getType().dyn_cast<VectorType>();
+    if (sequentialVectorType && distributedVectorType)
+      distributionMap =
+          calculateImplicitMap(sequentialVectorType, distributedVectorType);
   }
 
-  Value buildDistributedOffset(RewriterBase &b, Location loc) {
-    auto maybeDistributedSize = getDistributedSize(distributedVectorType);
-    assert(maybeDistributedSize &&
-           "at this point, a distributed size must be determined");
+  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
+    int64_t distributedSize = distributedVectorType.getDimSize(index);
     AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
-    return b.createOrFold<AffineApplyOp>(loc, tid * (*maybeDistributedSize),
+    return b.createOrFold<AffineApplyOp>(loc, tid * distributedSize,
                                          ArrayRef<Value>{laneId});
   }
 
@@ -79,27 +87,24 @@ struct DistributedLoadStoreHelper {
     assert((val == distributedVal || val == sequentialVal) &&
            "Must store either the preregistered distributed or the "
            "preregistered sequential value.");
+    // Scalar case can directly use memref.store.
+    if (!val.getType().isa<VectorType>())
+      return b.create<memref::StoreOp>(loc, val, buffer, zero);
+
     // Vector case must use vector::TransferWriteOp which will later lower to
     //   vector.store of memref.store depending on further lowerings.
-    if (val.getType().isa<VectorType>()) {
-      int64_t rank = sequentialVectorType.getRank();
-      if (rank == 0) {
-        return b.create<vector::TransferWriteOp>(loc, val, buffer, ValueRange{},
-                                                 ArrayRef<bool>{});
+    int64_t rank = sequentialVectorType.getRank();
+    SmallVector<Value> indices(rank, zero);
+    if (val == distributedVal) {
+      for (auto dimExpr : distributionMap.getResults()) {
+        int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
+        indices[index] = buildDistributedOffset(b, loc, index);
       }
-      SmallVector<Value> indices(rank, zero);
-      auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType);
-      assert(maybeDistributedDim && "must be able to deduce distributed dim");
-      if (val == distributedVal)
-        indices[*maybeDistributedDim] =
-            (val == distributedVal) ? buildDistributedOffset(b, loc) : zero;
-      SmallVector<bool> inBounds(indices.size(), true);
-      return b.create<vector::TransferWriteOp>(
-          loc, val, buffer, indices,
-          ArrayRef<bool>(inBounds.begin(), inBounds.end()));
     }
-    // Scalar case can directly use memref.store.
-    return b.create<memref::StoreOp>(loc, val, buffer, zero);
+    SmallVector<bool> inBounds(indices.size(), true);
+    return b.create<vector::TransferWriteOp>(
+        loc, val, buffer, indices,
+        ArrayRef<bool>(inBounds.begin(), inBounds.end()));
   }
 
   /// Create a load during the process of distributing the
@@ -122,36 +127,24 @@ struct DistributedLoadStoreHelper {
   ///   // Both types are f32. The constant %cst is broadcasted to all lanes.
   /// ```
   /// This behavior described in more detail in the documentation of the op.
-  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer,
-                  bool broadcastMode = false) {
-    if (broadcastMode) {
-      // Broadcast mode may occur for either scalar or vector operands.
-      auto vectorType = type.dyn_cast<VectorType>();
-      auto shape = buffer.getType().cast<MemRefType>();
-      if (vectorType) {
-        SmallVector<bool> inBounds(shape.getRank(), true);
-        return b.create<vector::TransferReadOp>(
-            loc, vectorType, buffer,
-            /*indices=*/SmallVector<Value>(shape.getRank(), zero),
-            ArrayRef<bool>(inBounds.begin(), inBounds.end()));
-      }
+  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
+
+    // Scalar case can directly use memref.store.
+    if (!type.isa<VectorType>())
       return b.create<memref::LoadOp>(loc, buffer, zero);
-    }
 
     // Other cases must be vector atm.
     // Vector case must use vector::TransferReadOp which will later lower to
     //   vector.read of memref.read depending on further lowerings.
-    assert(type.isa<VectorType>() && "must be a vector type");
     assert((type == distributedVectorType || type == sequentialVectorType) &&
            "Must store either the preregistered distributed or the "
            "preregistered sequential type.");
-    auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType);
-    assert(maybeDistributedDim && "must be able to deduce distributed dim");
     SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
     if (type == distributedVectorType) {
-      indices[*maybeDistributedDim] = buildDistributedOffset(b, loc);
-    } else {
-      indices[*maybeDistributedDim] = zero;
+      for (auto dimExpr : distributionMap.getResults()) {
+        int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
+        indices[index] = buildDistributedOffset(b, loc, index);
+      }
     }
     SmallVector<bool> inBounds(indices.size(), true);
     return b.create<vector::TransferReadOp>(
@@ -160,8 +153,8 @@ struct DistributedLoadStoreHelper {
   }
 
   Value sequentialVal, distributedVal, laneId, zero;
-  Type sequentialType, distributedType;
   VectorType sequentialVectorType, distributedVectorType;
+  AffineMap distributionMap;
 };
 
 } // namespace
@@ -262,32 +255,6 @@ static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
   return rewriter.create(res);
 }
 
-/// Currently the distribution map is implicit based on the vector shape. In the
-/// future it will be part of the op.
-/// Example:
-/// ```
-/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
-///   ...
-///   vector.yield %3 : vector<32x16x64xf32>
-/// }
-/// ```
-/// Would have an implicit map of:
-/// `(d0, d1, d2) -> (d0, d2)`
-static AffineMap calculateImplicitMap(Value yield, Value ret) {
-  auto srcType = yield.getType().cast<VectorType>();
-  auto dstType = ret.getType().cast<VectorType>();
-  SmallVector<AffineExpr> perm;
-  // Check which dimensions of the yield value are 
diff erent than the dimensions
-  // of the result to know the distributed dimensions. Then associate each
-  // distributed dimension to an ID in order.
-  for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
-    if (srcType.getDimSize(i) != dstType.getDimSize(i))
-      perm.push_back(getAffineDimExpr(i, yield.getContext()));
-  }
-  auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
-  return map;
-}
-
 namespace {
 
 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
@@ -318,13 +285,10 @@ namespace {
 ///
 /// All this assumes the vector distribution occurs along the most minor
 /// distributed vector dimension.
-/// TODO: which is expected to be a multiple of the warp size ?
-/// TODO: add an analysis step that determines which vector dimension should be
-/// used for distribution.
-struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
-  WarpOpToScfForPattern(MLIRContext *context,
-                        const WarpExecuteOnLane0LoweringOptions &options,
-                        PatternBenefit benefit = 1)
+struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
+  WarpOpToScfIfPattern(MLIRContext *context,
+                       const WarpExecuteOnLane0LoweringOptions &options,
+                       PatternBenefit benefit = 1)
       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
         options(options) {}
 
@@ -364,10 +328,8 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
       helper.buildStore(rewriter, loc, distributedVal, buffer);
       // Load sequential vector from buffer, inside the ifOp.
       rewriter.setInsertionPointToStart(ifOp.thenBlock());
-      bool broadcastMode =
-          (sequentialVal.getType() == distributedVal.getType());
-      bbArgReplacements.push_back(helper.buildLoad(
-          rewriter, loc, sequentialVal.getType(), buffer, broadcastMode));
+      bbArgReplacements.push_back(
+          helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
     }
 
     // Step 3. Insert sync after all the stores and before all the loads.
@@ -404,8 +366,6 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
       // Load distributed value from buffer, after  the warpOp.
       rewriter.setInsertionPointAfter(ifOp);
-      bool broadcastMode =
-          (sequentialVal.getType() == distributedVal.getType());
       // Result type and yielded value type are the same. This is a broadcast.
       // E.g.:
       // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
@@ -413,8 +373,8 @@ struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
       // }
       // Both types are f32. The constant %cst is broadcasted to all lanes.
       // This is described in more detail in the documentation of the op.
-      replacements.push_back(helper.buildLoad(
-          rewriter, loc, distributedVal.getType(), buffer, broadcastMode));
+      replacements.push_back(
+          helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
     }
 
     // Step 6. Insert sync after all the stores and before all the loads.
@@ -758,7 +718,9 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
     SmallVector<Value, 4> indices(read.getIndices().begin(),
                                   read.getIndices().end());
-    AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
+    auto sequentialType = read.getResult().getType().cast<VectorType>();
+    auto distributedType = distributedVal.getType().cast<VectorType>();
+    AffineMap map = calculateImplicitMap(sequentialType, distributedType);
     AffineMap indexMap = map.compose(read.getPermutationMap());
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(warpOp);
@@ -1118,7 +1080,7 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
     RewritePatternSet &patterns,
     const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
-  patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options, benefit);
+  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateDistributeTransferWriteOpPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 20bc623064f9f..200033d608ec8 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -734,3 +734,45 @@ func.func @warp_execute_has_broadcast_semantics(%laneid: index, %s0: f32, %v0: v
   // CHECK-SCF-IF: return %[[RS0]], %[[RV0]], %[[RV1]], %[[RV2]] : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>
   return %r#0, %r#1, %r#2, %r#3 : f32, vector<f32>, vector<1xf32>, vector<1x1xf32>
 }
+
+// -----
+
+// CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)>
+
+// CHECK-SCF-IF:   func @warp_execute_nd_distribute
+// CHECK-SCF-IF-SAME: (%[[LANEID:.*]]: index
+func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %v1: vector<1x2x128xf32>) 
+    -> (vector<1x64x1xf32>, vector<1x2x128xf32>) {
+  // CHECK-SCF-IF-DAG: %[[C0:.*]] = arith.constant 0 : index
+
+  // CHECK-SCF-IF:  vector.transfer_write %{{.*}}, %{{.*}}[%[[LANEID]], %c0, %c0] {in_bounds = [true, true, true]} : vector<1x64x1xf32>, memref<32x64x1xf32, 3>
+  // CHECK-SCF-IF:  %[[RID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]]
+  // CHECK-SCF-IF:  vector.transfer_write %{{.*}}, %{{.*}}[%[[C0]], %[[RID]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x2x128xf32>, memref<1x64x128xf32, 3>
+  // CHECK-SCF-IF:  gpu.barrier
+
+  // CHECK-SCF-IF: scf.if{{.*}}{
+  %r:2 = vector.warp_execute_on_lane_0(%laneid)[32]
+      args(%v0, %v1 : vector<1x64x1xf32>, vector<1x2x128xf32>) -> (vector<1x64x1xf32>, vector<1x2x128xf32>) {
+    ^bb0(%arg0: vector<32x64x1xf32>, %arg1: vector<1x64x128xf32>):
+
+  // CHECK-SCF-IF-DAG: %[[SR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<32x64x1xf32>
+  // CHECK-SCF-IF-DAG: %[[SR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x64x128xf32>
+  //     CHECK-SCF-IF: %[[W0:.*]] = "some_def_0"(%[[SR0]]) : (vector<32x64x1xf32>) -> vector<32x64x1xf32>
+  //     CHECK-SCF-IF: %[[W1:.*]] = "some_def_1"(%[[SR1]]) : (vector<1x64x128xf32>) -> vector<1x64x128xf32>
+  // CHECK-SCF-IF-DAG: vector.transfer_write %[[W0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<32x64x1xf32>, memref<32x64x1xf32, 3>
+  // CHECK-SCF-IF-DAG: vector.transfer_write %[[W1]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x64x128xf32>, memref<1x64x128xf32, 3>
+
+      %r0 = "some_def_0"(%arg0) : (vector<32x64x1xf32>) -> vector<32x64x1xf32>
+      %r1 = "some_def_1"(%arg1) : (vector<1x64x128xf32>) -> vector<1x64x128xf32>
+
+      // CHECK-SCF-IF-NOT: vector.yield
+      vector.yield %r0, %r1 : vector<32x64x1xf32>, vector<1x64x128xf32>
+  }
+
+  //     CHECK-SCF-IF: gpu.barrier
+  //     CHECK-SCF-IF: %[[WID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]]
+  // CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32>
+  // CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32>
+  //     CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32>
+  return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32>
+}


        


More information about the Mlir-commits mailing list