[Mlir-commits] [mlir] c1ebefd - [mlir] Make polynomial approximation emit std instead of LLVM ops

Benjamin Kramer llvmlistbot at llvm.org
Wed Aug 11 07:38:48 PDT 2021


Author: Benjamin Kramer
Date: 2021-08-11T16:37:21+02:00
New Revision: c1ebefdf77f34cc0b23597071098c8f8a8d2839b

URL: https://github.com/llvm/llvm-project/commit/c1ebefdf77f34cc0b23597071098c8f8a8d2839b
DIFF: https://github.com/llvm/llvm-project/commit/c1ebefdf77f34cc0b23597071098c8f8a8d2839b.diff

LOG: [mlir] Make polynomial approximation emit std instead of LLVM ops

This is a bit cleaner and removes issues with 2d vectors. It also has a
big impact on constant folding, hence the test changes.

Differential Revision: https://reviews.llvm.org/D107896

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
    mlir/test/Dialect/Math/polynomial-approximation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 6eece6e0f7b3b..77cbb9e1f770d 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -8,7 +8,6 @@ add_mlir_dialect_library(MLIRMathTransforms
 
   LINK_LIBS PUBLIC
   MLIRIR
-  MLIRLLVMIR
   MLIRMath
   MLIRPass
   MLIRStandard

diff  --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 7844d8c475d2d..08fb06087767b 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -11,8 +11,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
@@ -96,7 +94,7 @@ static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
 
 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
   Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
-  return builder.create<LLVM::BitcastOp>(builder.getF32Type(), i32Value);
+  return builder.create<BitcastOp>(builder.getF32Type(), i32Value);
 }
 
 //----------------------------------------------------------------------------//
@@ -139,20 +137,19 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
   Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
 
   // Bitcast to i32 for bitwise operations.
-  Value i32Half = builder.create<LLVM::BitcastOp>(i32, cstHalf);
-  Value i32InvMantMask = builder.create<LLVM::BitcastOp>(i32, cstInvMantMask);
-  Value i32Arg = builder.create<LLVM::BitcastOp>(i32Vec, arg);
+  Value i32Half = builder.create<BitcastOp>(i32, cstHalf);
+  Value i32InvMantMask = builder.create<BitcastOp>(i32, cstInvMantMask);
+  Value i32Arg = builder.create<BitcastOp>(i32Vec, arg);
 
   // Compute normalized fraction.
-  Value tmp0 = builder.create<LLVM::AndOp>(i32Arg, bcast(i32InvMantMask));
-  Value tmp1 = builder.create<LLVM::OrOp>(tmp0, bcast(i32Half));
-  Value normalizedFraction = builder.create<LLVM::BitcastOp>(f32Vec, tmp1);
+  Value tmp0 = builder.create<AndOp>(i32Arg, bcast(i32InvMantMask));
+  Value tmp1 = builder.create<OrOp>(tmp0, bcast(i32Half));
+  Value normalizedFraction = builder.create<BitcastOp>(f32Vec, tmp1);
 
   // Compute exponent.
   Value arg0 = is_positive ? arg : builder.create<AbsFOp>(arg);
   Value biasedExponentBits = builder.create<UnsignedShiftRightOp>(
-      builder.create<LLVM::BitcastOp>(i32Vec, arg0),
-      bcast(i32Cst(builder, 23)));
+      builder.create<BitcastOp>(i32Vec, arg0), bcast(i32Cst(builder, 23)));
   Value biasedExponent = builder.create<SIToFPOp>(f32Vec, biasedExponentBits);
   Value exponent = builder.create<SubFOp>(biasedExponent, bcast(cst126f));
 
@@ -178,7 +175,7 @@ static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
   Value biasedArg = builder.create<AddIOp>(arg, bias);
   Value exp2ValueInt =
       builder.create<ShiftLeftOp>(biasedArg, exponetBitLocation);
-  Value exp2ValueF32 = builder.create<LLVM::BitcastOp>(f32Vec, exp2ValueInt);
+  Value exp2ValueF32 = builder.create<BitcastOp>(f32Vec, exp2ValueInt);
 
   return exp2ValueF32;
 }
@@ -454,8 +451,8 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
   Value uInf = builder.create<CmpFOp>(CmpFPredicate::OEQ, u, logU);
   Value logLarge = builder.create<MulFOp>(
       x, builder.create<DivFOp>(logU, builder.create<SubFOp>(u, cstOne)));
-  Value approximation = builder.create<SelectOp>(
-      builder.create<LLVM::OrOp>(uSmall, uInf), x, logLarge);
+  Value approximation =
+      builder.create<SelectOp>(builder.create<OrOp>(uSmall, uInf), x, logLarge);
   rewriter.replaceOp(op, approximation);
   return success();
 }

diff  --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 296f9fcdc2918..a3092bfeb2954 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -33,7 +33,7 @@
 // CHECK:           %[[VAL_26:.*]] = fptosi %[[VAL_16]] : f32 to i32
 // CHECK:           %[[VAL_27:.*]] = addi %[[VAL_26]], %[[VAL_13]] : i32
 // CHECK:           %[[VAL_28:.*]] = shift_left %[[VAL_27]], %[[VAL_8]] : i32
-// CHECK:           %[[VAL_29:.*]] = llvm.bitcast %[[VAL_28]] : i32 to f32
+// CHECK:           %[[VAL_29:.*]] = bitcast %[[VAL_28]] : i32 to f32
 // CHECK:           %[[VAL_30:.*]] = mulf %[[VAL_25]], %[[VAL_29]] : f32
 // CHECK:           %[[VAL_31:.*]] = cmpi sle, %[[VAL_26]], %[[VAL_13]] : i32
 // CHECK:           %[[VAL_32:.*]] = cmpi sge, %[[VAL_26]], %[[VAL_14]] : i32
@@ -97,10 +97,9 @@ func @expm1_scalar(%arg0: f32) -> f32 {
 // CHECK-NOT:       exp
 // CHECK-COUNT-3:   select
 // CHECK-NOT:       log
-// CHECK-COUNT-6:   vector.broadcast
 // CHECK-COUNT-5:   select
 // CHECK-NOT:       expm1
-// CHECK-COUNT-2:   select
+// CHECK-COUNT-3:   select
 // CHECK:           %[[VAL_115:.*]] = select
 // CHECK:           return %[[VAL_115]] : vector<8xf32>
 // CHECK:         }
@@ -110,14 +109,14 @@ func @expm1_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
 }
 
 // CHECK-LABEL:   func @log_scalar(
-// CHECK-SAME:                     %[[X:.*]]: f32) -> f32 {
+// CHECK-SAME:                             %[[X:.*]]: f32) -> f32 {
 // CHECK:           %[[VAL_1:.*]] = constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_2:.*]] = constant 1.000000e+00 : f32
 // CHECK:           %[[VAL_3:.*]] = constant -5.000000e-01 : f32
-// CHECK:           %[[VAL_4:.*]] = constant 8388608 : i32
-// CHECK:           %[[VAL_5:.*]] = constant -8388608 : i32
-// CHECK:           %[[VAL_6:.*]] = constant 2139095040 : i32
-// CHECK:           %[[VAL_7:.*]] = constant 2143289344 : i32
+// CHECK:           %[[VAL_4:.*]] = constant 1.17549435E-38 : f32
+// CHECK:           %[[VAL_5:.*]] = constant 0xFF800000 : f32
+// CHECK:           %[[VAL_6:.*]] = constant 0x7F800000 : f32
+// CHECK:           %[[VAL_7:.*]] = constant 0x7FC00000 : f32
 // CHECK:           %[[VAL_8:.*]] = constant 0.707106769 : f32
 // CHECK:           %[[VAL_9:.*]] = constant 0.0703768358 : f32
 // CHECK:           %[[VAL_10:.*]] = constant -0.115146101 : f32
@@ -129,55 +128,48 @@ func @expm1_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
 // CHECK:           %[[VAL_16:.*]] = constant -0.24999994 : f32
 // CHECK:           %[[VAL_17:.*]] = constant 0.333333313 : f32
 // CHECK:           %[[VAL_18:.*]] = constant 1.260000e+02 : f32
-// CHECK:           %[[VAL_19:.*]] = constant 5.000000e-01 : f32
-// CHECK:           %[[VAL_20:.*]] = constant -2139095041 : i32
+// CHECK:           %[[VAL_19:.*]] = constant -2139095041 : i32
+// CHECK:           %[[VAL_20:.*]] = constant 1056964608 : i32
 // CHECK:           %[[VAL_21:.*]] = constant 23 : i32
-// CHECK:           %[[CST_LN2:.*]] = constant 0.693147182 : f32
-// CHECK:           %[[VAL_23:.*]] = llvm.bitcast %[[VAL_4]] : i32 to f32
-// CHECK:           %[[VAL_24:.*]] = llvm.bitcast %[[VAL_5]] : i32 to f32
-// CHECK:           %[[VAL_25:.*]] = llvm.bitcast %[[VAL_6]] : i32 to f32
-// CHECK:           %[[VAL_26:.*]] = llvm.bitcast %[[VAL_7]] : i32 to f32
-// CHECK:           %[[VAL_27:.*]] = cmpf ogt, %[[X]], %[[VAL_23]] : f32
-// CHECK:           %[[VAL_28:.*]] = select %[[VAL_27]], %[[X]], %[[VAL_23]] : f32
+// CHECK:           %[[VAL_22:.*]] = constant 0.693147182 : f32
+// CHECK:           %[[VAL_23:.*]] = cmpf ogt, %[[X]], %[[VAL_4]] : f32
+// CHECK:           %[[VAL_24:.*]] = select %[[VAL_23]], %[[X]], %[[VAL_4]] : f32
 // CHECK-NOT:       frexp
-// CHECK:           %[[VAL_29:.*]] = llvm.bitcast %[[VAL_20]] : i32 to f32
-// CHECK:           %[[VAL_30:.*]] = llvm.bitcast %[[VAL_19]] : f32 to i32
-// CHECK:           %[[VAL_31:.*]] = llvm.bitcast %[[VAL_29]] : f32 to i32
-// CHECK:           %[[VAL_32:.*]] = llvm.bitcast %[[VAL_28]] : f32 to i32
-// CHECK:           %[[VAL_33:.*]] = llvm.and %[[VAL_32]], %[[VAL_31]]  : i32
-// CHECK:           %[[VAL_34:.*]] = llvm.or %[[VAL_33]], %[[VAL_30]]  : i32
-// CHECK:           %[[VAL_35:.*]] = llvm.bitcast %[[VAL_34]] : i32 to f32
-// CHECK:           %[[VAL_36:.*]] = llvm.bitcast %[[VAL_28]] : f32 to i32
-// CHECK:           %[[VAL_37:.*]] = shift_right_unsigned %[[VAL_36]], %[[VAL_21]] : i32
-// CHECK:           %[[FREXP_X:.*]] = sitofp %[[VAL_37]] : i32 to f32
-// CHECK:           %[[VAL_39:.*]] = subf %[[FREXP_X]], %[[VAL_18]] : f32
-// CHECK:           %[[VAL_40:.*]] = cmpf olt, %[[VAL_35]], %[[VAL_8]] : f32
-// CHECK:           %[[VAL_41:.*]] = select %[[VAL_40]], %[[VAL_35]], %[[VAL_1]] : f32
-// CHECK:           %[[VAL_42:.*]] = subf %[[VAL_35]], %[[VAL_2]] : f32
-// CHECK:           %[[VAL_43:.*]] = select %[[VAL_40]], %[[VAL_2]], %[[VAL_1]] : f32
-// CHECK:           %[[VAL_44:.*]] = subf %[[VAL_39]], %[[VAL_43]] : f32
-// CHECK:           %[[VAL_45:.*]] = addf %[[VAL_42]], %[[VAL_41]] : f32
-// CHECK:           %[[VAL_46:.*]] = mulf %[[VAL_45]], %[[VAL_45]] : f32
-// CHECK:           %[[VAL_47:.*]] = mulf %[[VAL_46]], %[[VAL_45]] : f32
-// CHECK:           %[[VAL_48:.*]] = fmaf %[[VAL_9]], %[[VAL_45]], %[[VAL_10]] : f32
-// CHECK:           %[[VAL_49:.*]] = fmaf %[[VAL_12]], %[[VAL_45]], %[[VAL_13]] : f32
-// CHECK:           %[[VAL_50:.*]] = fmaf %[[VAL_15]], %[[VAL_45]], %[[VAL_16]] : f32
-// CHECK:           %[[VAL_51:.*]] = fmaf %[[VAL_48]], %[[VAL_45]], %[[VAL_11]] : f32
-// CHECK:           %[[VAL_52:.*]] = fmaf %[[VAL_49]], %[[VAL_45]], %[[VAL_14]] : f32
-// CHECK:           %[[VAL_53:.*]] = fmaf %[[VAL_50]], %[[VAL_45]], %[[VAL_17]] : f32
-// CHECK:           %[[VAL_54:.*]] = fmaf %[[VAL_51]], %[[VAL_47]], %[[VAL_52]] : f32
-// CHECK:           %[[VAL_55:.*]] = fmaf %[[VAL_54]], %[[VAL_47]], %[[VAL_53]] : f32
-// CHECK:           %[[VAL_56:.*]] = mulf %[[VAL_55]], %[[VAL_47]] : f32
-// CHECK:           %[[VAL_57:.*]] = fmaf %[[VAL_3]], %[[VAL_46]], %[[VAL_56]] : f32
-// CHECK:           %[[VAL_58:.*]] = addf %[[VAL_45]], %[[VAL_57]] : f32
-// CHECK:           %[[VAL_59:.*]] = fmaf %[[VAL_44]], %[[CST_LN2]], %[[VAL_58]] : f32
-// CHECK:           %[[VAL_60:.*]] = cmpf ult, %[[X]], %[[VAL_1]] : f32
-// CHECK:           %[[VAL_61:.*]] = cmpf oeq, %[[X]], %[[VAL_1]] : f32
-// CHECK:           %[[VAL_62:.*]] = cmpf oeq, %[[X]], %[[VAL_25]] : f32
-// CHECK:           %[[VAL_63:.*]] = select %[[VAL_62]], %[[VAL_25]], %[[VAL_59]] : f32
-// CHECK:           %[[VAL_64:.*]] = select %[[VAL_60]], %[[VAL_26]], %[[VAL_63]] : f32
-// CHECK:           %[[VAL_65:.*]] = select %[[VAL_61]], %[[VAL_24]], %[[VAL_64]] : f32
-// CHECK:           return %[[VAL_65]] : f32
+// CHECK:           %[[VAL_25:.*]] = bitcast %[[VAL_24]] : f32 to i32
+// CHECK:           %[[VAL_26:.*]] = and %[[VAL_25]], %[[VAL_19]] : i32
+// CHECK:           %[[VAL_27:.*]] = or %[[VAL_26]], %[[VAL_20]] : i32
+// CHECK:           %[[VAL_28:.*]] = bitcast %[[VAL_27]] : i32 to f32
+// CHECK:           %[[VAL_29:.*]] = bitcast %[[VAL_24]] : f32 to i32
+// CHECK:           %[[VAL_30:.*]] = shift_right_unsigned %[[VAL_29]], %[[VAL_21]] : i32
+// CHECK:           %[[VAL_31:.*]] = sitofp %[[VAL_30]] : i32 to f32
+// CHECK:           %[[VAL_32:.*]] = subf %[[VAL_31]], %[[VAL_18]] : f32
+// CHECK:           %[[VAL_33:.*]] = cmpf olt, %[[VAL_28]], %[[VAL_8]] : f32
+// CHECK:           %[[VAL_34:.*]] = select %[[VAL_33]], %[[VAL_28]], %[[VAL_1]] : f32
+// CHECK:           %[[VAL_35:.*]] = subf %[[VAL_28]], %[[VAL_2]] : f32
+// CHECK:           %[[VAL_36:.*]] = select %[[VAL_33]], %[[VAL_2]], %[[VAL_1]] : f32
+// CHECK:           %[[VAL_37:.*]] = subf %[[VAL_32]], %[[VAL_36]] : f32
+// CHECK:           %[[VAL_38:.*]] = addf %[[VAL_35]], %[[VAL_34]] : f32
+// CHECK:           %[[VAL_39:.*]] = mulf %[[VAL_38]], %[[VAL_38]] : f32
+// CHECK:           %[[VAL_40:.*]] = mulf %[[VAL_39]], %[[VAL_38]] : f32
+// CHECK:           %[[VAL_41:.*]] = fmaf %[[VAL_9]], %[[VAL_38]], %[[VAL_10]] : f32
+// CHECK:           %[[VAL_42:.*]] = fmaf %[[VAL_12]], %[[VAL_38]], %[[VAL_13]] : f32
+// CHECK:           %[[VAL_43:.*]] = fmaf %[[VAL_15]], %[[VAL_38]], %[[VAL_16]] : f32
+// CHECK:           %[[VAL_44:.*]] = fmaf %[[VAL_41]], %[[VAL_38]], %[[VAL_11]] : f32
+// CHECK:           %[[VAL_45:.*]] = fmaf %[[VAL_42]], %[[VAL_38]], %[[VAL_14]] : f32
+// CHECK:           %[[VAL_46:.*]] = fmaf %[[VAL_43]], %[[VAL_38]], %[[VAL_17]] : f32
+// CHECK:           %[[VAL_47:.*]] = fmaf %[[VAL_44]], %[[VAL_40]], %[[VAL_45]] : f32
+// CHECK:           %[[VAL_48:.*]] = fmaf %[[VAL_47]], %[[VAL_40]], %[[VAL_46]] : f32
+// CHECK:           %[[VAL_49:.*]] = mulf %[[VAL_48]], %[[VAL_40]] : f32
+// CHECK:           %[[VAL_50:.*]] = fmaf %[[VAL_3]], %[[VAL_39]], %[[VAL_49]] : f32
+// CHECK:           %[[VAL_51:.*]] = addf %[[VAL_38]], %[[VAL_50]] : f32
+// CHECK:           %[[VAL_52:.*]] = fmaf %[[VAL_37]], %[[VAL_22]], %[[VAL_51]] : f32
+// CHECK:           %[[VAL_53:.*]] = cmpf ult, %[[X]], %[[VAL_1]] : f32
+// CHECK:           %[[VAL_54:.*]] = cmpf oeq, %[[X]], %[[VAL_1]] : f32
+// CHECK:           %[[VAL_55:.*]] = cmpf oeq, %[[X]], %[[VAL_6]] : f32
+// CHECK:           %[[VAL_56:.*]] = select %[[VAL_55]], %[[VAL_6]], %[[VAL_52]] : f32
+// CHECK:           %[[VAL_57:.*]] = select %[[VAL_53]], %[[VAL_7]], %[[VAL_56]] : f32
+// CHECK:           %[[VAL_58:.*]] = select %[[VAL_54]], %[[VAL_5]], %[[VAL_57]] : f32
+// CHECK:           return %[[VAL_58]] : f32
 // CHECK:         }
 func @log_scalar(%arg0: f32) -> f32 {
   %0 = math.log %arg0 : f32
@@ -187,8 +179,7 @@ func @log_scalar(%arg0: f32) -> f32 {
 // CHECK-LABEL:   func @log_vector(
 // CHECK-SAME:                     %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
 // CHECK:           %[[CST_LN2:.*]] = constant dense<0.693147182> : vector<8xf32>
-// CHECK-COUNT-6:   vector.broadcast
-// CHECK-COUNT-4:   select
+// CHECK-COUNT-5:   select
 // CHECK:           %[[VAL_71:.*]] = select
 // CHECK:           return %[[VAL_71]] : vector<8xf32>
 // CHECK:         }
@@ -212,8 +203,7 @@ func @log2_scalar(%arg0: f32) -> f32 {
 // CHECK-LABEL:   func @log2_vector(
 // CHECK-SAME:                      %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
 // CHECK:           %[[CST_LOG2E:.*]] = constant dense<1.44269502> : vector<8xf32>
-// CHECK-COUNT-6:   vector.broadcast
-// CHECK-COUNT-4:   select
+// CHECK-COUNT-5:   select
 // CHECK:           %[[VAL_71:.*]] = select
 // CHECK:           return %[[VAL_71]] : vector<8xf32>
 // CHECK:         }
@@ -234,7 +224,7 @@ func @log2_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
 // CHECK:           %[[VAL_69:.*]] = subf %[[U]], %[[CST_ONE]] : f32
 // CHECK:           %[[VAL_70:.*]] = divf %[[LOG_U]], %[[VAL_69]] : f32
 // CHECK:           %[[LOG_LARGE:.*]] = mulf %[[X]], %[[VAL_70]] : f32
-// CHECK:           %[[VAL_72:.*]] = llvm.or %[[U_SMALL]], %[[U_INF]]  : i1
+// CHECK:           %[[VAL_72:.*]] = or %[[U_SMALL]], %[[U_INF]]  : i1
 // CHECK:           %[[APPROX:.*]] = select %[[VAL_72]], %[[X]], %[[LOG_LARGE]] : f32
 // CHECK:           return %[[APPROX]] : f32
 // CHECK:         }
@@ -246,8 +236,7 @@ func @log1p_scalar(%arg0: f32) -> f32 {
 // CHECK-LABEL:   func @log1p_vector(
 // CHECK-SAME:                       %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
 // CHECK:           %[[CST_ONE:.*]] = constant dense<1.000000e+00> : vector<8xf32>
-// CHECK-COUNT-6:   vector.broadcast
-// CHECK-COUNT-5:   select
+// CHECK-COUNT-6:   select
 // CHECK:           %[[VAL_79:.*]] = select
 // CHECK:           return %[[VAL_79]] : vector<8xf32>
 // CHECK:         }


        


More information about the Mlir-commits mailing list