[Mlir-commits] [mlir] 23149d5 - [mlir] Added ctlz and cttz to math dialect and LLVM dialect
Rob Suderman
llvmlistbot at llvm.org
Wed Dec 8 14:40:58 PST 2021
Author: Rob Suderman
Date: 2021-12-08T14:32:15-08:00
New Revision: 23149d522b92cf7525e3415c2c184ca6ecfbc1a1
URL: https://github.com/llvm/llvm-project/commit/23149d522b92cf7525e3415c2c184ca6ecfbc1a1
DIFF: https://github.com/llvm/llvm-project/commit/23149d522b92cf7525e3415c2c184ca6ecfbc1a1.diff
LOG: [mlir] Added ctlz and cttz to math dialect and LLVM dialect
Count leading/trailing zeros are an existing LLVM intrinsic. Added LLVM
support for the intrinsics with lowerings from the math dialect to LLVM
dialect.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D115206
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f90533721421c..453ac8dce142f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1401,6 +1401,12 @@ class LLVM_TernarySameArgsIntrinsicOp<string func, list<OpTrait> traits = []> :
let arguments = (ins LLVM_Type:$a, LLVM_Type:$b, LLVM_Type:$c);
}
+class LLVM_CountZerosIntrinsicOp<string func, list<OpTrait> traits = []> :
+ LLVM_OneResultIntrOp<func, [], [0],
+ !listconcat([NoSideEffect], traits)> {
+ let arguments = (ins LLVM_Type:$in, I<1>:$zero_undefined);
+}
+
def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
def LLVM_ExpOp : LLVM_UnaryIntrinsicOp<"exp">;
@@ -1421,6 +1427,8 @@ def LLVM_SinOp : LLVM_UnaryIntrinsicOp<"sin">;
def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">;
def LLVM_BitReverseOp : LLVM_UnaryIntrinsicOp<"bitreverse">;
+def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrinsicOp<"ctlz">;
+def LLVM_CountTrailingZerosOp : LLVM_CountZerosIntrinsicOp<"cttz">;
def LLVM_CtPopOp : LLVM_UnaryIntrinsicOp<"ctpop">;
def LLVM_MaxNumOp : LLVM_BinarySameArgsIntrinsicOp<"maxnum">;
def LLVM_MinNumOp : LLVM_BinarySameArgsIntrinsicOp<"minnum">;
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 42a5b57e9726b..bef60175e4b62 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -297,6 +297,54 @@ def Math_SinOp : Math_FloatUnaryOp<"sin"> {
}];
}
+//===----------------------------------------------------------------------===//
+// CountLeadingZerosOp
+//===----------------------------------------------------------------------===//
+
+def Math_CountLeadingZerosOp : Math_IntegerUnaryOp<"ctlz"> {
+ let summary = "counts the leading zeros an integer value";
+ let description = [{
+ The `ctlz` operation computes the number of leading zeros of an integer value.
+
+ Example:
+
+ ```mlir
+ // Scalar ctlz function value.
+ %a = math.ctlz %b : i32
+
+ // SIMD vector element-wise ctlz function value.
+ %f = math.ctlz %g : vector<4xi16>
+
+ // Tensor element-wise ctlz function value.
+ %x = math.ctlz %y : tensor<4x?xi8>
+ ```
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// CountTrailingZerosOp
+//===----------------------------------------------------------------------===//
+
+def Math_CountTrailingZerosOp : Math_IntegerUnaryOp<"cttz"> {
+ let summary = "counts the trailing zeros an integer value";
+ let description = [{
+ The `cttz` operation computes the number of trailing zeros of an integer value.
+
+ Example:
+
+ ```mlir
+ // Scalar cttz function value.
+ %a = math.cttz %b : i32
+
+ // SIMD vector element-wise cttz function value.
+ %f = math.cttz %g : vector<4xi16>
+
+ // Tensor element-wise cttz function value.
+ %x = math.cttz %y : tensor<4x?xi8>
+ ```
+ }];
+}
+
//===----------------------------------------------------------------------===//
// CtPopOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 551cefe9eda0d..e83e491926a02 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -38,6 +38,54 @@ using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
+// A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
+template <typename MathOp, typename LLVMOp>
+struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
+ using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
+ using Super = CountOpLowering<MathOp, LLVMOp>;
+
+ LogicalResult
+ matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto operandType = adaptor.getOperand().getType();
+
+ if (!operandType || !LLVM::isCompatibleType(operandType))
+ return failure();
+
+ auto loc = op.getLoc();
+ auto resultType = op.getResult().getType();
+ auto boolType = rewriter.getIntegerType(1);
+ auto boolZero = rewriter.getIntegerAttr(boolType, 0);
+
+ if (!operandType.template isa<LLVM::LLVMArrayType>()) {
+ LLVM::ConstantOp zero =
+ rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
+ rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
+ zero);
+ return success();
+ }
+
+ auto vectorType = resultType.template dyn_cast<VectorType>();
+ if (!vectorType)
+ return failure();
+
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) {
+ LLVM::ConstantOp zero =
+ rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
+ return rewriter.replaceOpWithNewOp<LLVMOp>(op, llvm1DVectorTy,
+ operands[0], zero);
+ },
+ rewriter);
+ }
+};
+
+using CountLeadingZerosOpLowering =
+ CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
+using CountTrailingZerosOpLowering =
+ CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
+
// A `expm1` is converted into `exp - 1`.
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
@@ -222,6 +270,8 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
CeilOpLowering,
CopySignOpLowering,
CosOpLowering,
+ CountLeadingZerosOpLowering,
+ CountTrailingZerosOpLowering,
CtPopFOpLowering,
ExpOpLowering,
Exp2OpLowering,
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 2b22ef9e319f2..b2de213b4e34c 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -74,6 +74,39 @@ func @sine(%arg0 : f32) {
// -----
+// CHECK-LABEL: func @ctlz(
+// CHECK-SAME: i32
+func @ctlz(%arg0 : i32) {
+ // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
+ // CHECK: "llvm.intr.ctlz"(%arg0, %[[ZERO]]) : (i32, i1) -> i32
+ %0 = math.ctlz %arg0 : i32
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cttz(
+// CHECK-SAME: i32
+func @cttz(%arg0 : i32) {
+ // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
+ // CHECK: "llvm.intr.cttz"(%arg0, %[[ZERO]]) : (i32, i1) -> i32
+ %0 = math.cttz %arg0 : i32
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cttz_vec(
+// CHECK-SAME: i32
+func @cttz_vec(%arg0 : vector<4xi32>) {
+ // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
+ // CHECK: "llvm.intr.cttz"(%arg0, %[[ZERO]]) : (vector<4xi32>, i1) -> vector<4xi32>
+ %0 = math.cttz %arg0 : vector<4xi32>
+ std.return
+}
+
+// -----
+
// CHECK-LABEL: func @ctpop(
// CHECK-SAME: i32
func @ctpop(%arg0 : i32) {
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index 0a4bf63d0c3c1..18a9badc264d2 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -135,6 +135,26 @@ llvm.func @bitreverse_test(%arg0: i32, %arg1: vector<8xi32>) {
llvm.return
}
+// CHECK-LABEL: @ctlz_test
+llvm.func @ctlz_test(%arg0: i32, %arg1: vector<8xi32>) {
+ %i1 = llvm.mlir.constant(false) : i1
+ // CHECK: call i32 @llvm.ctlz.i32
+ "llvm.intr.ctlz"(%arg0, %i1) : (i32, i1) -> i32
+ // CHECK: call <8 x i32> @llvm.ctlz.v8i32
+ "llvm.intr.ctlz"(%arg1, %i1) : (vector<8xi32>, i1) -> vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL: @cttz_test
+llvm.func @cttz_test(%arg0: i32, %arg1: vector<8xi32>) {
+ %i1 = llvm.mlir.constant(false) : i1
+ // CHECK: call i32 @llvm.cttz.i32
+ "llvm.intr.cttz"(%arg0, %i1) : (i32, i1) -> i32
+ // CHECK: call <8 x i32> @llvm.cttz.v8i32
+ "llvm.intr.cttz"(%arg1, %i1) : (vector<8xi32>, i1) -> vector<8xi32>
+ llvm.return
+}
+
// CHECK-LABEL: @ctpop_test
llvm.func @ctpop_test(%arg0: i32, %arg1: vector<8xi32>) {
// CHECK: call i32 @llvm.ctpop.i32
More information about the Mlir-commits
mailing list