[Mlir-commits] [mlir] 2dde4ba - [mlir][math] Added algebraic simplification for IPowI operation.

Slava Zakharin llvmlistbot at llvm.org
Mon Aug 15 11:56:17 PDT 2022


Author: Slava Zakharin
Date: 2022-08-15T11:55:05-07:00
New Revision: 2dde4ba63974daf59f8ce5c346505f194f920131

URL: https://github.com/llvm/llvm-project/commit/2dde4ba63974daf59f8ce5c346505f194f920131
DIFF: https://github.com/llvm/llvm-project/commit/2dde4ba63974daf59f8ce5c346505f194f920131.diff

LOG: [mlir][math] Added algebraic simplification for IPowI operation.

Differential Revision: https://reviews.llvm.org/D130390

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
    mlir/test/Dialect/Math/algebraic-simplification.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 1cefa6744facf..b967b8699b45d 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -112,9 +112,100 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
   return failure();
 }
 
+//----------------------------------------------------------------------------//
+// IPowIOp strength reduction.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct IPowIStrengthReduction : public OpRewritePattern<math::IPowIOp> {
+  unsigned exponentThreshold;
+
+public:
+  IPowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
+                         PatternBenefit benefit = 1,
+                         ArrayRef<StringRef> generatedNames = {})
+      : OpRewritePattern<math::IPowIOp>(context, benefit, generatedNames),
+        exponentThreshold(exponentThreshold) {}
+  LogicalResult matchAndRewrite(math::IPowIOp op,
+                                PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
+                                        PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  Value base = op.getLhs();
+
+  IntegerAttr scalarExponent;
+  DenseIntElementsAttr vectorExponent;
+
+  bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
+  bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
+
+  // Simplify cases with known exponent value.
+  int64_t exponentValue = 0;
+  if (isScalar)
+    exponentValue = scalarExponent.getInt();
+  else if (isVector && vectorExponent.isSplat())
+    exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
+  else
+    return failure();
+
+  // Maybe broadcasts scalar value into vector type compatible with `op`.
+  auto bcast = [&](Value value) -> Value {
+    if (auto vec = op.getType().dyn_cast<VectorType>())
+      return rewriter.create<vector::BroadcastOp>(loc, vec, value);
+    return value;
+  };
+
+  if (exponentValue == 0) {
+    // Replace `ipowi(x, 0)` with `1`.
+    Value one = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
+    rewriter.replaceOp(op, bcast(one));
+    return success();
+  }
+
+  bool exponentIsNegative = false;
+  if (exponentValue < 0) {
+    exponentIsNegative = true;
+    exponentValue *= -1;
+  }
+
+  // Bail out if `abs(exponent)` exceeds the threshold.
+  if (exponentValue > exponentThreshold)
+    return failure();
+
+  // Inverse the base for negative exponent, i.e. for
+  // `ipowi(x, negative_exponent)` set `x` to `1 / x`.
+  if (exponentIsNegative) {
+    Value one = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
+    base = rewriter.create<arith::DivSIOp>(loc, bcast(one), base);
+  }
+
+  Value result = base;
+  // Transform to naive sequence of multiplications:
+  //   * For positive exponent case replace:
+  //       `ipowi(x, positive_exponent)`
+  //     with:
+  //       x * x * x * ...
+  //   * For negative exponent case replace:
+  //       `ipowi(x, negative_exponent)`
+  //     with:
+  //       (1 / x) * (1 / x) * (1 / x) * ...
+  for (unsigned i = 1; i < exponentValue; ++i)
+    result = rewriter.create<arith::MulIOp>(loc, result, base);
+
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
 //----------------------------------------------------------------------------//
 
 void mlir::populateMathAlgebraicSimplificationPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<PowFStrengthReduction>(patterns.getContext());
+  patterns.add<PowFStrengthReduction, IPowIStrengthReduction>(
+      patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index d4f4efe0de1fa..106f7fd0c4d04 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -73,3 +73,93 @@ func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
   %1 = math.powf %arg1, %v : vector<4xf32>
   return %0, %1 : f32, vector<4xf32>
 }
+
+// CHECK-LABEL: @ipowi_zero_exp(
+// CHECK-SAME: %[[ARG0:.+]]: i32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
+// CHECK-SAME: -> (i32, vector<4xi32>) {
+func.func @ipowi_zero_exp(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>) {
+  // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
+  // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
+  // CHECK: return %[[CST_S]], %[[CST_V]]
+  %c = arith.constant 0 : i32
+  %v = arith.constant dense <0> : vector<4xi32>
+  %0 = math.ipowi %arg0, %c : i32
+  %1 = math.ipowi %arg1, %v : vector<4xi32>
+  return %0, %1 : i32, vector<4xi32>
+}
+
+// CHECK-LABEL: @ipowi_exp_one(
+// CHECK-SAME: %[[ARG0:.+]]: i32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
+// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+func.func @ipowi_exp_one(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+  // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
+  // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
+  // CHECK: %[[SCALAR:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
+  // CHECK: %[[VECTOR:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
+  // CHECK: return %[[ARG0]], %[[ARG1]], %[[SCALAR]], %[[VECTOR]]
+  %c1 = arith.constant 1 : i32
+  %v1 = arith.constant dense <1> : vector<4xi32>
+  %0 = math.ipowi %arg0, %c1 : i32
+  %1 = math.ipowi %arg1, %v1 : vector<4xi32>
+  %cm1 = arith.constant -1 : i32
+  %vm1 = arith.constant dense <-1> : vector<4xi32>
+  %2 = math.ipowi %arg0, %cm1 : i32
+  %3 = math.ipowi %arg1, %vm1 : vector<4xi32>
+  return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
+}
+
+// CHECK-LABEL: @ipowi_exp_two(
+// CHECK-SAME: %[[ARG0:.+]]: i32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
+// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+  // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
+  // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
+  // CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
+  // CHECK: %[[SMUL:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
+  // CHECK: %[[VMUL:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
+  %c1 = arith.constant 2 : i32
+  %v1 = arith.constant dense <2> : vector<4xi32>
+  %0 = math.ipowi %arg0, %c1 : i32
+  %1 = math.ipowi %arg1, %v1 : vector<4xi32>
+  %cm1 = arith.constant -2 : i32
+  %vm1 = arith.constant dense <-2> : vector<4xi32>
+  %2 = math.ipowi %arg0, %cm1 : i32
+  %3 = math.ipowi %arg1, %vm1 : vector<4xi32>
+  return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
+}
+
+// CHECK-LABEL: @ipowi_exp_three(
+// CHECK-SAME: %[[ARG0:.+]]: i32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
+// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+  // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
+  // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
+  // CHECK: %[[SMUL0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
+  // CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]]
+  // CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
+  // CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]]
+  // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
+  // CHECK: %[[SMUL1:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
+  // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[SCALAR1]]
+  // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
+  // CHECK: %[[VMUL1:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
+  // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[VECTOR1]]
+  // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
+  %c1 = arith.constant 3 : i32
+  %v1 = arith.constant dense <3> : vector<4xi32>
+  %0 = math.ipowi %arg0, %c1 : i32
+  %1 = math.ipowi %arg1, %v1 : vector<4xi32>
+  %cm1 = arith.constant -3 : i32
+  %vm1 = arith.constant dense <-3> : vector<4xi32>
+  %2 = math.ipowi %arg0, %cm1 : i32
+  %3 = math.ipowi %arg1, %vm1 : vector<4xi32>
+  return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
+}


        


More information about the Mlir-commits mailing list