[Mlir-commits] [mlir] 711c589 - [mlir][math] Update math arith expansions for vectorization
Robert Suderman
llvmlistbot at llvm.org
Thu Apr 6 11:43:04 PDT 2023
Author: Robert Suderman
Date: 2023-04-06T18:42:01Z
New Revision: 711c58938f36d91af1dc4209946bcf5e70869445
URL: https://github.com/llvm/llvm-project/commit/711c58938f36d91af1dc4209946bcf5e70869445
DIFF: https://github.com/llvm/llvm-project/commit/711c58938f36d91af1dc4209946bcf5e70869445.diff
LOG: [mlir][math] Update math arith expansions for vectorization
The math arithmetic expansions do not support vectorized types.
Updated the lowerings so that they support vectorized types. This
includes a different implementation for `math.ctlz` to be a binary
search and not have variable termination time.
Reviewed By: jpienaar, NatashaKnk
Differential Revision: https://reviews.llvm.org/D147289
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 364dd05c093ba..91aef84348a96 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -14,22 +14,46 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
+/// Create a float constant.
+static Value createFloatConst(Location loc, Type type, double value,
+ OpBuilder &b) {
+ auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value);
+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+ return b.create<arith::ConstantOp>(loc,
+ DenseElementsAttr::get(shapedTy, attr));
+ }
+
+ return b.create<arith::ConstantOp>(loc, attr);
+}
+
+/// Create a float constant.
+static Value createIntConst(Location loc, Type type, int64_t value,
+ OpBuilder &b) {
+ auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+ return b.create<arith::ConstantOp>(loc,
+ DenseElementsAttr::get(shapedTy, attr));
+ }
+
+ return b.create<arith::ConstantOp>(loc, attr);
+}
+
/// Expands tanh op into
/// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
/// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
auto floatType = op.getOperand().getType();
Location loc = op.getLoc();
- auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
- auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
- Value one = rewriter.create<arith::ConstantOp>(loc, floatOne);
- Value two = rewriter.create<arith::ConstantOp>(loc, floatTwo);
+ Value one = createFloatConst(loc, floatType, 1.0, rewriter);
+ Value two = createFloatConst(loc, floatType, 2.0, rewriter);
Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
// Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
@@ -46,8 +70,7 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
// tanh(x) = x >= 0 ? positiveRes : negativeRes
- auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
- Value zero = rewriter.create<arith::ConstantOp>(loc, floatZero);
+ Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
op.getOperand(), zero);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
@@ -55,6 +78,7 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
return success();
}
+// Converts math.tan to math.sin, math.cos, and arith.divf.
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
@@ -66,52 +90,47 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
return success();
}
+// Converts math.ctlz to scf and arith operations. This is done
+// by performing a binary search on the bits.
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
PatternRewriter &rewriter) {
auto operand = op.getOperand();
- auto elementTy = operand.getType();
- auto resultTy = op.getType();
+ auto operandTy = operand.getType();
+ auto eTy = getElementTypeOrSelf(operandTy);
Location loc = op.getLoc();
- int bitWidth = elementTy.getIntOrFloatBitWidth();
- auto zero =
- rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
- auto leadingZeros = rewriter.create<arith::ConstantOp>(
- loc, IntegerAttr::get(elementTy, bitWidth));
-
- SmallVector<Value> operands = {operand, leadingZeros, zero};
- SmallVector<Type> types = {elementTy, elementTy, elementTy};
- SmallVector<Location> locations = {loc, loc, loc};
-
- auto whileOp = rewriter.create<scf::WhileOp>(
- loc, types, operands,
- [&](OpBuilder &beforeBuilder, Location beforeLoc, ValueRange args) {
- // The conditional block of the while loop.
- Value input = args[0];
- Value zero = args[2];
-
- Value inputNotZero = beforeBuilder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ne, input, zero);
- beforeBuilder.create<scf::ConditionOp>(loc, inputNotZero, args);
- },
- [&](OpBuilder &afterBuilder, Location afterLoc, ValueRange args) {
- // The body of the while loop: shift right until reaching a value of 0.
- Value input = args[0];
- Value leadingZeros = args[1];
-
- auto one = afterBuilder.create<arith::ConstantOp>(
- loc, IntegerAttr::get(elementTy, 1));
- auto shifted =
- afterBuilder.create<arith::ShRUIOp>(loc, resultTy, input, one);
- auto leadingZerosMinusOne = afterBuilder.create<arith::SubIOp>(
- loc, resultTy, leadingZeros, one);
-
- afterBuilder.create<scf::YieldOp>(
- loc, ValueRange({shifted, leadingZerosMinusOne, args[2]}));
- });
-
- rewriter.setInsertionPointAfter(whileOp);
- rewriter.replaceOp(op, whileOp->getResult(1));
+ int32_t bitwidth = eTy.getIntOrFloatBitWidth();
+ if (bitwidth > 64)
+ return failure();
+
+ uint64_t allbits = -1;
+ if (bitwidth < 64) {
+ allbits = allbits >> (64 - bitwidth);
+ }
+
+ Value x = operand;
+ Value count = createIntConst(loc, operandTy, 0, rewriter);
+ for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
+ auto half = bw / 2;
+ auto bits = createIntConst(loc, operandTy, half, rewriter);
+ auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
+
+ Value pred =
+ rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
+ Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
+ Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
+
+ x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
+ count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
+ }
+
+ Value zero = createIntConst(loc, operandTy, 0, rewriter);
+ Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+ operand, zero);
+
+ Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
+ Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
+ rewriter.replaceOp(op, sel);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 49ac15fd97b7e..a66ea082f1ef3 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -24,6 +24,16 @@ func.func @tanh(%arg: f32) -> f32 {
// -----
+
+// CHECK-LABEL: func @vector_tanh
+func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
+ // CHECK-NOT: math.tanh
+ %res = math.tanh %arg : vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @tan
func.func @tan(%arg: f32) -> f32 {
%res = math.tan %arg : f32
@@ -33,23 +43,79 @@ func.func @tan(%arg: f32) -> f32 {
// CHECK-SAME: %[[ARG0:.+]]: f32
// CHECK: %[[SIN:.+]] = math.sin %[[ARG0]]
// CHECK: %[[COS:.+]] = math.cos %[[ARG0]]
-// CEHCK: %[[DIV:.+]] = arith.div %[[SIN]] %[[COS]]
+// CHECK: %[[DIV:.+]] = arith.divf %[[SIN]], %[[COS]]
+
+
+// -----
+
+// CHECK-LABEL: func @vector_tan
+func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
+ %res = math.tan %arg : vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-NOT: math.tan
// -----
-// CHECK-LABEL: func @ctlz
func.func @ctlz(%arg: i32) -> i32 {
- // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
- // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32
- // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
- // CHECK: %[[WHILE:.+]]:3 = scf.while (%[[A1:.+]] = %arg0, %[[A2:.+]] = %[[C32]], %[[A3:.+]] = %[[C0]])
- // CHECK: %[[CMP:.+]] = arith.cmpi ne, %[[A1]], %[[A3]]
- // CHECK: scf.condition(%[[CMP]]) %[[A1]], %[[A2]], %[[A3]]
- // CHECK: %[[SHR:.+]] = arith.shrui %[[A1]], %[[C1]]
- // CHECK: %[[SUB:.+]] = arith.subi %[[A2]], %[[C1]]
- // CHECK: scf.yield %[[SHR]], %[[SUB]], %[[A3]]
%res = math.ctlz %arg : i32
-
- // CHECK: return %[[WHILE]]#1
return %res : i32
}
+
+// CHECK-LABEL: @ctlz
+// CHECK-SAME: %[[ARG0:.+]]: i32
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0
+// CHECK-DAG: %[[C16:.+]] = arith.constant 16
+// CHECK-DAG: %[[C65535:.+]] = arith.constant 65535
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8
+// CHECK-DAG: %[[C16777215:.+]] = arith.constant 16777215
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4
+// CHECK-DAG: %[[C268435455:.+]] = arith.constant 268435455
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2
+// CHECK-DAG: %[[C1073741823:.+]] = arith.constant 1073741823
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1
+// CHECK-DAG: %[[C2147483647:.+]] = arith.constant 2147483647
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32
+
+// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[ARG0]], %[[C65535]]
+// CHECK: %[[SHL:.+]] = arith.shli %[[ARG0]], %[[C16]]
+// CHECK: %[[SELX0:.+]] = arith.select %[[PRED]], %[[SHL]], %[[ARG0]]
+// CHECK: %[[SELY0:.+]] = arith.select %[[PRED]], %[[C16]], %[[C0]]
+
+// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX0]], %[[C16777215]]
+// CHECK: %[[ADD:.+]] = arith.addi %[[SELY0]], %[[C8]]
+// CHECK: %[[SHL:.+]] = arith.shli %[[SELX0]], %[[C8]]
+// CHECK: %[[SELX1:.+]] = arith.select %[[PRED]], %[[SHL]], %[[SELX0]]
+// CHECK: %[[SELY1:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY0]]
+
+// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX1]], %[[C268435455]] : i32
+// CHECK: %[[ADD:.+]] = arith.addi %[[SELY1]], %[[C4]]
+// CHECK: %[[SHL:.+]] = arith.shli %[[SELX1]], %[[C4]]
+// CHECK: %[[SELX2:.+]] = arith.select %[[PRED]], %[[SHL]], %[[SELX1]]
+// CHECK: %[[SELY2:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY1]]
+
+
+// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX2]], %[[C1073741823]] : i32
+// CHECK: %[[ADD:.+]] = arith.addi %[[SELY2]], %[[C2]]
+// CHECK: %[[SHL:.+]] = arith.shli %[[SELX2]], %[[C2]]
+// CHECK: %[[SELX3:.+]] = arith.select %[[PRED]], %[[SHL]], %[[SELX2]]
+// CHECK: %[[SELY3:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY2]]
+
+// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX3]], %[[C2147483647]] : i32
+// CHECK: %[[ADD:.+]] = arith.addi %[[SELY3]], %[[C1]]
+// CHECK: %[[SELY4:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY3]]
+
+// CHECK: %[[PRED:.+]] = arith.cmpi eq, %[[ARG0]], %[[C0]] : i32
+// CHECK: %[[SEL:.+]] = arith.select %[[PRED]], %[[C32]], %[[SELY4]] : i32
+// CHECK: return %[[SEL]]
+
+// -----
+
+func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+ %res = math.ctlz %arg : vector<4xi32>
+ return %res : vector<4xi32>
+}
+
+// CHECK-LABEL: @ctlz_vector
+// CHECK-NOT: math.ctlz
diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 28819518b2780..29b862e410c0f 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -26,7 +27,8 @@ struct TestExpandMathPass
void runOnOperation() override;
StringRef getArgument() const final { return "test-expand-math"; }
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithDialect, scf::SCFDialect>();
+ registry
+ .insert<arith::ArithDialect, scf::SCFDialect, vector::VectorDialect>();
}
StringRef getDescription() const final { return "Test expanding math"; }
};
More information about the Mlir-commits
mailing list