[Mlir-commits] [mlir] 909e5ce - [mlir][arith] Add `uitofp` support to WIE
Jakub Kuderski
llvmlistbot at llvm.org
Wed Mar 22 16:08:34 PDT 2023
Author: Jakub Kuderski
Date: 2023-03-22T19:04:10-04:00
New Revision: 909e5ce47a70181dead332826e93f89b2928f0c0
URL: https://github.com/llvm/llvm-project/commit/909e5ce47a70181dead332826e93f89b2928f0c0
DIFF: https://github.com/llvm/llvm-project/commit/909e5ce47a70181dead332826e93f89b2928f0c0.diff
LOG: [mlir][arith] Add `uitofp` support to WIE
This includes standard LIT tests and integration tests with the LLVM CPU
runner.
I plan to use this to implement `sitofp` in D146597.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D146606
Added:
mlir/test/Dialect/Arith/emulate-wide-int-canonicalization.mlir
mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-uitofp-i32.mlir
Modified:
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/test/Dialect/Arith/emulate-wide-int.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index db3ddab483b5a..83f01397c4490 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
@@ -906,6 +907,70 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertUIToFP
+//===----------------------------------------------------------------------===//
+
+struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+
+ Type oldTy = op.getIn().getType();
+ auto newTy =
+ dyn_cast_or_null<VectorType>(getTypeConverter()->convertType(oldTy));
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ loc, llvm::formatv("unsupported type: {0}", oldTy));
+ unsigned newBitWidth = newTy.getElementTypeBitWidth();
+
+ auto [low, hi] = extractLastDimHalves(rewriter, loc, adaptor.getIn());
+ Value lowInt = dropTrailingX1Dim(rewriter, loc, low);
+ Value hiInt = dropTrailingX1Dim(rewriter, loc, hi);
+ Value zeroCst =
+ createScalarOrSplatConstant(rewriter, loc, hiInt.getType(), 0);
+
+ // The final result has the following form:
+ // if (hi == 0) return uitofp(low)
+ // else return uitofp(low) + uitofp(hi) * 2^BW
+ //
+ // where `BW` is the bitwidth of the narrowed integer type. We emit a
+ // select to make it easier to fold-away the `hi` part calculation when it
+ // is known to be zero.
+ //
+ // Note 1: The emulation is precise only for input values that have exact
+ // integer representation in the result floating point type, and may lead
+ // loss of precision otherwise.
+ //
+ // Note 2: We do not strictly need the `hi == 0`, case, but it makes
+ // constant folding easier.
+ Value hiEqZero = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
+
+ Type resultTy = op.getType();
+ Type resultElemTy = getElementTypeOrSelf(resultTy);
+ Value lowFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, lowInt);
+ Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);
+
+ int64_t pow2Int = int64_t(1) << newBitWidth;
+ Attribute pow2Attr =
+ rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
+ if (auto vecTy = dyn_cast<VectorType>(resultTy))
+ pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
+
+ Value pow2Val = rewriter.create<arith::ConstantOp>(loc, resultTy, pow2Attr);
+
+ Value hiVal = rewriter.create<arith::MulFOp>(loc, hiFp, pow2Val);
+ Value result = rewriter.create<arith::AddFOp>(loc, lowFp, hiVal);
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertTruncI
//===----------------------------------------------------------------------===//
@@ -1080,6 +1145,6 @@ void arith::populateArithWideIntEmulationPatterns(
ConvertIndexCastIntToIndex<arith::IndexCastOp>,
ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
- ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>>(
- typeConverter, patterns.getContext());
+ ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
+ ConvertUIToFP>(typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int-canonicalization.mlir b/mlir/test/Dialect/Arith/emulate-wide-int-canonicalization.mlir
new file mode 100644
index 0000000000000..0c95ab8284afa
--- /dev/null
+++ b/mlir/test/Dialect/Arith/emulate-wide-int-canonicalization.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt --arith-emulate-wide-int="widest-int-supported=32" --canonicalize %s | FileCheck %s
+
+// Check that we can fold away the 'hi' part calculation when it is know to be zero.
+//
+// CHECK-LABEL: func @uitofp_i16_ext_f64
+// CHECK-SAME: ([[ARG:%.+]]: i16) -> f64
+// CHECK-NEXT: [[EXT:%.+]] = arith.extui [[ARG]] : i16 to i32
+// CHECK-NEXT: [[FP:%.+]] = arith.uitofp [[EXT]] : i32 to f64
+// CHECK-NEXT: return [[FP]] : f64
+func.func @uitofp_i16_ext_f64(%a : i16) -> f64 {
+ %ext = arith.extui %a : i16 to i64
+ %r = arith.uitofp %ext : i64 to f64
+ return %r : f64
+}
diff --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index 80edc6f2ad001..55b4e7f89b0ac 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -908,3 +908,59 @@ func.func @xori_vector_a_b(%a : vector<3xi64>, %b : vector<3xi64>) -> vector<3xi
%x = arith.xori %a, %b : vector<3xi64>
return %x : vector<3xi64>
}
+
+// CHECK-LABEL: func @uitofp_i64_f64
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> f64
+// CHECK-NEXT: [[LOW:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
+// CHECK-NEXT: [[HI:%.+]] = vector.extract [[ARG]][1] : vector<2xi32>
+// CHECK-NEXT: [[CST0:%.+]] = arith.constant 0 : i32
+// CHECK-NEXT: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI]], [[CST0]] : i32
+// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW]] : i32 to f64
+// CHECK-NEXT: [[HIFP:%.+]] = arith.uitofp [[HI]] : i32 to f64
+// CHECK-NEXT: [[POW:%.+]] = arith.constant 0x41F0000000000000 : f64
+// CHECK-NEXT: [[RESHI:%.+]] = arith.mulf [[HIFP]], [[POW]] : f64
+// CHECK-NEXT: [[RES:%.+]] = arith.addf [[LOWFP]], [[RESHI]] : f64
+// CHECK-NEXT: [[SEL:%.+]] = arith.select [[HIEQ0]], [[LOWFP]], [[RES]] : f64
+// CHECK-NEXT: return [[SEL]] : f64
+func.func @uitofp_i64_f64(%a : i64) -> f64 {
+ %r = arith.uitofp %a : i64 to f64
+ return %r : f64
+}
+
+// CHECK-LABEL: func @uitofp_i64_f64_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xf64>
+// CHECK-NEXT: [[EXTLOW:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK-NEXT: [[EXTHI:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 1], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
+// CHECK-NEXT: [[LOW:%.+]] = vector.shape_cast [[EXTLOW]] : vector<3x1xi32> to vector<3xi32>
+// CHECK-NEXT: [[HI:%.+]] = vector.shape_cast [[EXTHI]] : vector<3x1xi32> to vector<3xi32>
+// CHECK-NEXT: [[CST0:%.+]] = arith.constant dense<0> : vector<3xi32>
+// CHECK-NEXT: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI]], [[CST0]] : vector<3xi32>
+// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW]] : vector<3xi32> to vector<3xf64>
+// CHECK-NEXT: [[HIFP:%.+]] = arith.uitofp [[HI]] : vector<3xi32> to vector<3xf64>
+// CHECK-NEXT: [[POW:%.+]] = arith.constant dense<0x41F0000000000000> : vector<3xf64>
+// CHECK-NEXT: [[RESHI:%.+]] = arith.mulf [[HIFP]], [[POW]] : vector<3xf64>
+// CHECK-NEXT: [[RES:%.+]] = arith.addf [[LOWFP]], [[RESHI]] : vector<3xf64>
+// CHECK-NEXT: [[SEL:%.+]] = arith.select [[HIEQ0]], [[LOWFP]], [[RES]] : vector<3xi1>, vector<3xf64>
+// CHECK-NEXT: return [[SEL]] : vector<3xf64>
+func.func @uitofp_i64_f64_vector(%a : vector<3xi64>) -> vector<3xf64> {
+ %r = arith.uitofp %a : vector<3xi64> to vector<3xf64>
+ return %r : vector<3xf64>
+}
+
+// CHECK-LABEL: func @uitofp_i64_f16
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> f16
+// CHECK-NEXT: [[LOW:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
+// CHECK-NEXT: [[HI:%.+]] = vector.extract [[ARG]][1] : vector<2xi32>
+// CHECK-NEXT: [[CST0:%.+]] = arith.constant 0 : i32
+// CHECK-NEXT: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI]], [[CST0]] : i32
+// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW]] : i32 to f16
+// CHECK-NEXT: [[HIFP:%.+]] = arith.uitofp [[HI]] : i32 to f16
+// CHECK-NEXT: [[POW:%.+]] = arith.constant 0x7C00 : f16
+// CHECK-NEXT: [[RESHI:%.+]] = arith.mulf [[HIFP]], [[POW]] : f16
+// CHECK-NEXT: [[RES:%.+]] = arith.addf [[LOWFP]], [[RESHI]] : f16
+// CHECK-NEXT: [[SEL:%.+]] = arith.select [[HIEQ0]], [[LOWFP]], [[RES]] : f16
+// CHECK-NEXT: return [[SEL]] : f16
+func.func @uitofp_i64_f16(%a : i64) -> f16 {
+ %r = arith.uitofp %a : i64 to f16
+ return %r : f16
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-uitofp-i32.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-uitofp-i32.mlir
new file mode 100644
index 0000000000000..c3d7db0de6d20
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-uitofp-i32.mlir
@@ -0,0 +1,77 @@
+// Check that the wide integer `arith.uitofp` emulation produces the same result as wide
+// `arith.uitofp`. Emulate i32 ops with i16 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=16" \
+// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --match-full-lines
+
+// Ops in this function *only* will be emulated using i16 types.
+func.func @emulate_uitofp(%arg: i32) -> f32 {
+ %res = arith.uitofp %arg : i32 to f32
+ return %res : f32
+}
+
+func.func @check_uitofp(%arg : i32) -> () {
+ %res = func.call @emulate_uitofp(%arg) : (i32) -> (f32)
+ vector.print %res : f32
+ return
+}
+
+func.func @entry() {
+ %cst0 = arith.constant 0 : i32
+ %cst1 = arith.constant 1 : i32
+ %cst2 = arith.constant 2 : i32
+ %cst7 = arith.constant 7 : i32
+ %cst1337 = arith.constant 1337 : i32
+ %cst_i16_max = arith.constant 65535 : i32
+ %cst_i16_overflow = arith.constant 65536 : i32
+
+ %cst_n1 = arith.constant -1 : i32
+ %cst_n13 = arith.constant -13 : i32
+ %cst_n1337 = arith.constant -1337 : i32
+
+ %cst_i16_min = arith.constant -32768 : i32
+
+ %cst_f32_int_max = arith.constant 16777217 : i32
+ %cst_f32_int_min = arith.constant -16777217 : i32
+
+ // CHECK: 0
+ func.call @check_uitofp(%cst0) : (i32) -> ()
+ // CHECK-NEXT: 1
+ func.call @check_uitofp(%cst1) : (i32) -> ()
+ // CHECK-NEXT: 2
+ func.call @check_uitofp(%cst2) : (i32) -> ()
+ // CHECK-NEXT: 7
+ func.call @check_uitofp(%cst7) : (i32) -> ()
+ // CHECK-NEXT: 1337
+ func.call @check_uitofp(%cst1337) : (i32) -> ()
+ // CHECK-NEXT: 65535
+ func.call @check_uitofp(%cst_i16_max) : (i32) -> ()
+ // CHECK-NEXT: 65536
+ func.call @check_uitofp(%cst_i16_overflow) : (i32) -> ()
+
+ // CHECK-NEXT: 4.2{{.+}}e+09
+ func.call @check_uitofp(%cst_n1) : (i32) -> ()
+ // CHECK-NEXT: 4.2{{.+}}e+09
+ func.call @check_uitofp(%cst_n1337) : (i32) -> ()
+
+ // CHECK-NEXT: 4.2{{.+}}e+09
+ func.call @check_uitofp(%cst_i16_min) : (i32) -> ()
+ // CHECK-NEXT: 4.2{{.+}}e+09
+ func.call @check_uitofp(%cst_i16_min) : (i32) -> ()
+ // CHECK-NEXT: 1.6{{.+}}e+07
+ func.call @check_uitofp(%cst_f32_int_max) : (i32) -> ()
+ // CHECK-NEXT: 4.2{{.+}}e+09
+ func.call @check_uitofp(%cst_f32_int_min) : (i32) -> ()
+
+ return
+}
More information about the Mlir-commits
mailing list