[Mlir-commits] [mlir] 8c80d01 - [mlir][NVGPU] NFC - Add a more convenient C++ builder for nvgpu::MmaSyncOp
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jun 19 07:31:01 PDT 2023
Author: Nicolas Vasilache
Date: 2023-06-19T13:54:31Z
New Revision: 8c80d01a95ba0d75c29191de0ea38cce48c9978f
URL: https://github.com/llvm/llvm-project/commit/8c80d01a95ba0d75c29191de0ea38cce48c9978f
DIFF: https://github.com/llvm/llvm-project/commit/8c80d01a95ba0d75c29191de0ea38cce48c9978f.diff
LOG: [mlir][NVGPU] NFC - Add a more convenient C++ builder for nvgpu::MmaSyncOp
Added:
Modified:
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 5bb02b082575a..e595e9dffbe0b 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -158,8 +158,7 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
AnyVector:$matrixB,
AnyVector:$matrixC,
I64ArrayAttr:$mmaShape,
- OptionalAttr<UnitAttr>:$tf32Enabled
- );
+ OptionalAttr<UnitAttr>:$tf32Enabled);
let results = (outs AnyVector:$res);
@@ -167,7 +166,12 @@ def NVGPU_MmaSyncOp : NVGPU_MmaSyncOp<"mma.sync"> {
OpBuilder<(ins "Value":$matrixA,
"Value":$matrixB,
"Value":$matrixC,
- "ArrayAttr":$mmaShape)>
+ "ArrayAttr":$mmaShape)>,
+ OpBuilder<(ins "Value":$matrixA,
+ "Value":$matrixB,
+ "Value":$matrixC,
+ "ArrayRef<int64_t>":$mmaShape,
+ CArg<"bool", "false">:$tf32Enabled)>
];
let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 77c853a2c35f4..0472d27906ead 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -96,6 +96,15 @@ void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
mmaShape, UnitAttr());
}
+void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, Value matrixA,
+ Value matrixB, Value matrixC, ArrayRef<int64_t> mmaShape,
+ bool tf32Enabled) {
+ build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
+ odsBuilder.getI64ArrayAttr(mmaShape),
+ tf32Enabled ? odsBuilder.getUnitAttr() : UnitAttr());
+}
+
/// Performs verification for MmaSyncOp and MmaSparseSyncOp.
static LogicalResult verifyMmaSyncOp(Operation *op,
TypedValue<VectorType> matrixA,
More information about the Mlir-commits
mailing list