[Mlir-commits] [mlir] d4d50e4 - [mlir][tosa] Add lowering for tosa.clz using scf::whileOp
Rob Suderman
llvmlistbot at llvm.org
Thu Sep 9 16:03:55 PDT 2021
Author: natashaknk
Date: 2021-09-09T15:57:35-07:00
New Revision: d4d50e47107b6d923d342d5a6ee297b56c2d87f2
URL: https://github.com/llvm/llvm-project/commit/d4d50e47107b6d923d342d5a6ee297b56c2d87f2
DIFF: https://github.com/llvm/llvm-project/commit/d4d50e47107b6d923d342d5a6ee297b56c2d87f2.diff
LOG: [mlir][tosa] Add lowering for tosa.clz using scf::whileOp
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D109540
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 07b519a15c0ea..558ba4d04acbb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -310,6 +311,55 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
}
+ // tosa::ClzOp
+ if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
+ int bitWidth = elementTy.getIntOrFloatBitWidth();
+ auto zero =
+ rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+ auto leadingZeros = rewriter.create<mlir::ConstantOp>(
+ loc, IntegerAttr::get(elementTy, bitWidth));
+
+ SmallVector<Value> operands = {args[0], leadingZeros, zero};
+ SmallVector<Type> types = {elementTy, elementTy, elementTy};
+
+ auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
+ Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
+ Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
+
+ // The conditional block of the while loop.
+ {
+ rewriter.setInsertionPointToStart(&whileOp.before().front());
+ Value input = before->getArgument(0);
+ Value zero = before->getArgument(2);
+
+ Value inputLargerThanZero =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::ne, input, zero);
+ rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero,
+ before->getArguments());
+ }
+
+ // The body of the while loop: shift right until reaching a value of 0.
+ {
+ rewriter.setInsertionPointToStart(&whileOp.after().front());
+ Value input = after->getArgument(0);
+ Value leadingZeros = after->getArgument(1);
+
+ auto one = rewriter.create<mlir::ConstantOp>(
+ loc, IntegerAttr::get(elementTy, 1));
+ auto shifted = rewriter.create<mlir::UnsignedShiftRightOp>(
+ loc, resultTypes, input, one);
+ auto leadingZerosMinusOne =
+ rewriter.create<mlir::SubIOp>(loc, resultTypes, leadingZeros, one);
+
+ rewriter.create<scf::YieldOp>(
+ loc,
+ ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
+ }
+
+ rewriter.setInsertionPointAfter(whileOp);
+ return whileOp->getResult(1);
+ }
+
// tosa::LogicalAnd
if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
@@ -2905,6 +2955,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
PointwiseConverter<tosa::LogicalLeftShiftOp>,
PointwiseConverter<tosa::LogicalRightShiftOp>,
PointwiseConverter<tosa::ArithmeticRightShiftOp>,
+ PointwiseConverter<tosa::ClzOp>,
PointwiseConverter<tosa::SelectOp>,
PointwiseConverter<tosa::GreaterOp>,
PointwiseConverter<tosa::GreaterEqualOp>,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index b89d3f6ef0bc6..232f85a70ac87 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -32,15 +33,16 @@ struct TosaToLinalgOnTensors
: public TosaToLinalgOnTensorsBase<TosaToLinalgOnTensors> {
public:
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect, math::MathDialect,
- StandardOpsDialect, tensor::TensorDialect>();
+ registry
+ .insert<linalg::LinalgDialect, math::MathDialect, StandardOpsDialect,
+ tensor::TensorDialect, scf::SCFDialect>();
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
- tensor::TensorDialect>();
+ tensor::TensorDialect, scf::SCFDialect>();
target.addIllegalDialect<tosa::TosaDialect>();
// Not every TOSA op can be legalized to linalg.
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 9cf3eba69d1ad..4209172b27478 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -357,37 +357,45 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: addi
%12 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ // CHECK: scf.while
+ // CHECK: cmpi ne
+ // CHECK: scf.condition
+ // CHECK: shift_right_unsigned
+ // CHECK: subi
+ // CHECK: scf.yield
+ %13 = "tosa.clz"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+
// CHECK: linalg.generic
// CHECK: cmpi
- %13 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %14 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: cmpi
- %14 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+ %15 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: select
- %15 = "tosa.select"(%13, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %16 = "tosa.select"(%14, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %16 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %17 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %17 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+ %18 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ %19 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: cmpi
// CHECK: select
- %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+ %20 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: constant -32768
@@ -397,27 +405,27 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: cmpi slt
// CHECK: select
// CHECK: trunci
- %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+ %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
// CHECK: linalg.generic
// CHECK: sexti
- %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+ %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
// CHECK: linalg.generic
// CHECK: constant 0
// CHECK: cmpi
- %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+ %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
// CHECK: linalg.generic
// CHECK: sitofp
- %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+ %24 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: constant 0
// CHECK: cmpi sgt
// CHECK: subi
// CHECK: select
- %24 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+ %25 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
return
}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index f9e468edfde2c..631620d8b9f4e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6497,6 +6497,7 @@ cc_library(
":LinalgOps",
":MathDialect",
":Pass",
+ ":SCFDialect",
":StandardOps",
":TensorDialect",
":TosaDialect",
More information about the Mlir-commits
mailing list