[Mlir-commits] [mlir] 8a992b2 - [mlir][gpu] Add basic support to do elementwise ops on mma matrix type

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 1 11:52:09 PDT 2021


Author: thomasraoux
Date: 2021-11-01T11:51:19-07:00
New Revision: 8a992b20dba54a061717a14eab86ccbe097da4c0

URL: https://github.com/llvm/llvm-project/commit/8a992b20dba54a061717a14eab86ccbe097da4c0
DIFF: https://github.com/llvm/llvm-project/commit/8a992b20dba54a061717a14eab86ccbe097da4c0.diff

LOG: [mlir][gpu] Add basic support to do elementwise ops on mma matrix type

In order to support fusion with mma matrix type we need to be able to
execute elementwise operations on them. This add an op to be able to
support some basic elementwise operations. This is a is not a full
solution as it only supports a limited scope or operations. Ideally we would
want to be able to fuse with more kind of operations.

Differential Revision: https://reviews.llvm.org/D112857

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/CMakeLists.txt
    mlir/include/mlir/Dialect/GPU/GPUBase.td
    mlir/include/mlir/Dialect/GPU/GPUDialect.h
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
    mlir/lib/Dialect/GPU/CMakeLists.txt
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
    mlir/test/Dialect/GPU/ops.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
index 73aa1d92ffc1e..4808ec53e4e75 100644
--- a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
@@ -22,4 +22,9 @@ mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix GPU)
 mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GPU)
 add_public_tablegen_target(MLIRGPUPassIncGen)
 
+set(LLVM_TARGET_DEFINITIONS GPUOps.td)
+mlir_tablegen(GPUOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(GPUOpsEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRGPUOpsEnumsGen)
+
 add_mlir_doc(Passes GPUPasses ./ -gen-pass-doc)

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td
index a7bd8ece6a1c7..6c2fa43679d23 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td
@@ -115,18 +115,4 @@ def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
   ];
 }
 
-// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing
-// the layouts of the operands supported by the ops that use this attribute.
-def RowMajor: StrEnumAttrCase<"RowMajor", 0>;
-def ColMajor: StrEnumAttrCase<"ColMajor", 1>;
-
-// Specifies a String enum Attribute for Warp wide matrix operations,
-// representing the layout of respective operands. The layout later governs
-// the lowerings to appropriate intrinsics.
-def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major",
-                           [RowMajor, ColMajor]> {
-  let stringToSymbolFnName = "LayoutStrToEnum";
-  let symbolToStringFnName = "EnumToLayoutStr";
-}
-
 #endif // GPU_BASE

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
index 79e8dca5af9c1..5c1b9db33c563 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -166,6 +166,8 @@ void addAsyncDependency(Operation *op, Value token);
 } // end namespace gpu
 } // end namespace mlir
 
+#include "mlir/Dialect/GPU/GPUOpsEnums.h.inc"
+
 #include "mlir/Dialect/GPU/GPUOpsDialect.h.inc"
 
 #include "mlir/Dialect/GPU/GPUOpInterfaces.h.inc"

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index b92d315b19ffb..18b5adfd2445c 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -591,13 +591,13 @@ def GPU_YieldOp : GPU_Op<"yield", [NoSideEffect, Terminator]>,
 }
 
 // add, mul mirror the XLA ComparisonDirection enum.
-def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">;
-def GPU_AllReduceOpAnd : StrEnumAttrCase<"and">;
-def GPU_AllReduceOpMax : StrEnumAttrCase<"max">;
-def GPU_AllReduceOpMin : StrEnumAttrCase<"min">;
-def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">;
-def GPU_AllReduceOpOr : StrEnumAttrCase<"or">;
-def GPU_AllReduceOpXor : StrEnumAttrCase<"xor">;
+def GPU_AllReduceOpAdd : StrEnumAttrCase<"ADD", -1, "add">;
+def GPU_AllReduceOpAnd : StrEnumAttrCase<"AND", -1, "and">;
+def GPU_AllReduceOpMax : StrEnumAttrCase<"MAX", -1, "max">;
+def GPU_AllReduceOpMin : StrEnumAttrCase<"MIN", -1, "min">;
+def GPU_AllReduceOpMul : StrEnumAttrCase<"MUL", -1, "mul">;
+def GPU_AllReduceOpOr : StrEnumAttrCase<"OR", -1, "or">;
+def GPU_AllReduceOpXor : StrEnumAttrCase<"XOR", -1, "xor">;
 
 def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr",
     "built-in reduction operations supported by gpu.allreduce.",
@@ -644,7 +644,7 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
   let verifier = [{ return ::verifyAllReduce(*this); }];
 }
 
-def GPU_ShuffleOpXor : StrEnumAttrCase<"xor">;
+def GPU_ShuffleOpXor : StrEnumAttrCase<"XOR", -1, "xor">;
 
 def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr",
     "Indexing modes supported by gpu.shuffle.",
@@ -1121,4 +1121,60 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
   }];
 }
 
+def GPU_ELEMENTWISE_OP_ADD : StrEnumAttrCase<"ADDF">;
+def GPU_ELEMENTWISE_OP_MUL : StrEnumAttrCase<"MULF">;
+def GPU_ELEMENTWISE_OP_MAXF : StrEnumAttrCase<"MAXF">;
+def GPU_ELEMENTWISE_OP_MINF : StrEnumAttrCase<"MINF">;
+
+def MMAElementWiseAttr : StrEnumAttr<"MMAElementwiseOp",
+  "elementwise operation to apply to mma matrix",
+  [GPU_ELEMENTWISE_OP_ADD, GPU_ELEMENTWISE_OP_MUL,
+   GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF]> {
+  let cppNamespace = "::mlir::gpu";
+  let storageType = "::mlir::StringAttr";
+  let returnType = "::mlir::gpu::MMAElementwiseOp";
+  let convertFromStorage = "*symbolizeMMAElementwiseOp($_self.getValue())";
+  let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))";
+}
+
+def GPU_SubgroupMmaElementwiseOp : GPU_Op<"subgroup_mma_elementwise",
+    [NoSideEffect,
+     AllTypesMatch<["args"]>]>{
+
+  let summary = "GPU warp elementwise operation on a matrix";
+
+  let description = [{
+    The `gpu.subgroup_mma_elementwise` takes `!gpu.mma_matrix` inputs and
+    compute a new `!gpu.mma_matrix` by applying an elementwise operation to each
+    element.
+
+    Since the operation is elementwise and the matrix type must match, the
+    matrix elements are processed independently of the matrix layout.
+
+    This op is meant to be used along with `gpu.subgroup_mma_compute`.
+
+    Example:
+
+    ```mlir
+     %0 =  %A, %B { operation = "ADD" } :
+      (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">)
+      -> !gpu.mma_matrix<16x16xf16, "COp">
+    ```
+  }];
+
+  let arguments = (ins Variadic<GPU_MMAMatrix>:$args, MMAElementWiseAttr:$operation);
+
+  let results = (outs GPU_MMAMatrix:$res);
+
+  let extraClassDeclaration = [{
+    gpu::MMAMatrixType getType() {
+      return res().getType().cast<gpu::MMAMatrixType>();
+    }
+  }];
+
+  let assemblyFormat = [{
+    $args attr-dict `:` functional-type($args, $res)
+  }];
+}
+
 #endif // GPU_OPS

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index ec0d5355dcf18..e63a2672bf316 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1187,11 +1187,11 @@ class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
 }
 
 // An enum attribute case stored with StringAttr.
-class StrEnumAttrCase<string sym, int val = -1> :
-    EnumAttrCaseInfo<sym, val, sym>,
+class StrEnumAttrCase<string sym, int val = -1, string str = sym> :
+    EnumAttrCaseInfo<sym, val, str>,
     StringBasedAttr<
-      CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # sym # "\"">,
-      "case " # sym>;
+      CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # str # "\"">,
+      "case " # str>;
 
 // An enum attribute case stored with IntegerAttr, which has an integer value,
 // its representation as a string and a C++ symbol name which may be 
diff erent.

diff  --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 49d48bf2d630b..878d0cf22fd8f 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/IR/TypeUtilities.h"
 
 using namespace mlir;
 
@@ -352,13 +353,90 @@ struct WmmaConstantOpToNVVMLowering
   }
 };
 
+static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
+                           Value rhs, bool isMin) {
+  auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
+  Type i1Type = builder.getI1Type();
+  if (auto vecType = lhs.getType().dyn_cast<VectorType>())
+    i1Type = VectorType::get(vecType.getShape(), i1Type);
+  Value cmp = builder.create<LLVM::FCmpOp>(
+      loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
+      lhs, rhs);
+  Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
+  Value isNan = builder.create<LLVM::FCmpOp>(
+      loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
+  Value nan = builder.create<LLVM::ConstantOp>(
+      loc, lhs.getType(),
+      builder.getFloatAttr(floatType,
+                           APFloat::getQNaN(floatType.getFloatSemantics())));
+  return builder.create<LLVM::SelectOp>(loc, isNan, sel, nan);
+}
+
+static Value createScalarOp(OpBuilder &builder, Location loc,
+                            gpu::MMAElementwiseOp op,
+                            ArrayRef<Value> operands) {
+  switch (op) {
+  case gpu::MMAElementwiseOp::ADDF:
+    return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
+  case gpu::MMAElementwiseOp::MULF:
+    return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
+  case gpu::MMAElementwiseOp::MAXF:
+    return createMinMaxF(builder, loc, operands[0], operands[1],
+                         /*isMin=*/false);
+  case gpu::MMAElementwiseOp::MINF:
+    return createMinMaxF(builder, loc, operands[0], operands[1],
+                         /*isMin=*/true);
+  }
+  llvm_unreachable("unknown op");
+}
+
+/// Convert GPU MMA elementwise ops to extract + op + insert.
+struct WmmaElementwiseOpToNVVMLowering
+    : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
+  using ConvertOpToLLVMPattern<
+      gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(),
+                               adaptor.getOperands(), rewriter)))
+      return failure();
+    Location loc = subgroupMmaElementwiseOp.getLoc();
+    size_t numOperands = adaptor.getOperands().size();
+    LLVM::LLVMStructType destType = convertMMAToLLVMType(
+        subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
+    Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType);
+    for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
+      SmallVector<Value> extractedOperands;
+      for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
+        Type elementType = adaptor.getOperands()[opIdx]
+                               .getType()
+                               .cast<LLVM::LLVMStructType>()
+                               .getBody()[i];
+        extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
+            loc, elementType, adaptor.getOperands()[opIdx],
+            rewriter.getI32ArrayAttr(i)));
+      }
+      Value element =
+          createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.operation(),
+                         extractedOperands);
+      matrixStruct = rewriter.create<LLVM::InsertValueOp>(
+          loc, matrixStruct, element, rewriter.getI32ArrayAttr(i));
+    }
+    rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
+    return success();
+  }
+};
+
 } // anonymous namespace
 
 namespace mlir {
 void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                              RewritePatternSet &patterns) {
   patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
-                  WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering>(
-      converter);
+                  WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
+                  WmmaElementwiseOpToNVVMLowering>(converter);
 }
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 2beb7ea7bc882..14520ce6767d8 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRGPUOps
 
   DEPENDS
   MLIRGPUOpsIncGen
+  MLIRGPUOpsEnumsGen
   MLIRGPUOpInterfacesIncGen
 
   LINK_LIBS PUBLIC

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index ba1710b57a919..9baff7f53ca8f 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1185,6 +1185,7 @@ void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
 }
 
 #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc"
+#include "mlir/Dialect/GPU/GPUOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/GPU/GPUOps.cpp.inc"

diff  --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index 4c035acaf7383..c0ac8a050288f 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -220,3 +220,33 @@ gpu.module @test_module {
     return %C : !gpu.mma_matrix<16x16xf16, "COp">
   }
 }
+
+// -----
+
+gpu.module @test_module {
+
+// CHECK-LABEL: func @gpu_wmma_elementwise
+//       CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[C0:.*]] = llvm.fadd %[[A0]], %[[B0]]  : vector<2xf16>
+//       CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[A1:.*]] = llvm.extractvalue %{{.*}}[1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[B1:.*]] = llvm.extractvalue %{{.*}}[1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[C1:.*]] = llvm.fadd %[[A1]], %[[B1]]  : vector<2xf16>
+//       CHECK: %[[M2:.*]] = llvm.insertvalue %[[C1]], %[[M1]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[A2:.*]] = llvm.extractvalue %{{.*}}[2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[B2:.*]] = llvm.extractvalue %{{.*}}[2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[C2:.*]] = llvm.fadd %[[A2]], %[[B2]]  : vector<2xf16>
+//       CHECK: %[[M3:.*]] = llvm.insertvalue %[[C2]], %[[M2]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[A3:.*]] = llvm.extractvalue %{{.*}}[3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[B3:.*]] = llvm.extractvalue %{{.*}}[3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: %[[C3:.*]] = llvm.fadd %[[A3]], %[[B3]]  : vector<2xf16>
+//       CHECK: %[[M4:.*]] = llvm.insertvalue %[[C3]], %[[M3]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK: llvm.return %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+  builtin.func @gpu_wmma_elementwise(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">)  ->(!gpu.mma_matrix<16x16xf16, "COp">) {
+    %C = gpu.subgroup_mma_elementwise %A, %B { operation = "ADDF" } :
+      (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+    return %C : !gpu.mma_matrix<16x16xf16, "COp">
+  }
+}

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 297fb5fe6fe20..c24fd7bf8a818 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -220,7 +220,10 @@ module attributes {gpu.container_module} {
     %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
     // CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
     %1 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf32, "COp">
-    // CHECK: gpu.subgroup_mma_constant_matrix %[[cst]] : !gpu.mma_matrix<16x16xf32, "COp">
+    // CHECK: gpu.subgroup_mma_elementwise %{{.*}}, %{{.*}} {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+    %2 = gpu.subgroup_mma_elementwise %1, %1 {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+    // CHECK: gpu.subgroup_mma_elementwise %{{.*}}, %{{.*}} {operation = "MAXF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+    %3 = gpu.subgroup_mma_elementwise %2, %1 {operation = "MAXF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
     return
   }
 }

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index cd1e34d1964d2..a5a59eb9cd632 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2768,6 +2768,14 @@ gentbl_cc_library(
             ["-gen-op-defs"],
             "include/mlir/Dialect/GPU/GPUOps.cpp.inc",
         ),
+        (
+            ["-gen-enum-decls"],
+            "include/mlir/Dialect/GPU/GPUOpsEnums.h.inc",
+        ),
+        (
+            ["-gen-enum-defs"],
+            "include/mlir/Dialect/GPU/GPUOpsEnums.cpp.inc",
+        ),
     ],
     tblgen = ":mlir-tblgen",
     td_file = "include/mlir/Dialect/GPU/GPUOps.td",


        


More information about the Mlir-commits mailing list