[Mlir-commits] [mlir] 2b23e6c - [mlir][nvgpu] Add `nvgpu.rcp` OP (#100965)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 30 00:20:53 PDT 2024
Author: Observer007
Date: 2024-07-30T09:20:49+02:00
New Revision: 2b23e6c8d6a28ce1c134d1b6b322e35eec9d98e4
URL: https://github.com/llvm/llvm-project/commit/2b23e6c8d6a28ce1c134d1b6b322e35eec9d98e4
DIFF: https://github.com/llvm/llvm-project/commit/2b23e6c8d6a28ce1c134d1b6b322e35eec9d98e4.diff
LOG: [mlir][nvgpu] Add `nvgpu.rcp` OP (#100965)
This PR introduces a new OP for reciprocal calculation for `vector`
types using `nvvm.rcp` OPs. Currently, it supports only f32 types
---------
Co-authored-by: jingzec <jingzec at nvidia.com>
Added:
Modified:
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
mlir/test/Dialect/NVGPU/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index dda8f31e688fe..1f52f6b91617c 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -20,6 +20,7 @@
#ifndef NVGPU
#define NVGPU
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
@@ -109,10 +110,22 @@ def TensorMapInterleaveKind : I32EnumAttr<"TensorMapInterleaveKind",
let cppNamespace = "::mlir::nvgpu";
}
+def RcpApprox : I32EnumAttrCase<"APPROX", 0, "approx">;
+def RcpRN : I32EnumAttrCase<"RN", 1, "rn">;
+def RcpRZ : I32EnumAttrCase<"RZ", 2, "rz">;
+def RcpRM : I32EnumAttrCase<"RM", 3, "rm">;
+def RcpRP : I32EnumAttrCase<"RP", 4, "rp">;
+def RcpRoundingMode : I32EnumAttr<"RcpRoundingMode", "Rounding mode of rcp",
+ [RcpApprox, RcpRN, RcpRZ, RcpRM, RcpRP]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::nvgpu";
+}
+
def TensorMapSwizzleAttr : EnumAttr<NVGPU_Dialect, TensorMapSwizzleKind, "swizzle">;
def TensorMapL2PromoAttr : EnumAttr<NVGPU_Dialect, TensorMapL2PromoKind, "l2promo">;
def TensorMapOOBAttr : EnumAttr<NVGPU_Dialect, TensorMapOOBKind, "oob">;
def TensorMapInterleaveAttr : EnumAttr<NVGPU_Dialect, TensorMapInterleaveKind, "interleave">;
+def RcpRoundingModeAttr : EnumAttr<NVGPU_Dialect, RcpRoundingMode, "rcp_rounding_mode">;
//===----------------------------------------------------------------------===//
// NVGPU Type Definitions
@@ -802,4 +815,24 @@ def NVGPU_WarpgroupMmaInitAccumulatorOp : NVGPU_Op<"warpgroup.mma.init.accumulat
let hasVerifier = 1;
}
+def NVGPU_RcpOp : NVGPU_Op<"rcp", [Pure,
+ SameOperandsAndResultType]> {
+ let summary = "The reciprocal calculation for vector types";
+ let description = [{
+ Reciprocal calculation for `vector` types using `nvvm.rcp` OPs.
+
+ Currently, only the `approx` rounding mode and `ftz` are supported, and only for the `f32` type.
+
+ The input and output must be of the same vector type and shape.
+ }];
+ let arguments = (ins VectorOf<[F32]>:$in,
+ DefaultValuedAttr<RcpRoundingModeAttr, "RcpRoundingMode::APPROX">:$rounding,
+ UnitAttr:$ftz);
+ let results = (outs VectorOf<[F32]>:$out);
+ let assemblyFormat = [{
+ $in `{` `rounding` `=` $rounding (`,` `ftz` $ftz^)? `}`
+ attr-dict `:` type($out)
+ }];
+ let hasVerifier = 1;
+}
#endif // NVGPU
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 19070f6f062a0..aad2ac6f4dd2b 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc"
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 11d29754aa760..cf984ca3293c0 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -11,6 +11,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -1666,6 +1667,40 @@ struct NVGPUTmaPrefetchOpLowering
}
};
+struct NVGPURcpOpLowering : public ConvertOpToLLVMPattern<nvgpu::RcpOp> {
+ using ConvertOpToLLVMPattern<nvgpu::RcpOp>::ConvertOpToLLVMPattern;
+ LogicalResult
+ matchAndRewrite(nvgpu::RcpOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ImplicitLocOpBuilder b(op->getLoc(), rewriter);
+ auto i64Ty = b.getI64Type();
+ auto f32Ty = b.getF32Type();
+ VectorType inTy = op.getIn().getType();
+ // apply rcp.approx.ftz.f on each element in vector.
+ auto convert1DVec = [&](Type llvm1DVectorTy, Value inVec) {
+ Value ret1DVec = b.create<LLVM::UndefOp>(llvm1DVectorTy);
+ int numElems = llvm::cast<VectorType>(llvm1DVectorTy).getNumElements();
+ for (int i = 0; i < numElems; i++) {
+ Value idx = b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(i));
+ Value elem = b.create<LLVM::ExtractElementOp>(inVec, idx);
+ Value dst = b.create<NVVM::RcpApproxFtzF32Op>(f32Ty, elem);
+ ret1DVec = b.create<LLVM::InsertElementOp>(ret1DVec, dst, idx);
+ }
+ return ret1DVec;
+ };
+ if (inTy.getRank() == 1) {
+ rewriter.replaceOp(op, convert1DVec(inTy, adaptor.getIn()));
+ return success();
+ }
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+ OpAdaptor adaptor(operands);
+ return convert1DVec(llvm1DVectorTy, adaptor.getIn());
+ },
+ rewriter);
+ }
+};
} // namespace
void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1688,5 +1723,5 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
NVGPUWarpgroupMmaInitAccumulatorOpLowering, // nvgpu.warpgroup.mma.init.accumulator
MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
- NVGPUMmaSparseSyncLowering>(converter);
+ NVGPUMmaSparseSyncLowering, NVGPURcpOpLowering>(converter);
}
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 26f831f10a4e4..de9bbcbace692 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -644,6 +644,21 @@ LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// RcpOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult RcpOp::verify() {
+ RcpRoundingModeAttr rounding = getRoundingAttr();
+ bool ftz = getFtz();
+ // Currently, only `rcp_approx` and `ftz` is supported.
+ if (rounding.getValue() != RcpRoundingMode::APPROX || !ftz) {
+ return emitOpError() << "has a limitation. " << rounding
+ << " or non-ftz is not supported yet.";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd dialect, type, and op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index 86a552c03a473..156a8a468d5b4 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -1339,3 +1339,18 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// CHECK-LABEL: @rcp_approx_ftz_f32
+// CHECK-SAME: %[[IN:.*]]: vector<32x16xf32>
+func.func @rcp_approx_ftz_f32(%in: vector<32x16xf32>) {
+ // CHECK: %[[IN_LLVM:.*]] = builtin.unrealized_conversion_cast %[[IN]] : vector<32x16xf32> to !llvm.array<32 x vector<16xf32>>
+ // CHECK: %[[IN1DVEC:.*]] = llvm.extractvalue %[[IN_LLVM]][0] : !llvm.array<32 x vector<16xf32>>
+ // CHECK: %[[OUT1DVEC:.*]] = llvm.mlir.undef : vector<16xf32>
+ // CHECK: %[[IDX_0:.+]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[ELEM_0:.*]] = llvm.extractelement %[[IN1DVEC]][%[[IDX_0]] : i64]
+ // CHECK: %[[ELEM_RCP0:.*]] = nvvm.rcp.approx.ftz.f %[[ELEM_0]] : f32
+ // CHECK: llvm.insertelement %[[ELEM_RCP0]], %[[OUT1DVEC]][%[[IDX_0]] : i64] : vector<16xf32>
+ // CHECK-COUNT-511: nvvm.rcp.approx.ftz.f
+ %out = nvgpu.rcp %in {rounding = approx, ftz} : vector<32x16xf32>
+ return
+}
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index c3aed35153241..f7db1140794e5 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -336,3 +336,21 @@ func.func @tma_generate_descriptor_incorrect_last_dim(%desc: !desc, %buffer2: m
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<64x128xf32,3>
return
}
+// -----
+
+func.func @rcp_unsupported_rounding_0(%in : vector<16xf32>) {
+ // expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode rn> or non-ftz is not supported yet.}}
+ %out = nvgpu.rcp %in {rounding = rn, ftz} : vector<16xf32>
+}
+// -----
+
+func.func @rcp_unsupported_rounding_1(%in : vector<16xf32>) {
+ // expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode rz> or non-ftz is not supported yet.}}
+ %out = nvgpu.rcp %in {rounding = rz} : vector<16xf32>
+}
+// -----
+
+func.func @rcp_unsupported_ftz(%in : vector<16xf32>) {
+ // expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode approx> or non-ftz is not supported yet.}}
+ %out = nvgpu.rcp %in {rounding = approx} : vector<16xf32>
+}
More information about the Mlir-commits
mailing list