[Mlir-commits] [mlir] 2dde4ba - [mlir][math] Added algebraic simplification for IPowI operation.
Slava Zakharin
llvmlistbot at llvm.org
Mon Aug 15 11:56:17 PDT 2022
Author: Slava Zakharin
Date: 2022-08-15T11:55:05-07:00
New Revision: 2dde4ba63974daf59f8ce5c346505f194f920131
URL: https://github.com/llvm/llvm-project/commit/2dde4ba63974daf59f8ce5c346505f194f920131
DIFF: https://github.com/llvm/llvm-project/commit/2dde4ba63974daf59f8ce5c346505f194f920131.diff
LOG: [mlir][math] Added algebraic simplification for IPowI operation.
Differential Revision: https://reviews.llvm.org/D130390
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
mlir/test/Dialect/Math/algebraic-simplification.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index 1cefa6744facf..b967b8699b45d 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -112,9 +112,100 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
return failure();
}
+//----------------------------------------------------------------------------//
+// IPowIOp strength reduction.
+//----------------------------------------------------------------------------//
+
+namespace {
+struct IPowIStrengthReduction : public OpRewritePattern<math::IPowIOp> {
+ unsigned exponentThreshold;
+
+public:
+ IPowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
+ PatternBenefit benefit = 1,
+ ArrayRef<StringRef> generatedNames = {})
+ : OpRewritePattern<math::IPowIOp>(context, benefit, generatedNames),
+ exponentThreshold(exponentThreshold) {}
+ LogicalResult matchAndRewrite(math::IPowIOp op,
+ PatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+IPowIStrengthReduction::matchAndRewrite(math::IPowIOp op,
+ PatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ Value base = op.getLhs();
+
+ IntegerAttr scalarExponent;
+ DenseIntElementsAttr vectorExponent;
+
+ bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
+ bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
+
+ // Simplify cases with known exponent value.
+ int64_t exponentValue = 0;
+ if (isScalar)
+ exponentValue = scalarExponent.getInt();
+ else if (isVector && vectorExponent.isSplat())
+ exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
+ else
+ return failure();
+
+ // Maybe broadcasts scalar value into vector type compatible with `op`.
+ auto bcast = [&](Value value) -> Value {
+ if (auto vec = op.getType().dyn_cast<VectorType>())
+ return rewriter.create<vector::BroadcastOp>(loc, vec, value);
+ return value;
+ };
+
+ if (exponentValue == 0) {
+ // Replace `ipowi(x, 0)` with `1`.
+ Value one = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
+ rewriter.replaceOp(op, bcast(one));
+ return success();
+ }
+
+ bool exponentIsNegative = false;
+ if (exponentValue < 0) {
+ exponentIsNegative = true;
+ exponentValue *= -1;
+ }
+
+ // Bail out if `abs(exponent)` exceeds the threshold.
+ if (exponentValue > exponentThreshold)
+ return failure();
+
+ // Inverse the base for negative exponent, i.e. for
+ // `ipowi(x, negative_exponent)` set `x` to `1 / x`.
+ if (exponentIsNegative) {
+ Value one = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(getElementTypeOrSelf(op.getType()), 1));
+ base = rewriter.create<arith::DivSIOp>(loc, bcast(one), base);
+ }
+
+ Value result = base;
+ // Transform to naive sequence of multiplications:
+ // * For positive exponent case replace:
+ // `ipowi(x, positive_exponent)`
+ // with:
+ // x * x * x * ...
+ // * For negative exponent case replace:
+ // `ipowi(x, negative_exponent)`
+ // with:
+ // (1 / x) * (1 / x) * (1 / x) * ...
+ for (unsigned i = 1; i < exponentValue; ++i)
+ result = rewriter.create<arith::MulIOp>(loc, result, base);
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
//----------------------------------------------------------------------------//
void mlir::populateMathAlgebraicSimplificationPatterns(
RewritePatternSet &patterns) {
- patterns.add<PowFStrengthReduction>(patterns.getContext());
+ patterns.add<PowFStrengthReduction, IPowIStrengthReduction>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir
index d4f4efe0de1fa..106f7fd0c4d04 100644
--- a/mlir/test/Dialect/Math/algebraic-simplification.mlir
+++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir
@@ -73,3 +73,93 @@ func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
%1 = math.powf %arg1, %v : vector<4xf32>
return %0, %1 : f32, vector<4xf32>
}
+
+// CHECK-LABEL: @ipowi_zero_exp(
+// CHECK-SAME: %[[ARG0:.+]]: i32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
+// CHECK-SAME: -> (i32, vector<4xi32>) {
+func.func @ipowi_zero_exp(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>) {
+ // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
+ // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
+ // CHECK: return %[[CST_S]], %[[CST_V]]
+ %c = arith.constant 0 : i32
+ %v = arith.constant dense <0> : vector<4xi32>
+ %0 = math.ipowi %arg0, %c : i32
+ %1 = math.ipowi %arg1, %v : vector<4xi32>
+ return %0, %1 : i32, vector<4xi32>
+}
+
+// CHECK-LABEL: @ipowi_exp_one(
+// CHECK-SAME: %[[ARG0:.+]]: i32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
+// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+func.func @ipowi_exp_one(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+ // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
+ // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
+ // CHECK: %[[SCALAR:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
+ // CHECK: %[[VECTOR:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
+ // CHECK: return %[[ARG0]], %[[ARG1]], %[[SCALAR]], %[[VECTOR]]
+ %c1 = arith.constant 1 : i32
+ %v1 = arith.constant dense <1> : vector<4xi32>
+ %0 = math.ipowi %arg0, %c1 : i32
+ %1 = math.ipowi %arg1, %v1 : vector<4xi32>
+ %cm1 = arith.constant -1 : i32
+ %vm1 = arith.constant dense <-1> : vector<4xi32>
+ %2 = math.ipowi %arg0, %cm1 : i32
+ %3 = math.ipowi %arg1, %vm1 : vector<4xi32>
+ return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
+}
+
+// CHECK-LABEL: @ipowi_exp_two(
+// CHECK-SAME: %[[ARG0:.+]]: i32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
+// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+ // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
+ // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
+ // CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
+ // CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
+ // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
+ // CHECK: %[[SMUL:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
+ // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
+ // CHECK: %[[VMUL:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
+ // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]]
+ %c1 = arith.constant 2 : i32
+ %v1 = arith.constant dense <2> : vector<4xi32>
+ %0 = math.ipowi %arg0, %c1 : i32
+ %1 = math.ipowi %arg1, %v1 : vector<4xi32>
+ %cm1 = arith.constant -2 : i32
+ %vm1 = arith.constant dense <-2> : vector<4xi32>
+ %2 = math.ipowi %arg0, %cm1 : i32
+ %3 = math.ipowi %arg1, %vm1 : vector<4xi32>
+ return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
+}
+
+// CHECK-LABEL: @ipowi_exp_three(
+// CHECK-SAME: %[[ARG0:.+]]: i32
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>
+// CHECK-SAME: -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32>, i32, vector<4xi32>) {
+ // CHECK: %[[CST_S:.*]] = arith.constant 1 : i32
+ // CHECK: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32>
+ // CHECK: %[[SMUL0:.*]] = arith.muli %[[ARG0]], %[[ARG0]]
+ // CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]]
+ // CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]]
+ // CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]]
+ // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]]
+ // CHECK: %[[SMUL1:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]]
+ // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[SCALAR1]]
+ // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]]
+ // CHECK: %[[VMUL1:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]]
+ // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[VECTOR1]]
+ // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]]
+ %c1 = arith.constant 3 : i32
+ %v1 = arith.constant dense <3> : vector<4xi32>
+ %0 = math.ipowi %arg0, %c1 : i32
+ %1 = math.ipowi %arg1, %v1 : vector<4xi32>
+ %cm1 = arith.constant -3 : i32
+ %vm1 = arith.constant dense <-3> : vector<4xi32>
+ %2 = math.ipowi %arg0, %cm1 : i32
+ %3 = math.ipowi %arg1, %vm1 : vector<4xi32>
+ return %0, %1, %2, %3 : i32, vector<4xi32>, i32, vector<4xi32>
+}
More information about the Mlir-commits
mailing list