[Mlir-commits] [mlir] 711c589 - [mlir][math] Update math arith expansions for vectorization

Robert Suderman llvmlistbot at llvm.org
Thu Apr 6 11:43:04 PDT 2023


Author: Robert Suderman
Date: 2023-04-06T18:42:01Z
New Revision: 711c58938f36d91af1dc4209946bcf5e70869445

URL: https://github.com/llvm/llvm-project/commit/711c58938f36d91af1dc4209946bcf5e70869445
DIFF: https://github.com/llvm/llvm-project/commit/711c58938f36d91af1dc4209946bcf5e70869445.diff

LOG: [mlir][math] Update math arith expansions for vectorization

The math arithmetic expansions do not support vectorized types.
Updated the lowerings so that they support vectorized types. This
includes a different implementation for `math.ctlz` to be a binary
search and not have variable termination time.

Reviewed By: jpienaar, NatashaKnk

Differential Revision: https://reviews.llvm.org/D147289

Added: 
    

Modified: 
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 364dd05c093ba..91aef84348a96 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -14,22 +14,46 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 
+/// Create a float constant.
+static Value createFloatConst(Location loc, Type type, double value,
+                              OpBuilder &b) {
+  auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+    return b.create<arith::ConstantOp>(loc,
+                                       DenseElementsAttr::get(shapedTy, attr));
+  }
+
+  return b.create<arith::ConstantOp>(loc, attr);
+}
+
+/// Create a float constant.
+static Value createIntConst(Location loc, Type type, int64_t value,
+                            OpBuilder &b) {
+  auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+    return b.create<arith::ConstantOp>(loc,
+                                       DenseElementsAttr::get(shapedTy, attr));
+  }
+
+  return b.create<arith::ConstantOp>(loc, attr);
+}
+
 /// Expands tanh op into
 ///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
 ///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   auto floatType = op.getOperand().getType();
   Location loc = op.getLoc();
-  auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
-  auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
-  Value one = rewriter.create<arith::ConstantOp>(loc, floatOne);
-  Value two = rewriter.create<arith::ConstantOp>(loc, floatTwo);
+  Value one = createFloatConst(loc, floatType, 1.0, rewriter);
+  Value two = createFloatConst(loc, floatType, 2.0, rewriter);
   Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
 
   // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
@@ -46,8 +70,7 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
 
   // tanh(x) = x >= 0 ? positiveRes : negativeRes
-  auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
-  Value zero = rewriter.create<arith::ConstantOp>(loc, floatZero);
+  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
   Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
                                                 op.getOperand(), zero);
   rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
@@ -55,6 +78,7 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   return success();
 }
 
+// Converts math.tan to math.sin, math.cos, and arith.divf.
 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
@@ -66,52 +90,47 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
   return success();
 }
 
+// Converts math.ctlz to scf and arith operations. This is done
+// by performing a binary search on the bits.
 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
                                    PatternRewriter &rewriter) {
   auto operand = op.getOperand();
-  auto elementTy = operand.getType();
-  auto resultTy = op.getType();
+  auto operandTy = operand.getType();
+  auto eTy = getElementTypeOrSelf(operandTy);
   Location loc = op.getLoc();
 
-  int bitWidth = elementTy.getIntOrFloatBitWidth();
-  auto zero =
-      rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
-  auto leadingZeros = rewriter.create<arith::ConstantOp>(
-      loc, IntegerAttr::get(elementTy, bitWidth));
-
-  SmallVector<Value> operands = {operand, leadingZeros, zero};
-  SmallVector<Type> types = {elementTy, elementTy, elementTy};
-  SmallVector<Location> locations = {loc, loc, loc};
-
-  auto whileOp = rewriter.create<scf::WhileOp>(
-      loc, types, operands,
-      [&](OpBuilder &beforeBuilder, Location beforeLoc, ValueRange args) {
-        // The conditional block of the while loop.
-        Value input = args[0];
-        Value zero = args[2];
-
-        Value inputNotZero = beforeBuilder.create<arith::CmpIOp>(
-            loc, arith::CmpIPredicate::ne, input, zero);
-        beforeBuilder.create<scf::ConditionOp>(loc, inputNotZero, args);
-      },
-      [&](OpBuilder &afterBuilder, Location afterLoc, ValueRange args) {
-        // The body of the while loop: shift right until reaching a value of 0.
-        Value input = args[0];
-        Value leadingZeros = args[1];
-
-        auto one = afterBuilder.create<arith::ConstantOp>(
-            loc, IntegerAttr::get(elementTy, 1));
-        auto shifted =
-            afterBuilder.create<arith::ShRUIOp>(loc, resultTy, input, one);
-        auto leadingZerosMinusOne = afterBuilder.create<arith::SubIOp>(
-            loc, resultTy, leadingZeros, one);
-
-        afterBuilder.create<scf::YieldOp>(
-            loc, ValueRange({shifted, leadingZerosMinusOne, args[2]}));
-      });
-
-  rewriter.setInsertionPointAfter(whileOp);
-  rewriter.replaceOp(op, whileOp->getResult(1));
+  int32_t bitwidth = eTy.getIntOrFloatBitWidth();
+  if (bitwidth > 64)
+    return failure();
+
+  uint64_t allbits = -1;
+  if (bitwidth < 64) {
+    allbits = allbits >> (64 - bitwidth);
+  }
+
+  Value x = operand;
+  Value count = createIntConst(loc, operandTy, 0, rewriter);
+  for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
+    auto half = bw / 2;
+    auto bits = createIntConst(loc, operandTy, half, rewriter);
+    auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
+
+    Value pred =
+        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
+    Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
+    Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
+
+    x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
+    count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
+  }
+
+  Value zero = createIntConst(loc, operandTy, 0, rewriter);
+  Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+                                              operand, zero);
+
+  Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
+  Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
+  rewriter.replaceOp(op, sel);
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 49ac15fd97b7e..a66ea082f1ef3 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -24,6 +24,16 @@ func.func @tanh(%arg: f32) -> f32 {
 
 // -----
 
+
+// CHECK-LABEL: func @vector_tanh
+func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
+  // CHECK-NOT: math.tanh
+  %res = math.tanh %arg : vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @tan
 func.func @tan(%arg: f32) -> f32 {
   %res = math.tan %arg : f32
@@ -33,23 +43,79 @@ func.func @tan(%arg: f32) -> f32 {
 // CHECK-SAME: %[[ARG0:.+]]: f32
 // CHECK: %[[SIN:.+]] = math.sin %[[ARG0]]
 // CHECK: %[[COS:.+]] = math.cos %[[ARG0]]
-// CEHCK: %[[DIV:.+]] = arith.div %[[SIN]] %[[COS]]
+// CHECK: %[[DIV:.+]] = arith.divf %[[SIN]], %[[COS]]
+
+
+// -----
+
+// CHECK-LABEL: func @vector_tan
+func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
+  %res = math.tan %arg : vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-NOT: math.tan
 
 // -----
 
-// CHECK-LABEL: func @ctlz
 func.func @ctlz(%arg: i32) -> i32 {
-  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
-  // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32
-  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
-  // CHECK: %[[WHILE:.+]]:3 = scf.while (%[[A1:.+]] = %arg0, %[[A2:.+]] = %[[C32]], %[[A3:.+]] = %[[C0]])
-  // CHECK:   %[[CMP:.+]] = arith.cmpi ne, %[[A1]], %[[A3]]
-  // CHECK:   scf.condition(%[[CMP]]) %[[A1]], %[[A2]], %[[A3]]
-  // CHECK:   %[[SHR:.+]] = arith.shrui %[[A1]], %[[C1]]
-  // CHECK:   %[[SUB:.+]] = arith.subi %[[A2]], %[[C1]]
-  // CHECK:   scf.yield %[[SHR]], %[[SUB]], %[[A3]]
   %res = math.ctlz %arg : i32
-
-  // CHECK: return %[[WHILE]]#1
   return %res : i32
 }
+
+// CHECK-LABEL: @ctlz
+// CHECK-SAME: %[[ARG0:.+]]: i32
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0
+// CHECK-DAG: %[[C16:.+]] = arith.constant 16
+// CHECK-DAG: %[[C65535:.+]] = arith.constant 65535
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8
+// CHECK-DAG: %[[C16777215:.+]] = arith.constant 16777215
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4
+// CHECK-DAG: %[[C268435455:.+]] = arith.constant 268435455
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2
+// CHECK-DAG: %[[C1073741823:.+]] = arith.constant 1073741823
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1
+// CHECK-DAG: %[[C2147483647:.+]] = arith.constant 2147483647
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32
+
+// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[ARG0]], %[[C65535]]
+// CHECK: %[[SHL:.+]] = arith.shli %[[ARG0]], %[[C16]]
+// CHECK: %[[SELX0:.+]] = arith.select %[[PRED]], %[[SHL]], %[[ARG0]]
+// CHECK: %[[SELY0:.+]] = arith.select %[[PRED]], %[[C16]], %[[C0]]
+
+// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX0]], %[[C16777215]]
+// CHECK: %[[ADD:.+]] = arith.addi %[[SELY0]], %[[C8]]
+// CHECK: %[[SHL:.+]] = arith.shli %[[SELX0]], %[[C8]]
+// CHECK: %[[SELX1:.+]] = arith.select %[[PRED]], %[[SHL]], %[[SELX0]]
+// CHECK: %[[SELY1:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY0]]
+
+// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX1]], %[[C268435455]] : i32
+// CHECK: %[[ADD:.+]] = arith.addi %[[SELY1]], %[[C4]]
+// CHECK: %[[SHL:.+]] = arith.shli %[[SELX1]], %[[C4]]
+// CHECK: %[[SELX2:.+]] = arith.select %[[PRED]], %[[SHL]], %[[SELX1]]
+// CHECK: %[[SELY2:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY1]]
+
+
+// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX2]], %[[C1073741823]] : i32
+// CHECK: %[[ADD:.+]] = arith.addi %[[SELY2]], %[[C2]]
+// CHECK: %[[SHL:.+]] = arith.shli %[[SELX2]], %[[C2]]
+// CHECK: %[[SELX3:.+]] = arith.select %[[PRED]], %[[SHL]], %[[SELX2]]
+// CHECK: %[[SELY3:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY2]]
+
+// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX3]], %[[C2147483647]] : i32
+// CHECK: %[[ADD:.+]] = arith.addi %[[SELY3]], %[[C1]]
+// CHECK: %[[SELY4:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY3]]
+
+// CHECK: %[[PRED:.+]] = arith.cmpi eq, %[[ARG0]], %[[C0]] : i32
+// CHECK: %[[SEL:.+]] = arith.select %[[PRED]], %[[C32]], %[[SELY4]] : i32
+// CHECK: return %[[SEL]]
+
+// -----
+
+func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+  %res = math.ctlz %arg : vector<4xi32>
+  return %res : vector<4xi32>
+}
+
+// CHECK-LABEL: @ctlz_vector
+// CHECK-NOT: math.ctlz

diff  --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index 28819518b2780..29b862e410c0f 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -26,7 +27,8 @@ struct TestExpandMathPass
   void runOnOperation() override;
   StringRef getArgument() const final { return "test-expand-math"; }
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithDialect, scf::SCFDialect>();
+    registry
+        .insert<arith::ArithDialect, scf::SCFDialect, vector::VectorDialect>();
   }
   StringRef getDescription() const final { return "Test expanding math"; }
 };


        


More information about the Mlir-commits mailing list