[Mlir-commits] [mlir] f3bdb56 - [mlir][math] Add math.ctlz expansion to control flow + arith operations
Rob Suderman
llvmlistbot at llvm.org
Wed Jun 1 11:50:28 PDT 2022
Author: Rob Suderman
Date: 2022-06-01T11:45:04-07:00
New Revision: f3bdb56d61e3e7bbcb2615f087cc63b67c60ab59
URL: https://github.com/llvm/llvm-project/commit/f3bdb56d61e3e7bbcb2615f087cc63b67c60ab59
DIFF: https://github.com/llvm/llvm-project/commit/f3bdb56d61e3e7bbcb2615f087cc63b67c60ab59.diff
LOG: [mlir][math] Add math.ctlz expansion to control flow + arith operations
Ctlz is an intrinsic in LLVM but does not have equivalent operations in SPIR-V.
Including a decomposition gives an alternative path for these platforms.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D126261
Added:
mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/lib/Dialect/Math/TestExpandMath.cpp
Modified:
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Math/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
mlir/test/Dialect/Math/expand-tanh.mlir
mlir/test/lib/Dialect/Math/TestExpandTanh.cpp
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 8de5782fe9c9f..9dbead1768e8e 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -13,6 +13,7 @@ namespace mlir {
class RewritePatternSet;
+void populateExpandCtlzPattern(RewritePatternSet &patterns);
void populateExpandTanhPattern(RewritePatternSet &patterns);
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index 0343338602fcc..cc3a961cb0b5c 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
add_mlir_dialect_library(MLIRMathTransforms
AlgebraicSimplification.cpp
- ExpandTanh.cpp
+ ExpandPatterns.cpp
PolynomialApproximation.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
similarity index 53%
rename from mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
rename to mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index fe8c2fe98e4aa..9de58ceabd92d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -53,6 +54,67 @@ static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
return success();
}
+static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
+ PatternRewriter &rewriter) {
+ auto operand = op.getOperand();
+ auto elementTy = operand.getType();
+ auto resultTy = op.getType();
+ Location loc = op.getLoc();
+
+ int bitWidth = elementTy.getIntOrFloatBitWidth();
+ auto zero =
+ rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+ auto leadingZeros = rewriter.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(elementTy, bitWidth));
+
+ SmallVector<Value> operands = {operand, leadingZeros, zero};
+ SmallVector<Type> types = {elementTy, elementTy, elementTy};
+ SmallVector<Location> locations = {loc, loc, loc};
+
+ auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
+ Block *before =
+ rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
+ Block *after =
+ rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
+
+ // The conditional block of the while loop.
+ {
+ rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
+ Value input = before->getArgument(0);
+ Value zero = before->getArgument(2);
+
+ Value inputNotZero = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, input, zero);
+ rewriter.create<scf::ConditionOp>(loc, inputNotZero,
+ before->getArguments());
+ }
+
+ // The body of the while loop: shift right until reaching a value of 0.
+ {
+ rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
+ Value input = after->getArgument(0);
+ Value leadingZeros = after->getArgument(1);
+
+ auto one =
+ rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
+ auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
+ auto leadingZerosMinusOne =
+ rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
+
+ rewriter.create<scf::YieldOp>(
+ loc,
+ ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
+ }
+
+ rewriter.setInsertionPointAfter(whileOp);
+ rewriter.replaceOp(op, whileOp->getResult(1));
+ return success();
+}
+
+void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
+ patterns.add(convertCtlzOp);
+}
+
void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
patterns.add(convertTanhOp);
}
diff --git a/mlir/test/Dialect/Math/expand-tanh.mlir b/mlir/test/Dialect/Math/expand-math.mlir
similarity index 59%
rename from mlir/test/Dialect/Math/expand-tanh.mlir
rename to mlir/test/Dialect/Math/expand-math.mlir
index 6724268c9a36b..0d7a63589ee47 100644
--- a/mlir/test/Dialect/Math/expand-tanh.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-expand-tanh | FileCheck %s
+// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s
// CHECK-LABEL: func @tanh
func.func @tanh(%arg: f32) -> f32 {
@@ -21,3 +21,22 @@ func.func @tanh(%arg: f32) -> f32 {
// CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32
// CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32
// CHECK: return %[[RESULT]]
+
+// ----
+
+// CHECK-LABEL: func @ctlz
+func.func @ctlz(%arg: i32) -> i32 {
+ // CHECK: %[[C0:.+]] = arith.constant 0 : i32
+ // CHECK: %[[C32:.+]] = arith.constant 32 : i32
+ // CHECK: %[[C1:.+]] = arith.constant 1 : i32
+ // CHECK: %[[WHILE:.+]]:3 = scf.while (%[[A1:.+]] = %arg0, %[[A2:.+]] = %[[C32]], %[[A3:.+]] = %[[C0]])
+ // CHECK: %[[CMP:.+]] = arith.cmpi ne, %[[A1]], %[[A3]]
+ // CHECK: scf.condition(%[[CMP]]) %[[A1]], %[[A2]], %[[A3]]
+ // CHECK: %[[SHR:.+]] = arith.shrui %[[A1]], %[[C1]]
+ // CHECK: %[[SUB:.+]] = arith.subi %[[A2]], %[[C1]]
+ // CHECK: scf.yield %[[SHR]], %[[SUB]], %[[A3]]
+ %res = math.ctlz %arg : i32
+
+ // CHECK: return %[[WHILE]]#1
+ return %res : i32
+}
diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt
index dd2f726928ae6..ff62dc4fd053c 100644
--- a/mlir/test/lib/Dialect/Math/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt
@@ -1,7 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMathTestPasses
TestAlgebraicSimplification.cpp
- TestExpandTanh.cpp
+ TestExpandMath.cpp
TestPolynomialApproximation.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Math/TestExpandTanh.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
similarity index 55%
rename from mlir/test/lib/Dialect/Math/TestExpandTanh.cpp
rename to mlir/test/lib/Dialect/Math/TestExpandMath.cpp
index fd83452bc4292..82e0e71f5026b 100644
--- a/mlir/test/lib/Dialect/Math/TestExpandTanh.cpp
+++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp
@@ -1,4 +1,4 @@
-//===- TestExpandTanh.cpp - Test expand tanh op into exp form -------------===//
+//===- TestExpandMath.cpp - Test expand math op into exp form -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,35 +6,41 @@
//
//===----------------------------------------------------------------------===//
//
-// This file contains test passes for expanding tanh.
+// This file contains test passes for expanding math operations.
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
-struct TestExpandTanhPass
- : public PassWrapper<TestExpandTanhPass, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandTanhPass)
+struct TestExpandMathPass
+ : public PassWrapper<TestExpandMathPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass)
void runOnOperation() override;
- StringRef getArgument() const final { return "test-expand-tanh"; }
- StringRef getDescription() const final { return "Test expanding tanh"; }
+ StringRef getArgument() const final { return "test-expand-math"; }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithmeticDialect, scf::SCFDialect>();
+ }
+ StringRef getDescription() const final { return "Test expanding math"; }
};
} // namespace
-void TestExpandTanhPass::runOnOperation() {
+void TestExpandMathPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
+ populateExpandCtlzPattern(patterns);
populateExpandTanhPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
namespace mlir {
namespace test {
-void registerTestExpandTanhPass() { PassRegistration<TestExpandTanhPass>(); }
+void registerTestExpandMathPass() { PassRegistration<TestExpandMathPass>(); }
} // namespace test
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index dcd8946d9c404..e75c2758e88db 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -76,7 +76,7 @@ void registerTestDecomposeCallGraphTypes();
void registerTestDiagnosticsPass();
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
-void registerTestExpandTanhPass();
+void registerTestExpandMathPass();
void registerTestComposeSubView();
void registerTestMultiBuffering();
void registerTestIRVisitorsPass();
@@ -172,7 +172,7 @@ void registerTestPasses() {
mlir::test::registerTestDataLayoutQuery();
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();
- mlir::test::registerTestExpandTanhPass();
+ mlir::test::registerTestExpandMathPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestIRVisitorsPass();
More information about the Mlir-commits
mailing list