[Mlir-commits] [mlir] abfc358 - [mlir][arith] Add `sitofp` support to WIE

Jakub Kuderski llvmlistbot at llvm.org
Wed Mar 22 16:13:07 PDT 2023


Author: Jakub Kuderski
Date: 2023-03-22T19:12:16-04:00
New Revision: abfc358cff0c0cfc8ffbc6c164d97e13a18a1685

URL: https://github.com/llvm/llvm-project/commit/abfc358cff0c0cfc8ffbc6c164d97e13a18a1685
DIFF: https://github.com/llvm/llvm-project/commit/abfc358cff0c0cfc8ffbc6c164d97e13a18a1685.diff

LOG: [mlir][arith] Add `sitofp` support to WIE

This depends on the handling of `uitofp` in D146606.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D146597

Added: 
    mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-sitofp-i32.mlir

Modified: 
    mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
    mlir/test/Dialect/Arith/emulate-wide-int.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 83f01397c4490..781ea3d3eca63 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include <cassert>
@@ -907,6 +908,52 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertSIToFP
+//===----------------------------------------------------------------------===//
+
+struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    Value in = op.getIn();
+    Type oldTy = in.getType();
+    auto newTy =
+        dyn_cast_or_null<VectorType>(getTypeConverter()->convertType(oldTy));
+    if (!newTy)
+      return rewriter.notifyMatchFailure(
+          loc, llvm::formatv("unsupported type: {0}", oldTy));
+
+    unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth();
+    Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
+    Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1);
+    Value allOnesCst = createScalarOrSplatConstant(
+        rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth));
+
+    // To avoid operating on very large unsigned numbers, perform the
+    // conversion on the absolute value. Then, decide whether to negate the
+    // result or not based on that sign bit. We assume two's complement and
+    // implement negation by flipping all bits and adding 1.
+    // Note that this relies on the the other conversion patterns to legalize
+    // created ops and narrow the bit widths.
+    Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
+                                                 in, zeroCst);
+    Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst);
+    Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
+    Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
+
+    Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
+    Value negResult = rewriter.create<arith::NegFOp>(loc, absResult);
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
+                                                 absResult);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertUIToFP
 //===----------------------------------------------------------------------===//
@@ -1146,5 +1193,5 @@ void arith::populateArithWideIntEmulationPatterns(
       ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
       ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
       ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
-      ConvertUIToFP>(typeConverter, patterns.getContext());
+      ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Arith/emulate-wide-int.mlir b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
index 55b4e7f89b0ac..9fb5478d7e94f 100644
--- a/mlir/test/Dialect/Arith/emulate-wide-int.mlir
+++ b/mlir/test/Dialect/Arith/emulate-wide-int.mlir
@@ -964,3 +964,46 @@ func.func @uitofp_i64_f16(%a : i64) -> f16 {
     %r = arith.uitofp %a : i64 to f16
     return %r : f16
 }
+
+// CHECK-LABEL: func @sitofp_i64_f64
+// CHECK-SAME:    ([[ARG:%.+]]: vector<2xi32>) -> f64
+// CHECK:         [[VONES:%.+]]  = arith.constant dense<-1> : vector<2xi32>
+// CHECK:         [[ONES1:%.+]]  = vector.extract [[VONES]][0] : vector<2xi32>
+// CHECK-NEXT:    [[ONES2:%.+]]  = vector.extract [[VONES]][1] : vector<2xi32>
+// CHECK:                          arith.xori {{%.+}}, [[ONES1]] : i32
+// CHECK-NEXT:                     arith.xori {{%.+}}, [[ONES2]] : i32
+// CHECK:         [[CST0:%.+]]   = arith.constant 0 : i32
+// CHECK:         [[HIEQ0:%.+]]  = arith.cmpi eq, [[HI:%.+]], [[CST0]] : i32
+// CHECK-NEXT:    [[LOWFP:%.+]]  = arith.uitofp [[LOW:%.+]] : i32 to f64
+// CHECK-NEXT:    [[HIFP:%.+]]   = arith.uitofp [[HI]] : i32 to f64
+// CHECK-NEXT:    [[POW:%.+]]    = arith.constant 0x41F0000000000000 : f64
+// CHECK-NEXT:    [[RESHI:%.+]]  = arith.mulf [[HIFP]], [[POW]] : f64
+// CHECK-NEXT:    [[RES:%.+]]    = arith.addf [[LOWFP]], [[RESHI]] : f64
+// CHECK-NEXT:    [[SEL:%.+]]    = arith.select [[HIEQ0]], [[LOWFP]], [[RES]] : f64
+// CHECK-NEXT:    [[NEG:%.+]]    = arith.negf [[SEL]] : f64
+// CHECK-NEXT:    [[FINAL:%.+]]  = arith.select %{{.+}}, [[NEG]], [[SEL]] : f64
+// CHECK-NEXT:    return [[FINAL]] : f64
+func.func @sitofp_i64_f64(%a : i64) -> f64 {
+    %r = arith.sitofp %a : i64 to f64
+    return %r : f64
+}
+
+// CHECK-LABEL: func @sitofp_i64_f64_vector
+// CHECK-SAME:    ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xf64>
+// CHECK:         [[VONES:%.+]]  = arith.constant dense<-1> : vector<3x2xi32>
+// CHECK:                          arith.xori
+// CHECK-NEXT:                     arith.xori
+// CHECK:         [[HIEQ0:%.+]]  = arith.cmpi eq, [[HI:%.+]], [[CST0:%.+]] : vector<3xi32>
+// CHECK-NEXT:    [[LOWFP:%.+]]  = arith.uitofp [[LOW:%.+]] : vector<3xi32> to vector<3xf64>
+// CHECK-NEXT:    [[HIFP:%.+]]   = arith.uitofp [[HI:%.+]] : vector<3xi32> to vector<3xf64>
+// CHECK-NEXT:    [[POW:%.+]]    = arith.constant dense<0x41F0000000000000> : vector<3xf64>
+// CHECK-NEXT:    [[RESHI:%.+]]  = arith.mulf [[HIFP]], [[POW]] : vector<3xf64>
+// CHECK-NEXT:    [[RES:%.+]]    = arith.addf [[LOWFP]], [[RESHI]] : vector<3xf64>
+// CHECK-NEXT:    [[SEL:%.+]]    = arith.select [[HIEQ0]], [[LOWFP]], [[RES]] : vector<3xi1>, vector<3xf64>
+// CHECK-NEXT:    [[NEG:%.+]]    = arith.negf [[SEL]] : vector<3xf64>
+// CHECK-NEXT:    [[FINAL:%.+]]  = arith.select %{{.+}}, [[NEG]], [[SEL]] : vector<3xi1>, vector<3xf64>
+// CHECK-NEXT:    return [[FINAL]] : vector<3xf64>
+func.func @sitofp_i64_f64_vector(%a : vector<3xi64>) -> vector<3xf64> {
+    %r = arith.sitofp %a : vector<3xi64> to vector<3xf64>
+    return %r : vector<3xf64>
+}

diff  --git a/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-sitofp-i32.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-sitofp-i32.mlir
new file mode 100644
index 0000000000000..3fc008705f111
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-wide-int-emulation-sitofp-i32.mlir
@@ -0,0 +1,68 @@
+// Check that the wide integer `arith.sitofp` emulation produces the same result as wide
+// `arith.sitofp`. Emulate i32 ops with i16 ops.
+
+// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=16" \
+// RUN:             --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
+// RUN:             --convert-func-to-llvm --convert-arith-to-llvm | \
+// RUN:   mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:                   --shared-libs=%mlir_c_runner_utils | \
+// RUN:   FileCheck %s --match-full-lines
+
+// Ops in this function *only* will be emulated using i16 types.
+func.func @emulate_sitofp(%arg: i32) -> f32 {
+  %res = arith.sitofp %arg : i32 to f32
+  return %res : f32
+}
+
+func.func @check_sitofp(%arg : i32) -> () {
+  %res = func.call @emulate_sitofp(%arg) : (i32) -> (f32)
+  vector.print %res : f32
+  return
+}
+
+func.func @entry() {
+  %cst0 = arith.constant 0 : i32
+  %cst1 = arith.constant 1 : i32
+  %cst2 = arith.constant 2 : i32
+  %cst7 = arith.constant 7 : i32
+  %cst1337 = arith.constant 1337 : i32
+
+  %cst_n1 = arith.constant -1 : i32
+  %cst_n13 = arith.constant -13 : i32
+  %cst_n1337 = arith.constant -1337 : i32
+
+  %cst_i16_min = arith.constant -32768 : i32
+
+  %cst_f32_int_max = arith.constant 16777217 : i32
+  %cst_f32_int_min = arith.constant -16777217 : i32
+
+  // CHECK:      0
+  func.call @check_sitofp(%cst0) : (i32) -> ()
+  // CHECK-NEXT: 1
+  func.call @check_sitofp(%cst1) : (i32) -> ()
+  // CHECK-NEXT: 2
+  func.call @check_sitofp(%cst2) : (i32) -> ()
+  // CHECK-NEXT: 7
+  func.call @check_sitofp(%cst7) : (i32) -> ()
+  // CHECK-NEXT: 1337
+  func.call @check_sitofp(%cst1337) : (i32) -> ()
+  // CHECK-NEXT: -1
+  func.call @check_sitofp(%cst_n1) : (i32) -> ()
+  // CHECK-NEXT: -1337
+  func.call @check_sitofp(%cst_n1337) : (i32) -> ()
+
+  // CHECK-NEXT: -32768
+  func.call @check_sitofp(%cst_i16_min) : (i32) -> ()
+  // CHECK-NEXT: 1.6{{.+}}e+07
+  func.call @check_sitofp(%cst_f32_int_max) : (i32) -> ()
+  // CHECK-NEXT: -1.6{{.+}}e+07
+  func.call @check_sitofp(%cst_f32_int_min) : (i32) -> ()
+
+  return
+}


        


More information about the Mlir-commits mailing list