[Mlir-commits] [mlir] 8c80d01 - [mlir][NVGPU] NFC - Add a more convenient C++ builder for nvgpu::MmaSyncOp

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jun 19 07:31:01 PDT 2023


Author: Nicolas Vasilache
Date: 2023-06-19T13:54:31Z
New Revision: 8c80d01a95ba0d75c29191de0ea38cce48c9978f

URL: https://github.com/llvm/llvm-project/commit/8c80d01a95ba0d75c29191de0ea38cce48c9978f
DIFF: https://github.com/llvm/llvm-project/commit/8c80d01a95ba0d75c29191de0ea38cce48c9978f.diff

LOG: [mlir][NVGPU] NFC - Add a more convenient C++ builder for nvgpu::MmaSyncOp

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
    mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 5bb02b082575a..e595e9dffbe0b 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -158,8 +158,7 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
                        AnyVector:$matrixB,
                        AnyVector:$matrixC,
                        I64ArrayAttr:$mmaShape,
-                       OptionalAttr<UnitAttr>:$tf32Enabled
-                       );
+                       OptionalAttr<UnitAttr>:$tf32Enabled);
 
   let results = (outs AnyVector:$res);
 
@@ -167,7 +166,12 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
     OpBuilder<(ins "Value":$matrixA,
                    "Value":$matrixB,
                    "Value":$matrixC,
-                   "ArrayAttr":$mmaShape)>
+                   "ArrayAttr":$mmaShape)>,
+    OpBuilder<(ins "Value":$matrixA,
+                   "Value":$matrixB,
+                   "Value":$matrixC,
+                   "ArrayRef<int64_t>":$mmaShape,
+                   CArg<"bool", "false">:$tf32Enabled)>
   ];
 
   let assemblyFormat = [{

diff  --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 77c853a2c35f4..0472d27906ead 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -96,6 +96,15 @@ void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
         mmaShape, UnitAttr());
 }
 
+void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
+                      ::mlir::OperationState &odsState, Value matrixA,
+                      Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
+                      bool tf32Enabled) {
+  build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
+        odsBuilder.getI64ArrayAttr(mmaShape),
+        tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
+}
+
 /// Performs verification for MmaSyncOp and MmaSparseSyncOp.
 static LogicalResult verifyMmaSyncOp(Operation *op,
                                      TypedValue<VectorType> matrixA,


        


More information about the Mlir-commits mailing list