[Mlir-commits] [mlir] 428a62f - [mlir][gpu] Add op to create MMA constant matrix

Thu Jun 10 08:34:16 PDT 2021

Author: thomasraoux
Date: 2021-06-10T08:34:04-07:00
New Revision: 428a62f65f16f1640b1bfe033d20e6a4f545dd3e

URL: https://github.com/llvm/llvm-project/commit/428a62f65f16f1640b1bfe033d20e6a4f545dd3e
DIFF: https://github.com/llvm/llvm-project/commit/428a62f65f16f1640b1bfe033d20e6a4f545dd3e.diff

LOG: [mlir][gpu] Add op to create MMA constant matrix

This allow creating a matrix with all elements set to a given value. This is
needed to be able to implement a simple dot op.

Differential Revision: https://reviews.llvm.org/D103870




diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 8e2520b675ae6..1e78e4af4d51a 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -1022,4 +1022,49 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
   let verifier = [{ return ::verify(*this); }];
+def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
+    [NoSideEffect,
+     TypesMatchWith<"value type matches element type of mma_matrix",
+                    "res", "value",
+                    "$_self.cast<gpu::MMAMatrixType>().getElementType()">]>{
+  let summary = "GPU warp synchronous constant matrix";
+  let description = [{
+    The `gpu.subgroup_mma_constant_matrix` creates a `!gpu.mma_matrix` with
+    constant elements.
+    The operation takes a scalar input and return a `!gpu.mma_matrix` where each
+    element of is equal to the operand constant. The destination mma_matrix type
+    must have elememt type equal to the constant type. Since the layout of
+    `!gpu.mma_matrix` is opaque this only support setting all the elements to
+    the same value.
+    This op is meant to be used along with `gpu.subgroup_mma_compute`.
+    Example:
+    ```mlir
+     %0 = gpu.subgroup_mma_constant_matrix %a :
+       !gpu.mma_matrix<16x16xf16, "AOp">
+     %1 = gpu.subgroup_mma_constant_matrix %b :
+       !gpu.mma_matrix<16x16xf32, "COp">
+    ```
+  }];
+  let arguments = (ins AnyTypeOf<[F16, F32]>:$value);
+  let results = (outs GPU_MMAMatrix:$res);
+  let extraClassDeclaration = [{
+    gpu::MMAMatrixType getType() {
+      return res().getType().cast<gpu::MMAMatrixType>();
+    }
+  }];
+  let assemblyFormat = [{
+    $value attr-dict `:` type($res)
+  }];
 #endif // GPU_OPS

diff  --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index d72c8c217f86c..d46a185dec22c 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -348,12 +348,52 @@ struct WmmaMmaOpToNVVMLowering
+/// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
+struct WmmaConstantOpToNVVMLowering
+    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> {
+  using ConvertOpToLLVMPattern<
+      gpu::SubgroupMmaConstantMatrixOp>::ConvertOpToLLVMPattern;
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (failed(areAllLLVMTypes(subgroupMmaConstantOp.getOperation(), operands,
+                               rewriter)))
+      return failure();
+    Location loc = subgroupMmaConstantOp.getLoc();
+    Value cst = operands[0];
+    LLVM::LLVMStructType type = convertMMAToLLVMType(
+        subgroupMmaConstantOp.getType().cast<gpu::MMAMatrixType>());
+    // If the element type is a vector create a vector from the operand.
+    if (auto vecType = type.getBody()[0].dyn_cast<VectorType>()) {
+      Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType);
+      for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
+        Value idx = rewriter.create<LLVM::ConstantOp>(
+            loc, typeConverter->convertType(rewriter.getIntegerType(32)),
+            rewriter.getI32ArrayAttr(vecEl));
+        vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
+                                                        cst, idx);
+      }
+      cst = vecCst;
+    }
+    Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, type);
+    for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
+      matrixStruct = rewriter.create<LLVM::InsertValueOp>(
+          loc, matrixStruct, cst, rewriter.getI32ArrayAttr(i));
+    }
+    rewriter.replaceOp(subgroupMmaConstantOp, matrixStruct);
+    return success();
+  }
 } // anonymous namespace
 namespace mlir {
 void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                              RewritePatternSet &patterns) {
   patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
-                  WmmaStoreOpToNVVMLowering>(converter);
+                  WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering>(
+      converter);
 } // namespace mlir

diff  --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index de5d0d3fcf1c0..f692dffdfcbad 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -151,3 +151,28 @@ gpu.module @test_module {
+// -----
+gpu.module @test_module {
+// CHECK-LABEL: func @gpu_wmma_constant_op
+//       CHECK: %[[CST:.+]] = llvm.mlir.constant(1.000000e+00 : f16) : f16
+//       CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<2xf16>
+//       CHECK: %[[C0:.+]] = llvm.mlir.constant([0 : i32]) : i32
+//       CHECK: %[[V1:.+]] = llvm.insertelement %[[CST]], %[[V0]][%[[C0]] : i32] : vector<2xf16>
+//       CHECK: %[[C1:.+]] = llvm.mlir.constant([1 : i32]) : i32
+//       CHECK: %[[V2:.+]] = llvm.insertelement %[[CST]], %[[V1]][%[[C1]] : i32] : vector<2xf16>
+//       CHECK: %[[M0:.+]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[M1:.+]] = llvm.insertvalue %[[V2]], %[[M0]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[M2:.+]] = llvm.insertvalue %[[V2]], %[[M1]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[M3:.+]] = llvm.insertvalue %[[V2]], %[[M2]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[M4:.+]] = llvm.insertvalue %[[V2]], %[[M3]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: llvm.return %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+  func @gpu_wmma_constant_op()  ->(!gpu.mma_matrix<16x16xf16, "COp">) {
+    %cst = constant 1.0 : f16
+    %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
+    return %C : !gpu.mma_matrix<16x16xf16, "COp">
+  }

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index a98fe1c496838..1bed13c4b21a4 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -201,8 +201,12 @@ module attributes {gpu.container_module} {
     // CHECK: %[[wg:.*]] = memref.alloca()
     %i = constant 16 : index
     // CHECK: %[[i:.*]] = constant 16 : index
+     %cst = constant 1.000000e+00 : f32
+    // CHECK: %[[cst:.*]] = constant 1.000000e+00 : f32
     %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
     // CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
+    %1 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf32, "COp">
+    // CHECK: gpu.subgroup_mma_constant_matrix %[[cst]] : !gpu.mma_matrix<16x16xf32, "COp">


