[Mlir-commits] [mlir] abfc358 - [mlir][arith] Add `sitofp` support to WIE
Jakub Kuderski
llvmlistbot at llvm.org
Wed Mar 22 16:13:07 PDT 2023
Author: Jakub Kuderski
Date: 2023-03-22T19:12:16-04:00
New Revision: abfc358cff0c0cfc8ffbc6c164d97e13a18a1685
URL: https://github.com/llvm/llvm-project/commit/abfc358cff0c0cfc8ffbc6c164d97e13a18a1685
DIFF: https://github.com/llvm/llvm-project/commit/abfc358cff0c0cfc8ffbc6c164d97e13a18a1685.diff
LOG: [mlir][arith] Add `sitofp` support to WIE
This depends on the handling of `uitofp` in D146606.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D146597
Added:
mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-sitofp-i32.mlir
Modified:
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arith/emulate-wide-int.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 83f01397c4490..781ea3d3eca63 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -16,6 +16,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
@@ -907,6 +908,52 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertSIToFP
+//===----------------------------------------------------------------------===//
+
+struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+
+ Value in = op.getIn();
+ Type oldTy = in.getType();
+ auto newTy =
+ dyn_cast_or_null<VectorType>(getTypeConverter()->convertType(oldTy));
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unsupported type: {0}", oldTy));
+
+ unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth();
+ Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
+ Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1);
+ Value allOnesCst = createScalarOrSplatConstant(
+ rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth));
+
+ // To avoid operating on very large unsigned numbers, perform the
+ // conversion on the absolute value. Then, decide whether to negate the
+ // result or not based on that sign bit. We assume two's complement and
+ // implement negation by flipping all bits and adding 1.
+ // Note that this relies on the the other conversion patterns to legalize
+ // created ops and narrow the bit widths.
+ Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+ in, zeroCst);
+ Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst);
+ Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
+ Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
+
+ Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
+ Value negResult = rewriter.create<arith::NegFOp>(loc, absResult);
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
+ absResult);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertUIToFP
//===----------------------------------------------------------------------===//
@@ -1146,5 +1193,5 @@ void arith::populateArithWideIntEmulationPatterns(
ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
- ConvertUIToFP>(typeConverter, patterns.getContext());
+ ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index 55b4e7f89b0ac..9fb5478d7e94f 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -964,3 +964,46 @@ func.func @uitofp_i64_f16(%a : i64) -> f16 {
%r = arith.uitofp %a : i64 to f16
return %r : f16
}
+
+// CHECK-LABEL: func @sitofp_i64_f64
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> f64
+// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<2xi32>
+// CHECK: [[ONES1:%.+]] = vector.extract [[VONES]][0] : vector<2xi32>
+// CHECK-NEXT: [[ONES2:%.+]] = vector.extract [[VONES]][1] : vector<2xi32>
+// CHECK: arith.xori {{%.+}}, [[ONES1]] : i32
+// CHECK-NEXT: arith.xori {{%.+}}, [[ONES2]] : i32
+// CHECK: [[CST0:%.+]] = arith.constant 0 : i32
+// CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0]] : i32
+// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : i32 to f64
+// CHECK-NEXT: [[HIFP:%.+]] = arith.uitofp [[HI]] : i32 to f64
+// CHECK-NEXT: [[POW:%.+]] = arith.constant 0x41F0000000000000 : f64
+// CHECK-NEXT: [[RESHI:%.+]] = arith.mulf [[HIFP]], [[POW]] : f64
+// CHECK-NEXT: [[RES:%.+]] = arith.addf [[LOWFP]], [[RESHI]] : f64
+// CHECK-NEXT: [[SEL:%.+]] = arith.select [[HIEQ0]], [[LOWFP]], [[RES]] : f64
+// CHECK-NEXT: [[NEG:%.+]] = arith.negf [[SEL]] : f64
+// CHECK-NEXT: [[FINAL:%.+]] = arith.select %{{.+}}, [[NEG]], [[SEL]] : f64
+// CHECK-NEXT: return [[FINAL]] : f64
+func.func @sitofp_i64_f64(%a : i64) -> f64 {
+ %r = arith.sitofp %a : i64 to f64
+ return %r : f64
+}
+
+// CHECK-LABEL: func @sitofp_i64_f64_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xf64>
+// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<3x2xi32>
+// CHECK: arith.xori
+// CHECK-NEXT: arith.xori
+// CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0:%.+]] : vector<3xi32>
+// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : vector<3xi32> to vector<3xf64>
+// CHECK-NEXT: [[HIFP:%.+]] = arith.uitofp [[HI:%.+]] : vector<3xi32> to vector<3xf64>
+// CHECK-NEXT: [[POW:%.+]] = arith.constant dense<0x41F0000000000000> : vector<3xf64>
+// CHECK-NEXT: [[RESHI:%.+]] = arith.mulf [[HIFP]], [[POW]] : vector<3xf64>
+// CHECK-NEXT: [[RES:%.+]] = arith.addf [[LOWFP]], [[RESHI]] : vector<3xf64>
+// CHECK-NEXT: [[SEL:%.+]] = arith.select [[HIEQ0]], [[LOWFP]], [[RES]] : vector<3xi1>, vector<3xf64>
+// CHECK-NEXT: [[NEG:%.+]] = arith.negf [[SEL]] : vector<3xf64>
+// CHECK-NEXT: [[FINAL:%.+]] = arith.select %{{.+}}, [[NEG]], [[SEL]] : vector<3xi1>, vector<3xf64>
+// CHECK-NEXT: return [[FINAL]] : vector<3xf64>
+func.func @sitofp_i64_f64_vector(%a : vector<3xi64>) -> vector<3xf64> {
+ %r = arith.sitofp %a : vector<3xi64> to vector<3xf64>
+ return %r : vector<3xf64>
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-sitofp-i32.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-sitofp-i32.mlir
new file mode 100644
index 0000000000000..3fc008705f111
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-sitofp-i32.mlir
@@ -0,0 +1,68 @@
+// Check that the wide integer `arith.sitofp` emulation produces the same result as wide
+// `arith.sitofp`. Emulate i32 ops with i16 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=16" \
+// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+// Ops in this function *only* will be emulated using i16 types.
+func.func @emulate_sitofp(%arg: i32) -> f32 {
+ %res = arith.sitofp %arg : i32 to f32
+ return %res : f32
+}
+
+func.func @check_sitofp(%arg : i32) -> () {
+ %res = func.call @emulate_sitofp(%arg) : (i32) -> (f32)
+ vector.print %res : f32
+ return
+}
+
+func.func @entry() {
+ %cst0 = arith.constant 0 : i32
+ %cst1 = arith.constant 1 : i32
+ %cst2 = arith.constant 2 : i32
+ %cst7 = arith.constant 7 : i32
+ %cst1337 = arith.constant 1337 : i32
+
+ %cst_n1 = arith.constant -1 : i32
+ %cst_n13 = arith.constant -13 : i32
+ %cst_n1337 = arith.constant -1337 : i32
+
+ %cst_i16_min = arith.constant -32768 : i32
+
+ %cst_f32_int_max = arith.constant 16777217 : i32
+ %cst_f32_int_min = arith.constant -16777217 : i32
+
+ // CHECK: 0
+ func.call @check_sitofp(%cst0) : (i32) -> ()
+ // CHECK-NEXT: 1
+ func.call @check_sitofp(%cst1) : (i32) -> ()
+ // CHECK-NEXT: 2
+ func.call @check_sitofp(%cst2) : (i32) -> ()
+ // CHECK-NEXT: 7
+ func.call @check_sitofp(%cst7) : (i32) -> ()
+ // CHECK-NEXT: 1337
+ func.call @check_sitofp(%cst1337) : (i32) -> ()
+ // CHECK-NEXT: -1
+ func.call @check_sitofp(%cst_n1) : (i32) -> ()
+ // CHECK-NEXT: -1337
+ func.call @check_sitofp(%cst_n1337) : (i32) -> ()
+
+ // CHECK-NEXT: -32768
+ func.call @check_sitofp(%cst_i16_min) : (i32) -> ()
+ // CHECK-NEXT: 1.6{{.+}}e+07
+ func.call @check_sitofp(%cst_f32_int_max) : (i32) -> ()
+ // CHECK-NEXT: -1.6{{.+}}e+07
+ func.call @check_sitofp(%cst_f32_int_min) : (i32) -> ()
+
+ return
+}
More information about the Mlir-commits
mailing list