[Mlir-commits] [mlir] [mlir][nvvm] Expand sitofp/uitofp to faster ops (PR #107001)

Christian Sigg llvmlistbot at llvm.org
Tue Sep 3 05:37:45 PDT 2024


https://github.com/chsigg updated https://github.com/llvm/llvm-project/pull/107001

>From ecd33d5df64357bc491cdfa5ff4b15ebfef7f157 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Fri, 30 Aug 2024 16:21:56 +0200
Subject: [PATCH 1/3] [mlir][nvvm] Expand sitofp/uitofp to faster ops

`sitofp` and `uitofp` are lowered to `cvt.rn` PTX instructions by the LLVM-NVPTX backend, which has lower throughput than int and float arithmetic ops.

Doing this optimization in LLVM would only work for i16->fp32 because the NVPTX backend has no i8 registers and promotes them to i16.
---
 .../LLVMIR/Transforms/OptimizeForNVVM.cpp     |  97 +++++++++-
 .../Dialect/LLVMIR/optimize-for-nvvm.mlir     | 178 ++++++++++++++++++
 2 files changed, 274 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
index 8c33148d1d2d78..de3295ead2c3cd 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -39,6 +40,17 @@ struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
                                 PatternRewriter &rewriter) const override;
 };
 
+// Replaces sitofp or uitofp on src types no wider than the dst type mantissa
+// with a faster combination of bit ops and add/sub.
+template <typename OpTy> // OpTy should be LLVM::SIToFPOp or LLVM::UIToFPOp.
+struct ExpandIToFP : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+private:
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override;
+};
+
 struct NVVMOptimizeForTarget
     : public NVVM::impl::NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
   void runOnOperation() override;
@@ -92,10 +104,93 @@ LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
   return success();
 }
 
+template <typename OpTy>
+LogicalResult
+ExpandIToFP<OpTy>::matchAndRewrite(OpTy op, PatternRewriter &rewriter) const {
+  Type srcType = op.getOperand().getType();
+  auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(srcType));
+  if (!intType)
+    return rewriter.notifyMatchFailure(op, "src type is not integer");
+  Type dstType = op.getType();
+  auto floatType = dyn_cast<FloatType>(getElementTypeOrSelf(dstType));
+  if (!floatType)
+    return rewriter.notifyMatchFailure(op, "dst type is not float");
+
+  // Mantissa width includes the integer bit, e.g. 24 for fp32.
+  auto mantissaWidth = floatType.getFPMantissaWidth();
+  if (mantissaWidth < 2)
+    return rewriter.notifyMatchFailure(op, "mantissa is less than 2 bits");
+  auto intWidth = intType.getWidth();
+  if (intWidth > mantissaWidth)
+    return rewriter.notifyMatchFailure(op, "src is wider than dst mantissa");
+
+  Type extType = IntegerType::get(rewriter.getContext(), floatType.getWidth(),
+                                  intType.getSignedness());
+  if (ShapedType shapedType = dyn_cast<ShapedType>(srcType))
+    extType = shapedType.clone(extType);
+  auto getAttr = [&](APInt value) -> TypedAttr {
+    if (ShapedType shapedType = dyn_cast<ShapedType>(extType))
+      return DenseElementsAttr::get(shapedType, value);
+    return IntegerAttr::get(extType, value);
+  };
+  ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+
+  if (intWidth == mantissaWidth) {
+    // Create a float bit-pattern with zero biased-exponent and zero mantissa.
+    APFloat::integerPart intPart = 1ull << (mantissaWidth - 1);
+    APFloat floatBits(floatType.getFloatSemantics(), intPart);
+    if (floatBits.bitcastToAPInt()[mantissaWidth - 1])
+      return rewriter.notifyMatchFailure(op, "bias exponent lsb bit is set");
+    TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt());
+
+    // Combine zero-extended src and float bit-pattern. The msb of src becomes
+    // the lsb of the exponent.
+    Value zext = builder.create<LLVM::ZExtOp>(extType, op.getOperand());
+    Value intConst = builder.create<LLVM::ConstantOp>(intAttr);
+    Value pattern = builder.create<LLVM::OrOp>(zext, intConst);
+
+    // Mask the exponent-lsb and the mantissa to get two separate values.
+    auto mask = APInt::getBitsSetFrom(floatType.getWidth(), mantissaWidth - 1);
+    Value exponentMask = builder.create<LLVM::ConstantOp>(getAttr(mask));
+    Value mantissaMask = builder.create<LLVM::ConstantOp>(getAttr(mask - 1));
+    Value exponentAnd = builder.create<LLVM::AndOp>(pattern, exponentMask);
+    Value mantissaAnd = builder.create<LLVM::AndOp>(pattern, mantissaMask);
+
+    // Bitcast these values to float and subtract or add them.
+    Value exponentCast = builder.create<LLVM::BitcastOp>(dstType, exponentAnd);
+    Value mantissaCast = builder.create<LLVM::BitcastOp>(dstType, mantissaAnd);
+    using SubOrAddOp = std::conditional_t<std::is_same_v<OpTy, LLVM::SIToFPOp>,
+                                          LLVM::FSubOp, LLVM::FAddOp>;
+    rewriter.replaceOpWithNewOp<SubOrAddOp>(op, mantissaCast, exponentCast);
+    return success();
+  }
+
+  // Create a float with zero biased-exponent and msb-set mantissa.
+  APFloat::integerPart intPart = 3ull << (mantissaWidth - 2);
+  APFloat floatBits(floatType.getFloatSemantics(), intPart);
+  TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt());
+  TypedAttr floatAttr = FloatAttr::get(floatType, floatBits);
+  if (ShapedType shapedType = dyn_cast<ShapedType>(dstType))
+    floatAttr = DenseElementsAttr::get(shapedType, floatAttr);
+
+  // Add extended src and bit-pattern of float, then subtract float.
+  using ExtOp = std::conditional_t<std::is_same_v<OpTy, LLVM::SIToFPOp>,
+                                   LLVM::SExtOp, LLVM::ZExtOp>;
+  Value ext = builder.create<ExtOp>(extType, op.getOperand());
+  Value intConst = builder.create<LLVM::ConstantOp>(intAttr);
+  Value add = builder.create<LLVM::AddOp>(ext, intConst);
+  Value bitcast = builder.create<LLVM::BitcastOp>(dstType, add);
+  Value floatConst = builder.create<LLVM::ConstantOp>(floatAttr);
+  rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, bitcast, floatConst);
+  return success();
+}
+
 void NVVMOptimizeForTarget::runOnOperation() {
   MLIRContext *ctx = getOperation()->getContext();
   RewritePatternSet patterns(ctx);
-  patterns.add<ExpandDivF16>(ctx);
+  patterns.add<ExpandDivF16, ExpandIToFP<LLVM::SIToFPOp>,
+               ExpandIToFP<LLVM::UIToFPOp>>(ctx);
+
   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
     return signalPassFailure();
 }
diff --git a/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir b/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
index b98d2e08b75486..a77d98a1b71a9c 100644
--- a/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
@@ -22,3 +22,181 @@ llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 {
   // CHECK: llvm.return %[[result]] : f16
   llvm.return %result : f16
 }
+
+// CHECK-LABEL: llvm.func @ui16_to_f32
+llvm.func @ui16_to_f32(%arg0 : i16) -> f32 {
+  // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i16 to i32
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1262485504 : i32) : i32
+  // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : i32
+  // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i32 to f32
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(0x4B400000 : f32) : f32
+  // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]]  : f32
+  %result = llvm.uitofp %arg0 : i16 to f32
+  // CHECK: llvm.return %[[result]] : f32
+  llvm.return %result : f32
+}
+
+// Checks that expansion only applies to integer width up to mantissa width.
+// CHECK-LABEL: llvm.func @si32_to_float
+llvm.func @si32_to_float_no_rewrite(%arg0 : i32) -> f32 {
+  // CHECK: %[[result:.*]] = llvm.sitofp %arg0 : i32 to f32
+  %result = llvm.sitofp %arg0 : i32 to f32
+  // CHECK: llvm.return %[[result]] : f32
+  llvm.return %result : f32
+}
+
+// CHECK-LABEL: llvm.func @si8_to_f16
+llvm.func @si8_to_f16(%arg0 : i8) -> f16 {
+  // CHECK-DAG: %[[sext:.*]] = llvm.sext %arg0 : i8 to i16
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(26112 : i16) : i16
+  // CHECK-DAG: %[[add:.*]] = llvm.add %[[sext]], %[[bias]] : i16
+  // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i16 to f16
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1.536000e+03 : f16) : f16
+  // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]]  : f16
+  %result = llvm.sitofp %arg0 : i8 to f16
+  // CHECK: llvm.return %[[result]] : f16
+  llvm.return %result : f16
+}
+
+// CHECK-LABEL: llvm.func @vec_ui4_to_bf16
+llvm.func @vec_ui4_to_bf16(%arg0 : vector<4xi4>) -> vector<4xbf16> {
+  // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi4> to vector<4xi16>
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<17216> : vector<4xi16>) : vector<4xi16>
+  // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : vector<4xi16>
+  // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : vector<4xi16> to vector<4xbf16>
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<1.920000e+02> : vector<4xbf16>) : vector<4xbf16>
+  // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]]  : vector<4xbf16>
+  %result = llvm.uitofp %arg0 : vector<4xi4> to vector<4xbf16>
+  // CHECK: llvm.return %[[result]] : vector<4xbf16>
+  llvm.return %result : vector<4xbf16>
+}
+
+// Checks code path when integer width is equal to mantissa width.
+// CHECK-LABEL: llvm.func @vec_si8_to_bf16
+llvm.func @vec_si8_to_bf16(%arg0 : vector<4xi8>) -> vector<4xbf16> {
+  // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi8> to vector<4xi16>
+  // CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(dense<17152> : vector<4xi16>) : vector<4xi16>
+  // CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : vector<4xi16>
+  // CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(dense<-128> : vector<4xi16>) : vector<4xi16>
+  // CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(dense<-129> : vector<4xi16>) : vector<4xi16>
+  // CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]]  : vector<4xi16>
+  // CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]]  : vector<4xi16>
+  // CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : vector<4xi16> to vector<4xbf16>
+  // CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : vector<4xi16> to vector<4xbf16>
+  // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[man_cast]], %[[exp_cast]]  : vector<4xbf16>
+  %result = llvm.sitofp %arg0 : vector<4xi8> to vector<4xbf16>
+  // CHECK: llvm.return %[[result]] : vector<4xbf16>
+  llvm.return %result : vector<4xbf16>
+}
+
+// Checks code path when integer width is equal to mantissa width.
+// CHECK-LABEL: llvm.func @ui8_to_bf16
+llvm.func @ui8_to_bf16(%arg0 : i8) -> bf16 {
+  // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i8 to i16
+  // CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(17152 : i16) : i16
+  // CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : i16
+  // CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(-128 : i16) : i16
+  // CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(-129 : i16) : i16
+  // CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]]  : i16
+  // CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]]  : i16
+  // CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : i16 to bf16
+  // CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : i16 to bf16
+  // CHECK-DAG: %[[result:.*]] = llvm.fadd %[[man_cast]], %[[exp_cast]]  : bf16
+  %result = llvm.uitofp %arg0 : i8 to bf16
+  // CHECK: llvm.return %[[result]] : bf16
+  llvm.return %result : bf16
+}
+
+// Checks that expansion does not apply when exponent bias lsb is set.
+// CHECK-LABEL: llvm.func @ui11_to_f16
+llvm.func @ui11_to_f16(%arg0 : i11) -> f16 {
+  // CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i11 to f16
+  %result = llvm.uitofp %arg0 : i11 to f16
+  // CHECK: llvm.return %[[result]] : f16
+  llvm.return %result : f16
+}
+
+// CHECK-LABEL: llvm.func @ui16_to_f32
+llvm.func @ui16_to_f32(%arg0 : i16) -> f32 {
+  // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i16 to i32
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1262485504 : i32) : i32
+  // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : i32
+  // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i32 to f32
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(0x4B400000 : f32) : f32
+  // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]]  : f32
+  %result = llvm.uitofp %arg0 : i16 to f32
+  // CHECK: llvm.return %[[result]] : f32
+  llvm.return %result : f32
+}
+
+// Checks that expansion only applies to integer width up to mantissa width.
+// CHECK-LABEL: llvm.func @si32_to_float
+llvm.func @si32_to_float_no_rewrite(%arg0 : i32) -> f32 {
+  // CHECK: %[[result:.*]] = llvm.sitofp %arg0 : i32 to f32
+  %result = llvm.sitofp %arg0 : i32 to f32
+  // CHECK: llvm.return %[[result]] : f32
+  llvm.return %result : f32
+}
+
+// CHECK-LABEL: llvm.func @si8_to_f16
+llvm.func @si8_to_f16(%arg0 : i8) -> f16 {
+  // CHECK-DAG: %[[sext:.*]] = llvm.sext %arg0 : i8 to i16
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(26112 : i16) : i16
+  // CHECK-DAG: %[[add:.*]] = llvm.add %[[sext]], %[[bias]] : i16
+  // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i16 to f16
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1.536000e+03 : f16) : f16
+  // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]]  : f16
+  %result = llvm.sitofp %arg0 : i8 to f16
+  // CHECK: llvm.return %[[result]] : f16
+  llvm.return %result : f16
+}
+
+// CHECK-LABEL: llvm.func @vec_ui4_to_bf16
+llvm.func @vec_ui4_to_bf16(%arg0 : vector<4xi4>) -> vector<4xbf16> {
+  // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi4> to vector<4xi16>
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<17216> : vector<4xi16>) : vector<4xi16>
+  // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : vector<4xi16>
+  // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : vector<4xi16> to vector<4xbf16>
+  // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<1.920000e+02> : vector<4xbf16>) : vector<4xbf16>
+  // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]]  : vector<4xbf16>
+  %result = llvm.uitofp %arg0 : vector<4xi4> to vector<4xbf16>
+  // CHECK: llvm.return %[[result]] : vector<4xbf16>
+  llvm.return %result : vector<4xbf16>
+}
+
+// Checks code path when integer width is equal to mantissa width.
+// CHECK-LABEL: llvm.func @vec_si8_to_bf16
+llvm.func @vec_si8_to_bf16(%arg0 : vector<4xi8>) -> vector<4xbf16> {
+  // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi8> to vector<4xi16>
+  // CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(dense<17152> : vector<4xi16>) : vector<4xi16>
+  // CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : vector<4xi16>
+  // CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(dense<-128> : vector<4xi16>) : vector<4xi16>
+  // CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(dense<-129> : vector<4xi16>) : vector<4xi16>
+  // CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]]  : vector<4xi16>
+  // CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]]  : vector<4xi16>
+  // CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : vector<4xi16> to vector<4xbf16>
+  // CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : vector<4xi16> to vector<4xbf16>
+  // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[man_cast]], %[[exp_cast]]  : vector<4xbf16>
+  %result = llvm.sitofp %arg0 : vector<4xi8> to vector<4xbf16>
+  // CHECK: llvm.return %[[result]] : vector<4xbf16>
+  llvm.return %result : vector<4xbf16>
+}
+
+// Checks that expansion does not apply when unsigned integer width is equal to
+// mantissa width.
+// CHECK-LABEL: llvm.func @ui8_to_bf16
+llvm.func @ui8_to_bf16(%arg0 : i8) -> bf16 {
+  // CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i8 to bf16
+  %result = llvm.uitofp %arg0 : i8 to bf16
+  // CHECK: llvm.return %[[result]] : bf16
+  llvm.return %result : bf16
+}
+
+// Checks that expansion does not apply when exponent bias lsb is set.
+// CHECK-LABEL: llvm.func @ui11_to_f16
+llvm.func @ui11_to_f16(%arg0 : i11) -> f16 {
+  // CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i11 to f16
+  %result = llvm.uitofp %arg0 : i11 to f16
+  // CHECK: llvm.return %[[result]] : f16
+  llvm.return %result : f16
+}

>From bd48a7a9fdc96aa4fccf48b7c39aa2f231bd3ab3 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Tue, 3 Sep 2024 13:56:51 +0200
Subject: [PATCH 2/3] Add condition.

---
 mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
index de3295ead2c3cd..7e5574fa124644 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
@@ -136,6 +136,10 @@ ExpandIToFP<OpTy>::matchAndRewrite(OpTy op, PatternRewriter &rewriter) const {
   ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
 
   if (intWidth == mantissaWidth) {
+    if (std::is_same_v<OpTy, LLVM::UIToFPOp>) {
+      return rewriter.notifyMatchFailure(
+          op, "unsigned src is as wide as dst mantissa");
+    }
     // Create a float bit-pattern with zero biased-exponent and zero mantissa.
     APFloat::integerPart intPart = 1ull << (mantissaWidth - 1);
     APFloat floatBits(floatType.getFloatSemantics(), intPart);

>From 53ac7a63e0b6152ea611b54f94921369601945ac Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Tue, 3 Sep 2024 14:37:36 +0200
Subject: [PATCH 3/3] Remove unnecessary conditional.

---
 mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
index 7e5574fa124644..fbaae515787d32 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
@@ -163,9 +163,7 @@ ExpandIToFP<OpTy>::matchAndRewrite(OpTy op, PatternRewriter &rewriter) const {
     // Bitcast these values to float and subtract or add them.
     Value exponentCast = builder.create<LLVM::BitcastOp>(dstType, exponentAnd);
     Value mantissaCast = builder.create<LLVM::BitcastOp>(dstType, mantissaAnd);
-    using SubOrAddOp = std::conditional_t<std::is_same_v<OpTy, LLVM::SIToFPOp>,
-                                          LLVM::FSubOp, LLVM::FAddOp>;
-    rewriter.replaceOpWithNewOp<SubOrAddOp>(op, mantissaCast, exponentCast);
+    rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, mantissaCast, exponentCast);
     return success();
   }
 



More information about the Mlir-commits mailing list