[Mlir-commits] [mlir] [mlir][nvvm] Expand sitofp/uitofp to faster ops (PR #107001)
Christian Sigg
llvmlistbot at llvm.org
Tue Sep 3 04:57:00 PDT 2024
https://github.com/chsigg updated https://github.com/llvm/llvm-project/pull/107001
>From ecd33d5df64357bc491cdfa5ff4b15ebfef7f157 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Fri, 30 Aug 2024 16:21:56 +0200
Subject: [PATCH 1/2] [mlir][nvvm] Expand sitofp/uitofp to faster ops
`sitofp` and `uitofp` are lowered to `cvt.rn` PTX instructions by the LLVM-NVPTX backend, which has lower throughput than int and float arithmetic ops.
Doing this optimization in LLVM would only work for i16->fp32 because the NVPTX backend has no i8 registers and promotes them to i16.
---
.../LLVMIR/Transforms/OptimizeForNVVM.cpp | 97 +++++++++-
.../Dialect/LLVMIR/optimize-for-nvvm.mlir | 178 ++++++++++++++++++
2 files changed, 274 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
index 8c33148d1d2d78..de3295ead2c3cd 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -39,6 +40,17 @@ struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
PatternRewriter &rewriter) const override;
};
+// Replaces sitofp or uitofp on src types no wider than the dst type mantissa
+// with a faster combination of bit ops and add/sub.
+template <typename OpTy> // OpTy should be LLVM::SIToFPOp or LLVM::UIToFPOp.
+struct ExpandIToFP : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+private:
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override;
+};
+
struct NVVMOptimizeForTarget
: public NVVM::impl::NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
void runOnOperation() override;
@@ -92,10 +104,93 @@ LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
return success();
}
+template <typename OpTy>
+LogicalResult
+ExpandIToFP<OpTy>::matchAndRewrite(OpTy op, PatternRewriter &rewriter) const {
+ Type srcType = op.getOperand().getType();
+ auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(srcType));
+ if (!intType)
+ return rewriter.notifyMatchFailure(op, "src type is not integer");
+ Type dstType = op.getType();
+ auto floatType = dyn_cast<FloatType>(getElementTypeOrSelf(dstType));
+ if (!floatType)
+ return rewriter.notifyMatchFailure(op, "dst type is not float");
+
+ // Mantissa width includes the integer bit, e.g. 24 for fp32.
+ auto mantissaWidth = floatType.getFPMantissaWidth();
+ if (mantissaWidth < 2)
+ return rewriter.notifyMatchFailure(op, "mantissa is less than 2 bits");
+ auto intWidth = intType.getWidth();
+ if (intWidth > mantissaWidth)
+ return rewriter.notifyMatchFailure(op, "src is wider than dst mantissa");
+
+ Type extType = IntegerType::get(rewriter.getContext(), floatType.getWidth(),
+ intType.getSignedness());
+ if (ShapedType shapedType = dyn_cast<ShapedType>(srcType))
+ extType = shapedType.clone(extType);
+ auto getAttr = [&](APInt value) -> TypedAttr {
+ if (ShapedType shapedType = dyn_cast<ShapedType>(extType))
+ return DenseElementsAttr::get(shapedType, value);
+ return IntegerAttr::get(extType, value);
+ };
+ ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+
+ if (intWidth == mantissaWidth) {
+ // Create a float bit-pattern with zero biased-exponent and zero mantissa.
+ APFloat::integerPart intPart = 1ull << (mantissaWidth - 1);
+ APFloat floatBits(floatType.getFloatSemantics(), intPart);
+ if (floatBits.bitcastToAPInt()[mantissaWidth - 1])
+ return rewriter.notifyMatchFailure(op, "bias exponent lsb bit is set");
+ TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt());
+
+ // Combine zero-extended src and float bit-pattern. The msb of src becomes
+ // the lsb of the exponent.
+ Value zext = builder.create<LLVM::ZExtOp>(extType, op.getOperand());
+ Value intConst = builder.create<LLVM::ConstantOp>(intAttr);
+ Value pattern = builder.create<LLVM::OrOp>(zext, intConst);
+
+ // Mask the exponent-lsb and the mantissa to get two separate values.
+ auto mask = APInt::getBitsSetFrom(floatType.getWidth(), mantissaWidth - 1);
+ Value exponentMask = builder.create<LLVM::ConstantOp>(getAttr(mask));
+ Value mantissaMask = builder.create<LLVM::ConstantOp>(getAttr(mask - 1));
+ Value exponentAnd = builder.create<LLVM::AndOp>(pattern, exponentMask);
+ Value mantissaAnd = builder.create<LLVM::AndOp>(pattern, mantissaMask);
+
+ // Bitcast these values to float and subtract or add them.
+ Value exponentCast = builder.create<LLVM::BitcastOp>(dstType, exponentAnd);
+ Value mantissaCast = builder.create<LLVM::BitcastOp>(dstType, mantissaAnd);
+ using SubOrAddOp = std::conditional_t<std::is_same_v<OpTy, LLVM::SIToFPOp>,
+ LLVM::FSubOp, LLVM::FAddOp>;
+ rewriter.replaceOpWithNewOp<SubOrAddOp>(op, mantissaCast, exponentCast);
+ return success();
+ }
+
+ // Create a float with zero biased-exponent and msb-set mantissa.
+ APFloat::integerPart intPart = 3ull << (mantissaWidth - 2);
+ APFloat floatBits(floatType.getFloatSemantics(), intPart);
+ TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt());
+ TypedAttr floatAttr = FloatAttr::get(floatType, floatBits);
+ if (ShapedType shapedType = dyn_cast<ShapedType>(dstType))
+ floatAttr = DenseElementsAttr::get(shapedType, floatAttr);
+
+ // Add extended src and bit-pattern of float, then subtract float.
+ using ExtOp = std::conditional_t<std::is_same_v<OpTy, LLVM::SIToFPOp>,
+ LLVM::SExtOp, LLVM::ZExtOp>;
+ Value ext = builder.create<ExtOp>(extType, op.getOperand());
+ Value intConst = builder.create<LLVM::ConstantOp>(intAttr);
+ Value add = builder.create<LLVM::AddOp>(ext, intConst);
+ Value bitcast = builder.create<LLVM::BitcastOp>(dstType, add);
+ Value floatConst = builder.create<LLVM::ConstantOp>(floatAttr);
+ rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, bitcast, floatConst);
+ return success();
+}
+
void NVVMOptimizeForTarget::runOnOperation() {
MLIRContext *ctx = getOperation()->getContext();
RewritePatternSet patterns(ctx);
- patterns.add<ExpandDivF16>(ctx);
+ patterns.add<ExpandDivF16, ExpandIToFP<LLVM::SIToFPOp>,
+ ExpandIToFP<LLVM::UIToFPOp>>(ctx);
+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir b/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
index b98d2e08b75486..a77d98a1b71a9c 100644
--- a/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
@@ -22,3 +22,181 @@ llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 {
// CHECK: llvm.return %[[result]] : f16
llvm.return %result : f16
}
+
+// CHECK-LABEL: llvm.func @ui16_to_f32
+llvm.func @ui16_to_f32(%arg0 : i16) -> f32 {
+ // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i16 to i32
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1262485504 : i32) : i32
+ // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : i32
+ // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i32 to f32
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(0x4B400000 : f32) : f32
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f32
+ %result = llvm.uitofp %arg0 : i16 to f32
+ // CHECK: llvm.return %[[result]] : f32
+ llvm.return %result : f32
+}
+
+// Checks that expansion only applies to integer width up to mantissa width.
+// CHECK-LABEL: llvm.func @si32_to_float
+llvm.func @si32_to_float_no_rewrite(%arg0 : i32) -> f32 {
+ // CHECK: %[[result:.*]] = llvm.sitofp %arg0 : i32 to f32
+ %result = llvm.sitofp %arg0 : i32 to f32
+ // CHECK: llvm.return %[[result]] : f32
+ llvm.return %result : f32
+}
+
+// CHECK-LABEL: llvm.func @si8_to_f16
+llvm.func @si8_to_f16(%arg0 : i8) -> f16 {
+ // CHECK-DAG: %[[sext:.*]] = llvm.sext %arg0 : i8 to i16
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(26112 : i16) : i16
+ // CHECK-DAG: %[[add:.*]] = llvm.add %[[sext]], %[[bias]] : i16
+ // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i16 to f16
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1.536000e+03 : f16) : f16
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f16
+ %result = llvm.sitofp %arg0 : i8 to f16
+ // CHECK: llvm.return %[[result]] : f16
+ llvm.return %result : f16
+}
+
+// CHECK-LABEL: llvm.func @vec_ui4_to_bf16
+llvm.func @vec_ui4_to_bf16(%arg0 : vector<4xi4>) -> vector<4xbf16> {
+ // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi4> to vector<4xi16>
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<17216> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : vector<4xi16>
+ // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : vector<4xi16> to vector<4xbf16>
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<1.920000e+02> : vector<4xbf16>) : vector<4xbf16>
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : vector<4xbf16>
+ %result = llvm.uitofp %arg0 : vector<4xi4> to vector<4xbf16>
+ // CHECK: llvm.return %[[result]] : vector<4xbf16>
+ llvm.return %result : vector<4xbf16>
+}
+
+// Checks code path when integer width is equal to mantissa width.
+// CHECK-LABEL: llvm.func @vec_si8_to_bf16
+llvm.func @vec_si8_to_bf16(%arg0 : vector<4xi8>) -> vector<4xbf16> {
+ // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi8> to vector<4xi16>
+ // CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(dense<17152> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : vector<4xi16>
+ // CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(dense<-128> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(dense<-129> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : vector<4xi16>
+ // CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : vector<4xi16>
+ // CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : vector<4xi16> to vector<4xbf16>
+ // CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : vector<4xi16> to vector<4xbf16>
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[man_cast]], %[[exp_cast]] : vector<4xbf16>
+ %result = llvm.sitofp %arg0 : vector<4xi8> to vector<4xbf16>
+ // CHECK: llvm.return %[[result]] : vector<4xbf16>
+ llvm.return %result : vector<4xbf16>
+}
+
+// Checks code path when integer width is equal to mantissa width.
+// CHECK-LABEL: llvm.func @ui8_to_bf16
+llvm.func @ui8_to_bf16(%arg0 : i8) -> bf16 {
+ // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i8 to i16
+ // CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(17152 : i16) : i16
+ // CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : i16
+ // CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(-128 : i16) : i16
+ // CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(-129 : i16) : i16
+ // CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : i16
+ // CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : i16
+ // CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : i16 to bf16
+ // CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : i16 to bf16
+ // CHECK-DAG: %[[result:.*]] = llvm.fadd %[[man_cast]], %[[exp_cast]] : bf16
+ %result = llvm.uitofp %arg0 : i8 to bf16
+ // CHECK: llvm.return %[[result]] : bf16
+ llvm.return %result : bf16
+}
+
+// Checks that expansion does not apply when exponent bias lsb is set.
+// CHECK-LABEL: llvm.func @ui11_to_f16
+llvm.func @ui11_to_f16(%arg0 : i11) -> f16 {
+ // CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i11 to f16
+ %result = llvm.uitofp %arg0 : i11 to f16
+ // CHECK: llvm.return %[[result]] : f16
+ llvm.return %result : f16
+}
+
+// CHECK-LABEL: llvm.func @ui16_to_f32
+llvm.func @ui16_to_f32(%arg0 : i16) -> f32 {
+ // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i16 to i32
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1262485504 : i32) : i32
+ // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : i32
+ // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i32 to f32
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(0x4B400000 : f32) : f32
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f32
+ %result = llvm.uitofp %arg0 : i16 to f32
+ // CHECK: llvm.return %[[result]] : f32
+ llvm.return %result : f32
+}
+
+// Checks that expansion only applies to integer width up to mantissa width.
+// CHECK-LABEL: llvm.func @si32_to_float
+llvm.func @si32_to_float_no_rewrite(%arg0 : i32) -> f32 {
+ // CHECK: %[[result:.*]] = llvm.sitofp %arg0 : i32 to f32
+ %result = llvm.sitofp %arg0 : i32 to f32
+ // CHECK: llvm.return %[[result]] : f32
+ llvm.return %result : f32
+}
+
+// CHECK-LABEL: llvm.func @si8_to_f16
+llvm.func @si8_to_f16(%arg0 : i8) -> f16 {
+ // CHECK-DAG: %[[sext:.*]] = llvm.sext %arg0 : i8 to i16
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(26112 : i16) : i16
+ // CHECK-DAG: %[[add:.*]] = llvm.add %[[sext]], %[[bias]] : i16
+ // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i16 to f16
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1.536000e+03 : f16) : f16
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f16
+ %result = llvm.sitofp %arg0 : i8 to f16
+ // CHECK: llvm.return %[[result]] : f16
+ llvm.return %result : f16
+}
+
+// CHECK-LABEL: llvm.func @vec_ui4_to_bf16
+llvm.func @vec_ui4_to_bf16(%arg0 : vector<4xi4>) -> vector<4xbf16> {
+ // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi4> to vector<4xi16>
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<17216> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : vector<4xi16>
+ // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : vector<4xi16> to vector<4xbf16>
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<1.920000e+02> : vector<4xbf16>) : vector<4xbf16>
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : vector<4xbf16>
+ %result = llvm.uitofp %arg0 : vector<4xi4> to vector<4xbf16>
+ // CHECK: llvm.return %[[result]] : vector<4xbf16>
+ llvm.return %result : vector<4xbf16>
+}
+
+// Checks code path when integer width is equal to mantissa width.
+// CHECK-LABEL: llvm.func @vec_si8_to_bf16
+llvm.func @vec_si8_to_bf16(%arg0 : vector<4xi8>) -> vector<4xbf16> {
+ // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi8> to vector<4xi16>
+ // CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(dense<17152> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : vector<4xi16>
+ // CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(dense<-128> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(dense<-129> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : vector<4xi16>
+ // CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : vector<4xi16>
+ // CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : vector<4xi16> to vector<4xbf16>
+ // CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : vector<4xi16> to vector<4xbf16>
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[man_cast]], %[[exp_cast]] : vector<4xbf16>
+ %result = llvm.sitofp %arg0 : vector<4xi8> to vector<4xbf16>
+ // CHECK: llvm.return %[[result]] : vector<4xbf16>
+ llvm.return %result : vector<4xbf16>
+}
+
+// Checks that expansion does not apply when unsigned integer width is equal to
+// mantissa width.
+// CHECK-LABEL: llvm.func @ui8_to_bf16
+llvm.func @ui8_to_bf16(%arg0 : i8) -> bf16 {
+ // CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i8 to bf16
+ %result = llvm.uitofp %arg0 : i8 to bf16
+ // CHECK: llvm.return %[[result]] : bf16
+ llvm.return %result : bf16
+}
+
+// Checks that expansion does not apply when exponent bias lsb is set.
+// CHECK-LABEL: llvm.func @ui11_to_f16
+llvm.func @ui11_to_f16(%arg0 : i11) -> f16 {
+ // CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i11 to f16
+ %result = llvm.uitofp %arg0 : i11 to f16
+ // CHECK: llvm.return %[[result]] : f16
+ llvm.return %result : f16
+}
>From bd48a7a9fdc96aa4fccf48b7c39aa2f231bd3ab3 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Tue, 3 Sep 2024 13:56:51 +0200
Subject: [PATCH 2/2] Add condition.
---
mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
index de3295ead2c3cd..7e5574fa124644 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
@@ -136,6 +136,10 @@ ExpandIToFP<OpTy>::matchAndRewrite(OpTy op, PatternRewriter &rewriter) const {
ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
if (intWidth == mantissaWidth) {
+ if (std::is_same_v<OpTy, LLVM::UIToFPOp>) {
+ return rewriter.notifyMatchFailure(
+ op, "unsigned src is as wide as dst mantissa");
+ }
// Create a float bit-pattern with zero biased-exponent and zero mantissa.
APFloat::integerPart intPart = 1ull << (mantissaWidth - 1);
APFloat floatBits(floatType.getFloatSemantics(), intPart);
More information about the Mlir-commits
mailing list