[Mlir-commits] [mlir] 7fbb067 - [mlir][VectorToGPU] Add support for elementwise mma to vector to GPU
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 2 08:01:32 PDT 2021
Author: thomasraoux
Date: 2021-11-02T08:01:04-07:00
New Revision: 7fbb0678fa4d6a8920fe7ddf3e734fba4406bd24
URL: https://github.com/llvm/llvm-project/commit/7fbb0678fa4d6a8920fe7ddf3e734fba4406bd24
DIFF: https://github.com/llvm/llvm-project/commit/7fbb0678fa4d6a8920fe7ddf3e734fba4406bd24.diff
LOG: [mlir][VectorToGPU] Add support for elementwise mma to vector to GPU
Differential Revision: https://reviews.llvm.org/D112960
Added:
Modified:
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 8e037ecf5c852..b97a04638a653 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -130,6 +130,26 @@ static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
broadcastOp.source().getType().isa<FloatType>();
}
+/// Return the MMA elementwise enum associated with `op` if it is supported.
+/// Return `llvm::None` otherwise.
+static llvm::Optional<gpu::MMAElementwiseOp>
+convertElementwiseOpToMMA(Operation *op) {
+ if (isa<arith::AddFOp>(op))
+ return gpu::MMAElementwiseOp::ADDF;
+ if (isa<arith::MulFOp>(op))
+ return gpu::MMAElementwiseOp::MULF;
+ if (isa<MaxFOp>(op))
+ return gpu::MMAElementwiseOp::MAXF;
+ if (isa<MinFOp>(op))
+ return gpu::MMAElementwiseOp::MINF;
+ return llvm::None;
+}
+
+/// Return true if the op is supported as elementwise op on MMAMatrix type.
+static bool elementwiseSupportsMMAMatrixType(Operation *op) {
+ return convertElementwiseOpToMMA(op).hasValue();
+}
+
static bool supportsMMaMatrixType(Operation *op) {
if (isa<scf::ForOp, scf::YieldOp>(op))
return true;
@@ -143,7 +163,7 @@ static bool supportsMMaMatrixType(Operation *op) {
return constantSupportsMMAMatrixType(constant);
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
return broadcastSupportsMMAMatrixType(broadcast);
- return false;
+ return elementwiseSupportsMMAMatrixType(op);
}
// Analyze slice of operations based on convert op to figure out if the whole
@@ -423,6 +443,18 @@ static void convertYieldOp(scf::YieldOp op,
op.erase();
}
+/// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
+static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ OpBuilder b(op);
+ SmallVector<Value> matrixOperands;
+ for (Value operand : op->getOperands())
+ matrixOperands.push_back(valueMapping.find(operand)->second);
+ Value newOp = b.create<gpu::SubgroupMmaElementwiseOp>(
+ op->getLoc(), matrixOperands[0].getType(), matrixOperands, opType);
+ valueMapping[op->getResult(0)] = newOp;
+}
+
namespace mlir {
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
@@ -448,6 +480,8 @@ void convertVectorToMMAOps(FuncOp funcOp) {
convertForOp(forOp, valueMapping);
} else if (auto yiledOp = dyn_cast<scf::YieldOp>(op)) {
convertYieldOp(yiledOp, valueMapping);
+ } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
+ convertElementwiseOp(op, *elementwiseType, valueMapping);
}
}
}
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 3a7c89343cab1..2ca899fa5bac4 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -83,3 +83,26 @@ func @matmul_loop(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2:
vector.transfer_write %14, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<128x128xf16>
return
}
+
+// CHECK-LABEL: func @matmul_fused_elementwise
+// CHECK-DAG: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16
+// CHECK-DAG: %[[CST_1:.+]] = arith.constant 1.000000e+00 : f16
+// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+// CHECK-DAG: %[[C0:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_0]] : !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK-DAG: %[[C1:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_1]] : !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C0]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[E:.+]] = gpu.subgroup_mma_elementwise %[[D]], %[[C1]] {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: gpu.subgroup_mma_store_matrix %[[E]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+ %cst_1 = arith.constant dense<1.000000e+00> : vector<16x16xf16>
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+ %E = arith.addf %D, %cst_1 : vector<16x16xf16>
+ vector.transfer_write %E, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+ return
+}
More information about the Mlir-commits
mailing list