[Mlir-commits] [mlir] [mlir][vector] Add support for multi-dim reduction vector distribution (PR #71193)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 8 02:26:03 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/71193.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h (+11-2) 
- (modified) mlir/include/mlir/IR/AffineMap.h (+5) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+46-11) 
- (modified) mlir/lib/IR/AffineMap.cpp (+21) 
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+65) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+8-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
index a76a58eb5ec6d3c..3b1ae34a3acdac9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h
@@ -15,6 +15,17 @@ namespace mlir {
 class RewritePatternSet;
 namespace vector {
 
+///  Given a value having a shaped type, returns the distribution map for that
+///  value. The distribution map represents the order of dimensions in which
+///  the shape should be distributed. The map is expected to be a projection of
+///  the shape dimensions. Examples of distribution maps that can be returned:
+///
+///  - Type: vector<16x32x64xf32>,
+///    Map: (d0, d1, d2) -> (d1, d2) : Distribute d1, and then d2
+///  - Type: vector<16x32x64xf32>
+///    Map: (d0, d1, d2) -> (d0, d1, d2) : Distribute d0, then d1 and then d2
+using DistributionMapFn = std::function<AffineMap(Value)>;
+
 struct WarpExecuteOnLane0LoweringOptions {
   /// Lamdba function to let users allocate memory needed for the lowering of
   /// WarpExecuteOnLane0Op.
@@ -40,8 +51,6 @@ void populateWarpExecuteOnLane0OpToScfForPattern(
     const WarpExecuteOnLane0LoweringOptions &options,
     PatternBenefit benefit = 1);
 
-using DistributionMapFn = std::function<AffineMap(Value)>;
-
 /// Distribute transfer_write ops based on the affine map returned by
 /// `distributionMapFn`.
 /// Example:
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 5af7835258f6bd2..b78a6c45360580a 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -330,6 +330,11 @@ class AffineMap {
   /// returns the resulting values. `this` must be symbol-less.
   SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values) const;
 
+  /// Returns true if the AffineMap represents a subset (i.e. a projection) of
+  /// a symbol-less identity map. `allowZeroInResults` allows projected maps
+  /// with constant zero result expressions.
+  bool isProjection() const;
+
   /// Returns true if the AffineMap represents a subset (i.e. a projection) of a
   /// symbol-less permutation map. `allowZeroInResults` allows projected
   /// permutation maps with constant zero result expressions.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 8b4575e96875409..70353cf19e07d14 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -423,25 +423,55 @@ static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
 }
 
 /// Return the distributed vector type based on the original type and the
-/// distribution map. The map is expected to have a dimension equal to the
-/// original type rank and should be a projection where the results are the
-/// distributed dimensions. The number of results should be equal to the number
-/// of warp sizes which is currently limited to 1.
-/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
-/// and a warp size of 16 would distribute the second dimension (associated to
-/// d1) and return vector<16x2x64>
+/// distribution map. The vector should be completely distributable, i.e. the
+/// linearized shape should be a multiple of the warp size. If all threads are
+/// used while distributing the first few dimensions, the rest dimensions may
+/// not be used for distribution.
+///
+/// Example (single-dim): For a vector<16x32x64> distributed with a
+/// map(d0, d1, d2) -> (d1) and a warp size of 16 would distribute the second
+/// dimension (associated to d1) and return vector<16x2x64>.
+///
+/// Example (multi-dim): For a vector<16x32x64> distributed with a
+/// map(d0, d1, d2) -> (d1, d2), and a warp size of 128 would distribute first
+/// the second dimension and then the third dimension, finally returning a
+/// vector <4x1x64>.
 static VectorType getDistributedType(VectorType originalType, AffineMap map,
                                      int64_t warpSize) {
-  if (map.getNumResults() != 1)
+  if (!map.isProjection()) {
+    assert(false && "expected distribution map to be a projection");
     return VectorType();
+  }
+
   SmallVector<int64_t> targetShape(originalType.getShape().begin(),
                                    originalType.getShape().end());
+  // Distribute the vector based on the order of dimensions in the affine map.
+  int64_t availableThreads = warpSize;
   for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
     unsigned position = map.getDimPosition(i);
-    if (targetShape[position] % warpSize != 0)
-      return VectorType();
-    targetShape[position] = targetShape[position] / warpSize;
+    int64_t &dimSize = targetShape[position];
+    if (availableThreads > dimSize) {
+      // We have more threads available than the size of the dimension, so we
+      // distribute the with size 1 along this dimension.
+      if (availableThreads % dimSize != 0)
+        return VectorType();
+      availableThreads = availableThreads / dimSize;
+      dimSize = 1;
+    } else {
+      // We have the dimension is bigger than the number of threads available,
+      // so we distribute with size > 1 along this dimension.
+      if (dimSize % availableThreads != 0)
+        return VectorType();
+      dimSize = dimSize / availableThreads;
+      availableThreads = 1;
+      break;
+    }
   }
+
+  // If we could not distribute the whole vector, we fail.
+  if (availableThreads != 1)
+    return VectorType();
+
   VectorType targetType =
       VectorType::get(targetShape, originalType.getElementType());
   return targetType;
@@ -710,6 +740,7 @@ bool delinearizeLaneId(OpBuilder &builder, Location loc,
                        ArrayRef<int64_t> originalShape,
                        ArrayRef<int64_t> distributedShape, int64_t warpSize,
                        Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
+
   // If the original shape and the distributed shape is the same, we don't
   // distribute at all--every thread is handling the whole. For such case, we
   // should not rely on lane IDs later. So just return an empty lane ID vector.
@@ -1485,6 +1516,10 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
           }
         });
 
+    // Check if any types could not be distributed.
+    if (llvm::any_of(distTypes, [](Type t) { return !t; }))
+      return failure();
+
     SmallVector<size_t> newRetIndices;
     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 3bd1181b6c7bbd8..a0c9d908833b882 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -532,6 +532,27 @@ SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
   return res;
 }
 
+bool AffineMap::isProjection() const {
+  if (getNumSymbols() > 0)
+    return false;
+
+  // A projection cannot have more results than inputs.
+  if (getNumResults() > getNumInputs())
+    return false;
+
+  int64_t current = -1;
+  // A projection must always have dim position > current.
+  for (auto expr : getResults()) {
+    if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
+      if (dim.getPosition() <= current)
+        return false;
+      current = dim.getPosition();
+    }
+  }
+
+  return true;
+}
+
 bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
   if (getNumSymbols() > 0)
     return false;
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 3bb981c7a623886..e668caf889563ee 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -494,6 +494,71 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
 
 // -----
 
+//   CHECK-PROP-LABEL:   func @warp_scf_for_multi_reduce(
+//     CHECK-PROP-NOT:   vector.warp_execute_on_lane_0
+//         CHECK-PROP:   scf.for {{.*}} -> (vector<1x4xf32>) {        
+//         CHECK-PROP:     scf.for {{.*}} -> (vector<1x4xf32>) {
+//         CHECK-PROP:       vector.transfer_read {{.*}} : memref<2x32x40x384xf32>, vector<1x4xf32> 
+//         CHECK-PROP:     }
+//         CHECK-PROP:   }
+//         CHECK-PROP:   vector.reduction <add>
+// CHECK-PROP-COUNT=8:   gpu.shuffle
+//
+//         CHECK-PROP:   scf.for {{.*}} {
+//         CHECK-PROP:     vector.transfer_read
+//         CHECK-PROP:     scf.for {{.*}} {
+//         CHECK-PROP:       vector.warp_execute_on_lane_0
+//         CHECK-PROP:         vector.transfer_read
+//         CHECK-PROP:         vector.transfer_write
+//         CHECK-PROP:       }
+//         CHECK-PROP:     }
+#map = affine_map<(d0, d1) -> (0, 0)>
+func.func @warp_scf_for_multi_reduce(%arg0: memref<2x32x40x384xf32>, %arg1: memref<2x32x40x384xf16>, %arg2: memref<2x32xf32>, %arg3: memref<2x32x40x384xf16>) {
+  %cst = arith.constant dense<1.536000e+04> : vector<8x128xf32>
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<8x128xf32>
+  %cst_1 = arith.constant 9.99999997E-7 : f32
+  %c128 = arith.constant 128 : index
+  %c8 = arith.constant 8 : index
+  %c0 = arith.constant 0 : index
+  %c40 = arith.constant 40 : index
+  %c384 = arith.constant 384 : index
+  %cst_2 = arith.constant 0.000000e+00 : f16
+  %cst_3 = arith.constant 0.000000e+00 : f32
+  %0 = gpu.thread_id  x
+  %1 = arith.truncf %cst_1 : f32 to f16
+  vector.warp_execute_on_lane_0(%0)[256] {
+    %2 = scf.for %arg4 = %c0 to %c40 step %c8 iter_args(%arg5 = %cst_0) -> (vector<8x128xf32>) {
+      %11 = scf.for %arg6 = %c0 to %c384 step %c128 iter_args(%arg7 = %arg5) -> (vector<8x128xf32>) {
+        %12 = vector.transfer_read %arg0[%c0, %c0, %arg4, %arg6], %cst_3 {in_bounds = [true, true]} : memref<2x32x40x384xf32>, vector<8x128xf32>
+        %13 = arith.addf %12, %arg7 : vector<8x128xf32>
+        scf.yield %13 : vector<8x128xf32>
+      }
+      scf.yield %11 : vector<8x128xf32>
+    }
+    %3 = vector.shape_cast %2 : vector<8x128xf32> to vector<1024xf32>
+    %4 = vector.reduction <add>, %3, %cst_3 : vector<1024xf32> into f32
+    %5 = vector.broadcast %4 : f32 to vector<8x128xf32>
+    %6 = arith.divf %5, %cst : vector<8x128xf32>
+    %7 = arith.truncf %6 : vector<8x128xf32> to vector<8x128xf16>
+    %8 = vector.broadcast %1 : f16 to vector<8x128xf16>
+    %9 = arith.addf %7, %8 : vector<8x128xf16>
+    %10 = math.rsqrt %9 : vector<8x128xf16>
+    scf.for %arg4 = %c0 to %c40 step %c8 {
+      %11 = vector.transfer_read %arg2[%c0, %c0], %cst_3 {in_bounds = [true, true], permutation_map = #map} : memref<2x32xf32>, vector<8x128xf32>
+      %12 = arith.truncf %11 : vector<8x128xf32> to vector<8x128xf16>
+      scf.for %arg5 = %c0 to %c384 step %c128 {
+        %13 = vector.transfer_read %arg1[%c0, %c0, %arg4, %arg5], %cst_2 {in_bounds = [true, true]} : memref<2x32x40x384xf16>, vector<8x128xf16>
+        %14 = arith.subf %13, %12 : vector<8x128xf16>
+        %15 = arith.mulf %14, %10 : vector<8x128xf16>
+        vector.transfer_write %15, %arg3[%c0, %c0, %arg4, %arg5] {in_bounds = [true, true]} : vector<8x128xf16>, memref<2x32x40x384xf16>
+      }
+    }
+  }
+  return
+}
+
+// -----
+
 // CHECK-PROP-LABEL: func @vector_reduction(
 //  CHECK-PROP-SAME:     %[[laneid:.*]]: index)
 //   CHECK-PROP-DAG:   %[[c1:.*]] = arith.constant 1 : i32
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 2fbf1babf437f08..b996d87be396077 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -569,15 +569,17 @@ struct TestVectorDistribution
     });
     MLIRContext *ctx = &getContext();
     auto distributionFn = [](Value val) {
-      // Create a map (d0, d1) -> (d1) to distribute along the inner
-      // dimension. Once we support n-d distribution we can add more
-      // complex cases.
+      // Create a map (d0, d1) -> (d1, d0) to distribute starting from the inner
+      // dimensions.
       VectorType vecType = dyn_cast<VectorType>(val.getType());
       int64_t vecRank = vecType ? vecType.getRank() : 0;
       OpBuilder builder(val.getContext());
-      if (vecRank == 0)
-        return AffineMap::get(val.getContext());
-      return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+      SmallVector<AffineExpr, 4> vecDims =
+          llvm::map_to_vector(llvm::seq<int64_t>(0, vecRank), [&](int64_t i) {
+            return builder.getAffineDimExpr(i);
+          });
+      return AffineMap::get(vecRank, /*symbolCount=*/0, vecDims,
+                            builder.getContext());
     };
     auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
                         Value srcIdx, int64_t warpSz) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/71193


More information about the Mlir-commits mailing list