[Mlir-commits] [mlir] 7fbb067 - [mlir][VectorToGPU] Add support for elementwise mma to vector to GPU

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 2 08:01:32 PDT 2021


Author: thomasraoux
Date: 2021-11-02T08:01:04-07:00
New Revision: 7fbb0678fa4d6a8920fe7ddf3e734fba4406bd24

URL: https://github.com/llvm/llvm-project/commit/7fbb0678fa4d6a8920fe7ddf3e734fba4406bd24
DIFF: https://github.com/llvm/llvm-project/commit/7fbb0678fa4d6a8920fe7ddf3e734fba4406bd24.diff

LOG: [mlir][VectorToGPU] Add support for elementwise mma to vector to GPU

Differential Revision: https://reviews.llvm.org/D112960

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 8e037ecf5c852..b97a04638a653 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -130,6 +130,26 @@ static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
          broadcastOp.source().getType().isa<FloatType>();
 }
 
+/// Return the MMA elementwise enum associated with `op` if it is supported.
+/// Return `llvm::None` otherwise.
+static llvm::Optional<gpu::MMAElementwiseOp>
+convertElementwiseOpToMMA(Operation *op) {
+  if (isa<arith::AddFOp>(op))
+    return gpu::MMAElementwiseOp::ADDF;
+  if (isa<arith::MulFOp>(op))
+    return gpu::MMAElementwiseOp::MULF;
+  if (isa<MaxFOp>(op))
+    return gpu::MMAElementwiseOp::MAXF;
+  if (isa<MinFOp>(op))
+    return gpu::MMAElementwiseOp::MINF;
+  return llvm::None;
+}
+
+/// Return true if the op is supported as elementwise op on MMAMatrix type.
+static bool elementwiseSupportsMMAMatrixType(Operation *op) {
+  return convertElementwiseOpToMMA(op).hasValue();
+}
+
 static bool supportsMMaMatrixType(Operation *op) {
   if (isa<scf::ForOp, scf::YieldOp>(op))
     return true;
@@ -143,7 +163,7 @@ static bool supportsMMaMatrixType(Operation *op) {
     return constantSupportsMMAMatrixType(constant);
   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
     return broadcastSupportsMMAMatrixType(broadcast);
-  return false;
+  return elementwiseSupportsMMAMatrixType(op);
 }
 
 // Analyze slice of operations based on convert op to figure out if the whole
@@ -423,6 +443,18 @@ static void convertYieldOp(scf::YieldOp op,
   op.erase();
 }
 
+/// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
+static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType,
+                                 llvm::DenseMap<Value, Value> &valueMapping) {
+  OpBuilder b(op);
+  SmallVector<Value> matrixOperands;
+  for (Value operand : op->getOperands())
+    matrixOperands.push_back(valueMapping.find(operand)->second);
+  Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>(
+      op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType);
+  valueMapping[op->getResult(0)] = newOp;
+}
+
 namespace mlir {
 
 void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
@@ -448,6 +480,8 @@ void convertVectorToMMAOps(FuncOp funcOp) {
       convertForOp(forOp, valueMapping);
     } else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
       convertYieldOp(yiledOp, valueMapping);
+    } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
+      convertElementwiseOp(op, *elementwiseType, valueMapping);
     }
   }
 }

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 3a7c89343cab1..2ca899fa5bac4 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -83,3 +83,26 @@ func @matmul_loop(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2:
   vector.transfer_write %14, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<128x128xf16>
   return
 }
+
+// CHECK-LABEL: func @matmul_fused_elementwise
+//   CHECK-DAG:   %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16
+//   CHECK-DAG:   %[[CST_1:.+]] = arith.constant 1.000000e+00 : f16
+//   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//   CHECK-DAG:   %[[C0:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_0]] : !gpu.mma_matrix<16x16xf16, "COp">
+//   CHECK-DAG:   %[[C1:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_1]] : !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C0]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[E:.+]] = gpu.subgroup_mma_elementwise %[[D]], %[[C1]] {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[E]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+  %cst_1 = arith.constant dense<1.000000e+00> : vector<16x16xf16>
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+  %E = arith.addf %D, %cst_1 : vector<16x16xf16>
+  vector.transfer_write %E, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  return
+}


        


More information about the Mlir-commits mailing list