[Mlir-commits] [mlir] 23149d5 - [mlir] Added ctlz and cttz to math dialect and LLVM dialect

Rob Suderman llvmlistbot at llvm.org
Wed Dec 8 14:40:58 PST 2021


Author: Rob Suderman
Date: 2021-12-08T14:32:15-08:00
New Revision: 23149d522b92cf7525e3415c2c184ca6ecfbc1a1

URL: https://github.com/llvm/llvm-project/commit/23149d522b92cf7525e3415c2c184ca6ecfbc1a1
DIFF: https://github.com/llvm/llvm-project/commit/23149d522b92cf7525e3415c2c184ca6ecfbc1a1.diff

LOG: [mlir] Added ctlz and cttz to math dialect and LLVM dialect

Count leading/trailing zeros are an existing LLVM intrinsic. Added LLVM
support for the intrinsics with lowerings from the math dialect to LLVM
dialect.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D115206

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
    mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
    mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f90533721421c..453ac8dce142f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1401,6 +1401,12 @@ class LLVM_TernarySameArgsIntrinsicOp<string func, list<OpTrait> traits = []> :
   let arguments = (ins LLVM_Type:$a, LLVM_Type:$b, LLVM_Type:$c);
 }
 
+class LLVM_CountZerosIntrinsicOp<string func, list<OpTrait> traits = []> :
+    LLVM_OneResultIntrOp<func, [], [0],
+           !listconcat([NoSideEffect], traits)> {
+  let arguments = (ins LLVM_Type:$in, I<1>:$zero_undefined);
+}
+
 def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
 def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
 def LLVM_ExpOp : LLVM_UnaryIntrinsicOp<"exp">;
@@ -1421,6 +1427,8 @@ def LLVM_SinOp : LLVM_UnaryIntrinsicOp<"sin">;
 def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
 def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">;
 def LLVM_BitReverseOp : LLVM_UnaryIntrinsicOp<"bitreverse">;
+def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrinsicOp<"ctlz">;
+def LLVM_CountTrailingZerosOp : LLVM_CountZerosIntrinsicOp<"cttz">;
 def LLVM_CtPopOp : LLVM_UnaryIntrinsicOp<"ctpop">;
 def LLVM_MaxNumOp : LLVM_BinarySameArgsIntrinsicOp<"maxnum">;
 def LLVM_MinNumOp : LLVM_BinarySameArgsIntrinsicOp<"minnum">;

diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 42a5b57e9726b..bef60175e4b62 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -297,6 +297,54 @@ def Math_SinOp : Math_FloatUnaryOp<"sin"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// CountLeadingZerosOp
+//===----------------------------------------------------------------------===//
+
+def Math_CountLeadingZerosOp : Math_IntegerUnaryOp<"ctlz"> {
+  let summary = "counts the leading zeros an integer value";
+  let description = [{
+    The `ctlz` operation computes the number of leading zeros of an integer value.
+
+    Example:
+
+    ```mlir
+    // Scalar ctlz function value.
+    %a = math.ctlz %b : i32
+
+    // SIMD vector element-wise ctlz function value.
+    %f = math.ctlz %g : vector<4xi16>
+
+    // Tensor element-wise ctlz function value.
+    %x = math.ctlz %y : tensor<4x?xi8>
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// CountTrailingZerosOp
+//===----------------------------------------------------------------------===//
+
+def Math_CountTrailingZerosOp : Math_IntegerUnaryOp<"cttz"> {
+  let summary = "counts the trailing zeros an integer value";
+  let description = [{
+    The `cttz` operation computes the number of trailing zeros of an integer value.
+
+    Example:
+
+    ```mlir
+    // Scalar cttz function value.
+    %a = math.cttz %b : i32
+
+    // SIMD vector element-wise cttz function value.
+    %f = math.cttz %g : vector<4xi16>
+
+    // Tensor element-wise cttz function value.
+    %x = math.cttz %y : tensor<4x?xi8>
+    ```
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // CtPopOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 551cefe9eda0d..e83e491926a02 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -38,6 +38,54 @@ using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
 
+// A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
+template <typename MathOp, typename LLVMOp>
+struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
+  using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
+  using Super = CountOpLowering<MathOp, LLVMOp>;
+
+  LogicalResult
+  matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto operandType = adaptor.getOperand().getType();
+
+    if (!operandType || !LLVM::isCompatibleType(operandType))
+      return failure();
+
+    auto loc = op.getLoc();
+    auto resultType = op.getResult().getType();
+    auto boolType = rewriter.getIntegerType(1);
+    auto boolZero = rewriter.getIntegerAttr(boolType, 0);
+
+    if (!operandType.template isa<LLVM::LLVMArrayType>()) {
+      LLVM::ConstantOp zero =
+          rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
+      rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
+                                          zero);
+      return success();
+    }
+
+    auto vectorType = resultType.template dyn_cast<VectorType>();
+    if (!vectorType)
+      return failure();
+
+    return LLVM::detail::handleMultidimensionalVectors(
+        op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
+        [&](Type llvm1DVectorTy, ValueRange operands) {
+          LLVM::ConstantOp zero =
+              rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
+          return rewriter.replaceOpWithNewOp<LLVMOp>(op, llvm1DVectorTy,
+                                                     operands[0], zero);
+        },
+        rewriter);
+  }
+};
+
+using CountLeadingZerosOpLowering =
+    CountOpLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
+using CountTrailingZerosOpLowering =
+    CountOpLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
+
 // A `expm1` is converted into `exp - 1`.
 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
@@ -222,6 +270,8 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
     CeilOpLowering,
     CopySignOpLowering,
     CosOpLowering,
+    CountLeadingZerosOpLowering,
+    CountTrailingZerosOpLowering,
     CtPopFOpLowering,
     ExpOpLowering,
     Exp2OpLowering,

diff  --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 2b22ef9e319f2..b2de213b4e34c 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -74,6 +74,39 @@ func @sine(%arg0 : f32) {
 
 // -----
 
+// CHECK-LABEL: func @ctlz(
+// CHECK-SAME: i32
+func @ctlz(%arg0 : i32) {
+  // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
+  // CHECK: "llvm.intr.ctlz"(%arg0, %[[ZERO]]) : (i32, i1) -> i32
+  %0 = math.ctlz %arg0 : i32
+  std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cttz(
+// CHECK-SAME: i32
+func @cttz(%arg0 : i32) {
+  // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
+  // CHECK: "llvm.intr.cttz"(%arg0, %[[ZERO]]) : (i32, i1) -> i32
+  %0 = math.cttz %arg0 : i32
+  std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @cttz_vec(
+// CHECK-SAME: i32
+func @cttz_vec(%arg0 : vector<4xi32>) {
+  // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(false) : i1
+  // CHECK: "llvm.intr.cttz"(%arg0, %[[ZERO]]) : (vector<4xi32>, i1) -> vector<4xi32>
+  %0 = math.cttz %arg0 : vector<4xi32>
+  std.return
+}
+
+// -----
+
 // CHECK-LABEL: func @ctpop(
 // CHECK-SAME: i32
 func @ctpop(%arg0 : i32) {

diff  --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index 0a4bf63d0c3c1..18a9badc264d2 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -135,6 +135,26 @@ llvm.func @bitreverse_test(%arg0: i32, %arg1: vector<8xi32>) {
   llvm.return
 }
 
+// CHECK-LABEL: @ctlz_test
+llvm.func @ctlz_test(%arg0: i32, %arg1: vector<8xi32>) {
+  %i1 = llvm.mlir.constant(false) : i1
+  // CHECK: call i32 @llvm.ctlz.i32
+  "llvm.intr.ctlz"(%arg0, %i1) : (i32, i1) -> i32
+  // CHECK: call <8 x i32> @llvm.ctlz.v8i32
+  "llvm.intr.ctlz"(%arg1, %i1) : (vector<8xi32>, i1) -> vector<8xi32>
+  llvm.return
+}
+
+// CHECK-LABEL: @cttz_test
+llvm.func @cttz_test(%arg0: i32, %arg1: vector<8xi32>) {
+  %i1 = llvm.mlir.constant(false) : i1
+  // CHECK: call i32 @llvm.cttz.i32
+  "llvm.intr.cttz"(%arg0, %i1) : (i32, i1) -> i32
+  // CHECK: call <8 x i32> @llvm.cttz.v8i32
+  "llvm.intr.cttz"(%arg1, %i1) : (vector<8xi32>, i1) -> vector<8xi32>
+  llvm.return
+}
+
 // CHECK-LABEL: @ctpop_test
 llvm.func @ctpop_test(%arg0: i32, %arg1: vector<8xi32>) {
   // CHECK: call i32 @llvm.ctpop.i32


        


More information about the Mlir-commits mailing list