[Mlir-commits] [mlir] f3bdb56 - [mlir][math] Add math.ctlz expansion to control flow + arith operations

Rob Suderman llvmlistbot at llvm.org
Wed Jun 1 11:50:28 PDT 2022


Author: Rob Suderman
Date: 2022-06-01T11:45:04-07:00
New Revision: f3bdb56d61e3e7bbcb2615f087cc63b67c60ab59

URL: https://github.com/llvm/llvm-project/commit/f3bdb56d61e3e7bbcb2615f087cc63b67c60ab59
DIFF: https://github.com/llvm/llvm-project/commit/f3bdb56d61e3e7bbcb2615f087cc63b67c60ab59.diff

LOG: [mlir][math] Add math.ctlz expansion to control flow + arith operations

Ctlz is an intrinsic in LLVM but does not have equivalent operations in SPIR-V.
Including a decomposition gives an alternative path for these platforms.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D126261

Added: 
    mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
    mlir/test/Dialect/Math/expand-math.mlir
    mlir/test/lib/Dialect/Math/TestExpandMath.cpp

Modified: 
    mlir/include/mlir/Dialect/Math/Transforms/Passes.h
    mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Math/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
    mlir/test/Dialect/Math/expand-tanh.mlir
    mlir/test/lib/Dialect/Math/TestExpandTanh.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 8de5782fe9c9f..9dbead1768e8e 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -13,6 +13,7 @@ namespace mlir {
 
 class RewritePatternSet;
 
+void populateExpandCtlzPattern(RewritePatternSet &patterns);
 void populateExpandTanhPattern(RewritePatternSet &patterns);
 
 void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 0343338602fcc..cc3a961cb0b5c 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
 add_mlir_dialect_library(MLIRMathTransforms
   AlgebraicSimplification.cpp
-  ExpandTanh.cpp
+  ExpandPatterns.cpp
   PolynomialApproximation.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
similarity index 53%
rename from mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
rename to mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index fe8c2fe98e4aa..9de58ceabd92d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -53,6 +54,67 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
   return success();
 }
 
+static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
+                                   PatternRewriter &rewriter) {
+  auto operand = op.getOperand();
+  auto elementTy = operand.getType();
+  auto resultTy = op.getType();
+  Location loc = op.getLoc();
+
+  int bitWidth = elementTy.getIntOrFloatBitWidth();
+  auto zero =
+      rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+  auto leadingZeros = rewriter.create<arith::ConstantOp>(
+      loc, IntegerAttr::get(elementTy, bitWidth));
+
+  SmallVector<Value> operands = {operand, leadingZeros, zero};
+  SmallVector<Type> types = {elementTy, elementTy, elementTy};
+  SmallVector<Location> locations = {loc, loc, loc};
+
+  auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
+  Block *before =
+      rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
+  Block *after =
+      rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
+
+  // The conditional block of the while loop.
+  {
+    rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
+    Value input = before->getArgument(0);
+    Value zero = before->getArgument(2);
+
+    Value inputNotZero = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ne, input, zero);
+    rewriter.create<scf::ConditionOp>(loc, inputNotZero,
+                                      before->getArguments());
+  }
+
+  // The body of the while loop: shift right until reaching a value of 0.
+  {
+    rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
+    Value input = after->getArgument(0);
+    Value leadingZeros = after->getArgument(1);
+
+    auto one =
+        rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
+    auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
+    auto leadingZerosMinusOne =
+        rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
+
+    rewriter.create<scf::YieldOp>(
+        loc,
+        ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
+  }
+
+  rewriter.setInsertionPointAfter(whileOp);
+  rewriter.replaceOp(op, whileOp->getResult(1));
+  return success();
+}
+
+void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
+  patterns.add(convertCtlzOp);
+}
+
 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
   patterns.add(convertTanhOp);
 }

diff  --git a/mlir/test/Dialect/Math/expand-tanh.mlir b/mlir/test/Dialect/Math/expand-math.mlir
similarity index 59%
rename from mlir/test/Dialect/Math/expand-tanh.mlir
rename to mlir/test/Dialect/Math/expand-math.mlir
index 6724268c9a36b..0d7a63589ee47 100644
--- a/mlir/test/Dialect/Math/expand-tanh.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-expand-tanh | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s
 
 // CHECK-LABEL: func @tanh
 func.func @tanh(%arg: f32) -> f32 {
@@ -21,3 +21,22 @@ func.func @tanh(%arg: f32) -> f32 {
 // CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32
 // CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32
 // CHECK: return %[[RESULT]]
+
+// ----
+
+// CHECK-LABEL: func @ctlz
+func.func @ctlz(%arg: i32) -> i32 {
+  // CHECK: %[[C0:.+]] = arith.constant 0 : i32
+  // CHECK: %[[C32:.+]] = arith.constant 32 : i32
+  // CHECK: %[[C1:.+]] = arith.constant 1 : i32
+  // CHECK: %[[WHILE:.+]]:3 = scf.while (%[[A1:.+]] = %arg0, %[[A2:.+]] = %[[C32]], %[[A3:.+]] = %[[C0]])
+  // CHECK:   %[[CMP:.+]] = arith.cmpi ne, %[[A1]], %[[A3]]
+  // CHECK:   scf.condition(%[[CMP]]) %[[A1]], %[[A2]], %[[A3]]
+  // CHECK:   %[[SHR:.+]] = arith.shrui %[[A1]], %[[C1]]
+  // CHECK:   %[[SUB:.+]] = arith.subi %[[A2]], %[[C1]]
+  // CHECK:   scf.yield %[[SHR]], %[[SUB]], %[[A3]]
+  %res = math.ctlz %arg : i32
+
+  // CHECK: return %[[WHILE]]#1
+  return %res : i32
+}

diff  --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt
index dd2f726928ae6..ff62dc4fd053c 100644
--- a/mlir/test/lib/Dialect/Math/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt
@@ -1,7 +1,7 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMathTestPasses
   TestAlgebraicSimplification.cpp
-  TestExpandTanh.cpp
+  TestExpandMath.cpp
   TestPolynomialApproximation.cpp
 
   EXCLUDE_FROM_LIBMLIR

diff  --git a/mlir/test/lib/Dialect/Math/TestExpandTanh.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
similarity index 55%
rename from mlir/test/lib/Dialect/Math/TestExpandTanh.cpp
rename to mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index fd83452bc4292..82e0e71f5026b 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandTanh.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -1,4 +1,4 @@
-//===- TestExpandTanh.cpp - Test expand tanh op into exp form -------------===//
+//===- TestExpandMath.cpp - Test expand math op into exp form -------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,35 +6,41 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file contains test passes for expanding tanh.
+// This file contains test passes for expanding math operations.
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 
 namespace {
-struct TestExpandTanhPass
-    : public PassWrapper<TestExpandTanhPass, OperationPass<>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandTanhPass)
+struct TestExpandMathPass
+    : public PassWrapper<TestExpandMathPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass)
 
   void runOnOperation() override;
-  StringRef getArgument() const final { return "test-expand-tanh"; }
-  StringRef getDescription() const final { return "Test expanding tanh"; }
+  StringRef getArgument() const final { return "test-expand-math"; }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithmeticDialect, scf::SCFDialect>();
+  }
+  StringRef getDescription() const final { return "Test expanding math"; }
 };
 } // namespace
 
-void TestExpandTanhPass::runOnOperation() {
+void TestExpandMathPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
+  populateExpandCtlzPattern(patterns);
   populateExpandTanhPattern(patterns);
   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
 }
 
 namespace mlir {
 namespace test {
-void registerTestExpandTanhPass() { PassRegistration<TestExpandTanhPass>(); }
+void registerTestExpandMathPass() { PassRegistration<TestExpandMathPass>(); }
 } // namespace test
 } // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index dcd8946d9c404..e75c2758e88db 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -76,7 +76,7 @@ void registerTestDecomposeCallGraphTypes();
 void registerTestDiagnosticsPass();
 void registerTestDominancePass();
 void registerTestDynamicPipelinePass();
-void registerTestExpandTanhPass();
+void registerTestExpandMathPass();
 void registerTestComposeSubView();
 void registerTestMultiBuffering();
 void registerTestIRVisitorsPass();
@@ -172,7 +172,7 @@ void registerTestPasses() {
   mlir::test::registerTestDataLayoutQuery();
   mlir::test::registerTestDominancePass();
   mlir::test::registerTestDynamicPipelinePass();
-  mlir::test::registerTestExpandTanhPass();
+  mlir::test::registerTestExpandMathPass();
   mlir::test::registerTestComposeSubView();
   mlir::test::registerTestMultiBuffering();
   mlir::test::registerTestIRVisitorsPass();


        


More information about the Mlir-commits mailing list