[Mlir-commits] [mlir] ff8775f - [mlir][GPU] Add op for unrolling contractions to a native size

Quinn Dawkins llvmlistbot at llvm.org
Tue Jul 25 10:13:06 PDT 2023


Author: Quinn Dawkins
Date: 2023-07-25T13:11:32-04:00
New Revision: ff8775f3fffea787922eb238c0af3ed5715485d9

URL: https://github.com/llvm/llvm-project/commit/ff8775f3fffea787922eb238c0af3ed5715485d9
DIFF: https://github.com/llvm/llvm-project/commit/ff8775f3fffea787922eb238c0af3ed5715485d9.diff

LOG: [mlir][GPU] Add op for unrolling contractions to a native size

Adds `apply_patterns.gpu.unroll_vectors_subgroup_mma` which allows
specifying a native MMA shape of `m`, `n`, and `k` to unroll to,
greedily unrolling the inner most dimension of contractions and other
vector operations based on expected usage.

Differential Revision: https://reviews.llvm.org/D156079

Added: 
    mlir/test/Dialect/GPU/subgroup-mma-vector-unroll.mlir

Modified: 
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
index 97c0f0f18a22db..c7ffbafeefd023 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
@@ -14,6 +14,29 @@ include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def ApplyUnrollVectorsSubgroupMmaOp : Op<Transform_Dialect,
+    "apply_patterns.gpu.unroll_vectors_subgroup_mma",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Unrolls contractions to the target `m`, `n`, and `k` native vector size,
+    along with other vector operations based on expected usage. `transfer_read`
+    ops unroll based on the extract slice shape introduced by unrolling the
+    contractions, while elementwise and `transfer_write` ops unroll to the shape of
+    the C matrix (`m x n`).
+
+    This operation applies to pure vector operations and should be applied before
+    lowering to subgroup_mma ops.
+  }];
+
+  let arguments = (ins I64Attr:$m,
+                       I64Attr:$n,
+                       I64Attr:$k);
+
+  let assemblyFormat = [{
+    `[` $m `,` $n `,` $k `]` attr-dict
+  }];
+}
+
 def EliminateBarriersOp :
   Op<Transform_Dialect, "apply_patterns.gpu.eliminate_barriers",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 07470c24ae2d69..ddbe5d47ff4456 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -46,6 +47,132 @@ using namespace mlir::transform::gpu;
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
 
+//===----------------------------------------------------------------------===//
+// ApplyUnrollVectorsSubgroupMmaOp
+//===----------------------------------------------------------------------===//
+
+/// Pick an unrolling order that will allow tensorcore operation to reuse LHS
+/// register.
+static std::optional<SmallVector<int64_t>>
+gpuMmaUnrollOrder(vector::ContractionOp contract) {
+  SmallVector<int64_t> order;
+  // First make reduction the outer dimensions.
+  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
+    if (vector::isReductionIterator(iter)) {
+      order.push_back(index);
+    }
+  }
+
+  llvm::SmallDenseSet<int64_t> dims;
+  for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) {
+    dims.insert(expr.cast<AffineDimExpr>().getPosition());
+  }
+  // Then parallel dimensions that are part of Lhs as we want to re-use Lhs.
+  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
+    if (vector::isParallelIterator(iter) && dims.count(index)) {
+      order.push_back(index);
+    }
+  }
+  // Then the remaining parallel loops.
+  for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) {
+    if (vector::isParallelIterator(iter) && !dims.count(index)) {
+      order.push_back(index);
+    }
+  }
+  return order;
+}
+
+/// Returns the target vector size for the target operation based on the native
+/// vector size specified with `m`, `n`, and `k`.
+static std::optional<SmallVector<int64_t>>
+getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
+  if (auto contract = dyn_cast<vector::ContractionOp>(op)) {
+    int64_t contractRank = contract.getIteratorTypes().size();
+    if (contractRank < 3)
+      return std::nullopt;
+    SmallVector<int64_t> nativeSize(contractRank - 3, 1);
+    nativeSize.append({m, n, k});
+    return nativeSize;
+  }
+  if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
+    int64_t writeRank = writeOp.getVectorType().getRank();
+    if (writeRank < 2)
+      return std::nullopt;
+    SmallVector<int64_t> nativeSize(writeRank - 2, 1);
+    nativeSize.append({m, n});
+    return nativeSize;
+  }
+  if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
+    // Transfer read ops may need 
diff erent shapes based on how they are being
+    // used. For simplicity just match the shape used by the extract strided op.
+    VectorType sliceType;
+    for (Operation *users : op->getUsers()) {
+      auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
+      if (!extract)
+        return std::nullopt;
+      auto vecType = extract.getResult().getType().cast<VectorType>();
+      if (sliceType && sliceType != vecType)
+        return std::nullopt;
+      sliceType = vecType;
+    }
+    return llvm::to_vector(sliceType.getShape());
+  }
+  if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
+    if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) {
+      // TODO: The condition for unrolling elementwise should be restricted
+      // only to operations that need unrolling (connected to the contract).
+      if (vecType.getRank() < 2)
+        return std::nullopt;
+
+      // First check whether there is a slice to infer the shape from. This is
+      // required for cases where the accumulator type 
diff ers from the input
+      // types, in which case we will see an `arith.ext_` between the contract
+      // and transfer_read which needs to be unrolled.
+      VectorType sliceType;
+      for (Operation *users : op->getUsers()) {
+        auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
+        if (!extract)
+          return std::nullopt;
+        auto vecType = extract.getResult().getType().cast<VectorType>();
+        if (sliceType && sliceType != vecType)
+          return std::nullopt;
+        sliceType = vecType;
+      }
+      if (sliceType)
+        return llvm::to_vector(sliceType.getShape());
+
+      // Else unroll for trailing elementwise.
+      SmallVector<int64_t> nativeSize(vecType.getRank() - 2, 1);
+      // Map elementwise ops to the output shape.
+      nativeSize.append({m, n});
+      return nativeSize;
+    }
+  }
+  return std::nullopt;
+}
+
+void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  auto unrollOrder = [](Operation *op) -> std::optional<SmallVector<int64_t>> {
+    auto contract = dyn_cast<vector::ContractionOp>(op);
+    if (!contract)
+      return std::nullopt;
+    return gpuMmaUnrollOrder(contract);
+  };
+
+  int64_t m = getM();
+  int64_t n = getN();
+  int64_t k = getK();
+  auto nativeShapeFn =
+      [m, n, k](Operation *op) -> std::optional<SmallVector<int64_t>> {
+    return getSubgroupMmaNativeVectorSize(op, m, n, k);
+  };
+  vector::populateVectorUnrollPatterns(
+      patterns, vector::UnrollVectorOptions()
+                    .setNativeShapeFn(nativeShapeFn)
+                    .setUnrollTraversalOrderFn(unrollOrder));
+}
+
 //===----------------------------------------------------------------------===//
 // EliminateBarriersOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/GPU/subgroup-mma-vector-unroll.mlir b/mlir/test/Dialect/GPU/subgroup-mma-vector-unroll.mlir
new file mode 100644
index 00000000000000..dded4993f467c0
--- /dev/null
+++ b/mlir/test/Dialect/GPU/subgroup-mma-vector-unroll.mlir
@@ -0,0 +1,98 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+func.func @matmul(%lhs: memref<32x32xf32>, %rhs: memref<32x32xf32>, %out: memref<32x32xf32>) {
+  %c8 = arith.constant 8 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<16x16xf32>
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %3 = gpu.thread_id  x
+  %4 = gpu.thread_id  y
+  %5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%4]
+  %6 = affine.apply affine_map<()[s0] -> ((s0 floordiv 32) * 16)>()[%3]
+  // CHECK:         scf.for {{.*}} -> (vector<16x16xf32>) {
+  // CHECK-COUNT-2:   vector.transfer_read {{.*}} vector<16x8xf32>
+  // CHECK-COUNT-2:   vector.transfer_read {{.*}} vector<8x16xf32>
+  // CHECK-COUNT-2:   vector.contract {{.*}} vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
+  // CHECK:           scf.yield {{.*}} : vector<16x16xf32>
+  // CHECK:         }
+  %7 = scf.for %arg0 = %c0 to %c32 step %c16 iter_args(%arg1 = %cst) -> (vector<16x16xf32>) {
+    %10 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%5]
+    %11 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%arg0]
+    %12 = vector.transfer_read %lhs[%10, %11], %cst_0 {in_bounds = [true, true]} : memref<32x32xf32>, vector<16x16xf32>
+    %16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%6]
+    %17 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%arg0]
+    %18 = vector.transfer_read %rhs[%17, %16], %cst_0 {in_bounds = [true, true]} : memref<32x32xf32>, vector<16x16xf32>
+    %22 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %12, %18, %arg1 : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
+    scf.yield %22 : vector<16x16xf32>
+  }
+  %8 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%5]
+  %9 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%6]
+  vector.transfer_write %7, %out[%8, %9] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32>
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%func_op: !transform.op<"func.func">):
+  transform.apply_patterns to %func_op {
+    transform.apply_patterns.gpu.unroll_vectors_subgroup_mma [16, 16, 8]
+  } : !transform.op<"func.func">
+}
+
+// -----
+
+// CHECK-LABEL: func.func @gathered_matmul
+func.func @gathered_matmul(%lhs: memref<32x32xf32>, %rhs: memref<32x32xf32>, %out: memref<32x32xf32>) {
+  %c8 = arith.constant 8 : index
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<16x16xf32>
+  %cst_mask = arith.constant dense<true> : vector<4x4xi1>
+  %cst_pt = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %cst_1 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+  %cst_2 = arith.constant dense<1> : vector<4x4xindex>
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32>
+  %3 = gpu.thread_id  x
+  %4 = gpu.thread_id  y
+  %5 = affine.apply affine_map<()[s0] -> (s0 * 16)>()[%4]
+  %6 = affine.apply affine_map<()[s0] -> ((s0 floordiv 32) * 16)>()[%3]
+  // CHECK:         scf.for {{.*}} -> (vector<16x16xf32>) {
+  // CHECK:           arith.addi {{.*}} : vector<4xindex>
+  // CHECK:           vector.gather {{.*}} : memref<32x32xf32>, vector<4x4xindex>, vector<4x4xi1>, vector<4x4xf32> into vector<4x4xf32>
+  // CHECK-COUNT-8:   vector.transfer_read {{.*}} vector<8x4xf32>
+  // CHECK-COUNT-4:   vector.transfer_read {{.*}} vector<4x16xf32>
+  // CHECK-COUNT-8:   vector.contract {{.*}} vector<8x4xf32>, vector<4x16xf32> into vector<8x16xf32>
+  // CHECK:           scf.yield {{.*}} : vector<16x16xf32>
+  // CHECK:         }
+  %7 = scf.for %arg0 = %c0 to %c32 step %c16 iter_args(%arg1 = %cst) -> (vector<16x16xf32>) {
+    %10 = vector.broadcast %arg0 : index to vector<4xindex>
+    %11 = arith.addi %10, %cst_1 : vector<4xindex>
+    %12 = vector.broadcast %11 : vector<4xindex> to vector<4x4xindex>
+    %13 = arith.addi %12, %cst_2 : vector<4x4xindex>
+    %14 = vector.gather %lhs[%c0, %c0] [%13], %cst_mask, %cst_pt : memref<32x32xf32>, vector<4x4xindex>, vector<4x4xi1>, vector<4x4xf32> into vector<4x4xf32>
+    vector.transfer_write %14, %alloc[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, memref<32x32xf32>
+    gpu.barrier
+    %15 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%5]
+    %16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%arg0]
+    %17 = vector.transfer_read %alloc[%15, %16], %cst_0 {in_bounds = [true, true]} : memref<32x32xf32>, vector<16x16xf32>
+    %18 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%6]
+    %19 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%arg0]
+    %20 = vector.transfer_read %rhs[%19, %18], %cst_0 {in_bounds = [true, true]} : memref<32x32xf32>, vector<16x16xf32>
+    %21 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %17, %20, %arg1 : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
+    scf.yield %21 : vector<16x16xf32>
+  }
+  %8 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%5]
+  %9 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%c0)[%6]
+  vector.transfer_write %7, %out[%8, %9] {in_bounds = [true, true]} : vector<16x16xf32>, memref<32x32xf32>
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%func_op: !transform.op<"func.func">):
+  transform.apply_patterns to %func_op {
+    transform.apply_patterns.gpu.unroll_vectors_subgroup_mma [8, 16, 4]
+  } : !transform.op<"func.func">
+}


        


More information about the Mlir-commits mailing list