[Mlir-commits] [mlir] 6413226 - [mlir][VectorToGPU] Add conversion for splat constant to MMA const matrix

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 24 15:38:34 PDT 2021


Author: thomasraoux
Date: 2021-06-24T15:38:12-07:00
New Revision: 6413226dce060581dbc2ce805e9d3311d7245e22

URL: https://github.com/llvm/llvm-project/commit/6413226dce060581dbc2ce805e9d3311d7245e22
DIFF: https://github.com/llvm/llvm-project/commit/6413226dce060581dbc2ce805e9d3311d7245e22.diff

LOG: [mlir][VectorToGPU] Add conversion for splat constant to MMA const matrix

Differential Revision: https://reviews.llvm.org/D104133

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1c0b0313c0dfc..0fc7944d5df73 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -113,6 +113,15 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
   return true;
 }
 
+/// Return true if the constant is a splat to a 2D vector so that it can be
+/// converted to a MMA constant matrix op.
+static bool constantSupportsMMAMatrixType(ConstantOp constantOp) {
+  auto vecType = constantOp.getType().dyn_cast<VectorType>();
+  if (!vecType || vecType.getRank() != 2)
+    return false;
+  return constantOp.value().isa<SplatElementsAttr>();
+}
+
 static bool supportsMMaMatrixType(Operation *op) {
   if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
     return transferReadSupportsMMAMatrixType(transferRead);
@@ -120,6 +129,8 @@ static bool supportsMMaMatrixType(Operation *op) {
     return transferWriteSupportsMMAMatrixType(transferWrite);
   if (auto contract = dyn_cast<vector::ContractionOp>(op))
     return contractSupportsMMAMatrixType(contract);
+  if (auto constant = dyn_cast<ConstantOp>(op))
+    return constantSupportsMMAMatrixType(constant);
   return false;
 }
 
@@ -241,10 +252,11 @@ struct CombineTransferReadOpTranspose final
 } // namespace
 
 // MMA types have 
diff erent layout based on how they are used in matmul ops.
-// Figure the right layout to use by looking at Transfer op uses.
+// Figure the right layout to use by looking at op uses.
 // TODO: Change the GPU dialect to abstract the layout at the this level and
 // only care about it during lowering to NVVM.
-static const char *inferFragType(vector::TransferReadOp op) {
+template <typename OpTy>
+static const char *inferFragType(OpTy op) {
   for (Operation *users : op->getUsers()) {
     auto contract = dyn_cast<vector::ContractionOp>(users);
     if (!contract)
@@ -297,6 +309,23 @@ static void convertContractOp(vector::ContractionOp op,
   valueMapping[op.getResult()] = matmul;
 }
 
+/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
+static void convertConstantOp(ConstantOp op,
+                              llvm::DenseMap<Value, Value> &valueMapping) {
+  assert(constantSupportsMMAMatrixType(op));
+  OpBuilder b(op);
+  Attribute splat = op.getValue().cast<SplatElementsAttr>().getSplatValue();
+  auto scalarConstant =
+      b.create<ConstantOp>(op.getLoc(), splat.getType(), splat);
+  const char *fragType = inferFragType(op);
+  auto vecType = op.getType().cast<VectorType>();
+  gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
+      vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
+  auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
+                                                           scalarConstant);
+  valueMapping[op.getResult()] = matrix;
+}
+
 namespace mlir {
 
 void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
@@ -314,6 +343,8 @@ void convertVectorToMMAOps(FuncOp funcOp) {
       convertTransferWriteOp(transferWrite, valueMapping);
     } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
       convertContractOp(contractOp, valueMapping);
+    } else if (auto constantOp = dyn_cast<ConstantOp>(op)) {
+      convertConstantOp(constantOp, valueMapping);
     }
   }
 }

diff  --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 5005cc6c6b228..d0b7d68e8c829 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -23,6 +23,24 @@ func @matmul(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<1
   return
 }
 
+// CHECK-LABEL: func @matmul_cst
+//   CHECK-DAG:   %[[CST:.+]] = constant 0.000000e+00 : f16
+//   CHECK-DAG:   %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//   CHECK-DAG:   %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//   CHECK-DAG:   %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[CST]] : !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
+func @matmul_cst(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) {
+  %cst_0 = constant dense<0.000000e+00> : vector<16x16xf16>
+  %c0 = constant 0 : index
+  %cst = constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+  vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  return
+}
+
 // Negative test until scf.for support is added.
 // CHECK-LABEL: func @matmul_loop
 //       CHECK:   vector.contract


        


More information about the Mlir-commits mailing list