[Mlir-commits] [mlir] 6413226 - [mlir][VectorToGPU] Add conversion for splat constant to MMA const matrix
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 24 15:38:34 PDT 2021
Author: thomasraoux
Date: 2021-06-24T15:38:12-07:00
New Revision: 6413226dce060581dbc2ce805e9d3311d7245e22
URL: https://github.com/llvm/llvm-project/commit/6413226dce060581dbc2ce805e9d3311d7245e22
DIFF: https://github.com/llvm/llvm-project/commit/6413226dce060581dbc2ce805e9d3311d7245e22.diff
LOG: [mlir][VectorToGPU] Add conversion for splat constant to MMA const matrix
Differential Revision: https://reviews.llvm.org/D104133
Added:
Modified:
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1c0b0313c0dfc..0fc7944d5df73 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -113,6 +113,15 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
return true;
}
+/// Return true if the constant is a splat to a 2D vector so that it can be
+/// converted to a MMA constant matrix op.
+static bool constantSupportsMMAMatrixType(ConstantOp constantOp) {
+ auto vecType = constantOp.getType().dyn_cast<VectorType>();
+ if (!vecType || vecType.getRank() != 2)
+ return false;
+ return constantOp.value().isa<SplatElementsAttr>();
+}
+
static bool supportsMMaMatrixType(Operation *op) {
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
return transferReadSupportsMMAMatrixType(transferRead);
@@ -120,6 +129,8 @@ static bool supportsMMaMatrixType(Operation *op) {
return transferWriteSupportsMMAMatrixType(transferWrite);
if (auto contract = dyn_cast<vector::ContractionOp>(op))
return contractSupportsMMAMatrixType(contract);
+ if (auto constant = dyn_cast<ConstantOp>(op))
+ return constantSupportsMMAMatrixType(constant);
return false;
}
@@ -241,10 +252,11 @@ struct CombineTransferReadOpTranspose final
} // namespace
// MMA types have
diff erent layout based on how they are used in matmul ops.
-// Figure the right layout to use by looking at Transfer op uses.
+// Figure the right layout to use by looking at op uses.
// TODO: Change the GPU dialect to abstract the layout at the this level and
// only care about it during lowering to NVVM.
-static const char *inferFragType(vector::TransferReadOp op) {
+template <typename OpTy>
+static const char *inferFragType(OpTy op) {
for (Operation *users : op->getUsers()) {
auto contract = dyn_cast<vector::ContractionOp>(users);
if (!contract)
@@ -297,6 +309,23 @@ static void convertContractOp(vector::ContractionOp op,
valueMapping[op.getResult()] = matmul;
}
+/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
+static void convertConstantOp(ConstantOp op,
+ llvm::DenseMap<Value, Value> &valueMapping) {
+ assert(constantSupportsMMAMatrixType(op));
+ OpBuilder b(op);
+ Attribute splat = op.getValue().cast<SplatElementsAttr>().getSplatValue();
+ auto scalarConstant =
+ b.create<ConstantOp>(op.getLoc(), splat.getType(), splat);
+ const char *fragType = inferFragType(op);
+ auto vecType = op.getType().cast<VectorType>();
+ gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
+ vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
+ auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
+ scalarConstant);
+ valueMapping[op.getResult()] = matrix;
+}
+
namespace mlir {
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
@@ -314,6 +343,8 @@ void convertVectorToMMAOps(FuncOp funcOp) {
convertTransferWriteOp(transferWrite, valueMapping);
} else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
convertContractOp(contractOp, valueMapping);
+ } else if (auto constantOp = dyn_cast<ConstantOp>(op)) {
+ convertConstantOp(constantOp, valueMapping);
}
}
}
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 5005cc6c6b228..d0b7d68e8c829 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -23,6 +23,24 @@ func @matmul(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<1
return
}
+// CHECK-LABEL: func @matmul_cst
+// CHECK-DAG: %[[CST:.+]] = constant 0.000000e+00 : f16
+// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[CST]] : !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func @matmul_cst(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) {
+ %cst_0 = constant dense<0.000000e+00> : vector<16x16xf16>
+ %c0 = constant 0 : index
+ %cst = constant 0.000000e+00 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+ vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+ return
+}
+
// Negative test until scf.for support is added.
// CHECK-LABEL: func @matmul_loop
// CHECK: vector.contract
More information about the Mlir-commits
mailing list