[Mlir-commits] [mlir] d4d50e4 - [mlir][tosa] Add lowering for tosa.clz using scf::whileOp

Rob Suderman llvmlistbot at llvm.org
Thu Sep 9 16:03:55 PDT 2021


Author: natashaknk
Date: 2021-09-09T15:57:35-07:00
New Revision: d4d50e47107b6d923d342d5a6ee297b56c2d87f2

URL: https://github.com/llvm/llvm-project/commit/d4d50e47107b6d923d342d5a6ee297b56c2d87f2
DIFF: https://github.com/llvm/llvm-project/commit/d4d50e47107b6d923d342d5a6ee297b56c2d87f2.diff

LOG: [mlir][tosa] Add lowering for tosa.clz using scf::whileOp

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D109540

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 07b519a15c0ea..558ba4d04acbb 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -310,6 +311,55 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
     return rewriter.create<mlir::AddIOp>(loc, resultTypes, result, extended);
   }
 
+  // tosa::ClzOp
+  if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
+    int bitWidth = elementTy.getIntOrFloatBitWidth();
+    auto zero =
+        rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
+    auto leadingZeros = rewriter.create<mlir::ConstantOp>(
+        loc, IntegerAttr::get(elementTy, bitWidth));
+
+    SmallVector<Value> operands = {args[0], leadingZeros, zero};
+    SmallVector<Type> types = {elementTy, elementTy, elementTy};
+
+    auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
+    Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
+    Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
+
+    // The conditional block of the while loop.
+    {
+      rewriter.setInsertionPointToStart(&whileOp.before().front());
+      Value input = before->getArgument(0);
+      Value zero = before->getArgument(2);
+
+      Value inputLargerThanZero =
+          rewriter.create<CmpIOp>(loc, CmpIPredicate::ne, input, zero);
+      rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero,
+                                        before->getArguments());
+    }
+
+    // The body of the while loop: shift right until reaching a value of 0.
+    {
+      rewriter.setInsertionPointToStart(&whileOp.after().front());
+      Value input = after->getArgument(0);
+      Value leadingZeros = after->getArgument(1);
+
+      auto one = rewriter.create<mlir::ConstantOp>(
+          loc, IntegerAttr::get(elementTy, 1));
+      auto shifted = rewriter.create<mlir::UnsignedShiftRightOp>(
+          loc, resultTypes, input, one);
+      auto leadingZerosMinusOne =
+          rewriter.create<mlir::SubIOp>(loc, resultTypes, leadingZeros, one);
+
+      rewriter.create<scf::YieldOp>(
+          loc,
+          ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
+    }
+
+    rewriter.setInsertionPointAfter(whileOp);
+    return whileOp->getResult(1);
+  }
+
   // tosa::LogicalAnd
   if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
     return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
@@ -2905,6 +2955,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       PointwiseConverter<tosa::LogicalLeftShiftOp>,
       PointwiseConverter<tosa::LogicalRightShiftOp>,
       PointwiseConverter<tosa::ArithmeticRightShiftOp>,
+      PointwiseConverter<tosa::ClzOp>,
       PointwiseConverter<tosa::SelectOp>,
       PointwiseConverter<tosa::GreaterOp>,
       PointwiseConverter<tosa::GreaterEqualOp>,

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index b89d3f6ef0bc6..232f85a70ac87 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -32,15 +33,16 @@ struct TosaToLinalgOnTensors
     : public TosaToLinalgOnTensorsBase<TosaToLinalgOnTensors> {
 public:
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<linalg::LinalgDialect, math::MathDialect,
-                    StandardOpsDialect, tensor::TensorDialect>();
+    registry
+        .insert<linalg::LinalgDialect, math::MathDialect, StandardOpsDialect,
+                tensor::TensorDialect, scf::SCFDialect>();
   }
 
   void runOnFunction() override {
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
     target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
-                           tensor::TensorDialect>();
+                           tensor::TensorDialect, scf::SCFDialect>();
     target.addIllegalDialect<tosa::TosaDialect>();
 
     // Not every TOSA op can be legalized to linalg.

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 9cf3eba69d1ad..4209172b27478 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -357,37 +357,45 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: addi
   %12 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
+  // CHECK: scf.while
+  // CHECK: cmpi ne
+  // CHECK: scf.condition
+  // CHECK: shift_right_unsigned
+  // CHECK: subi
+  // CHECK: scf.yield
+  %13 = "tosa.clz"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+
   // CHECK: linalg.generic
   // CHECK: cmpi
-  %13 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %14 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
-  %14 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
+  %15 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: select
-  %15 = "tosa.select"(%13, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %16 = "tosa.select"(%14, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %16 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %17 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %17 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %18 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %19 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: cmpi
   // CHECK: select
-  %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
+  %20 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: constant -32768
@@ -397,27 +405,27 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
   // CHECK: cmpi slt
   // CHECK: select
   // CHECK: trunci
-  %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
+  %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
 
   // CHECK: linalg.generic
   // CHECK: sexti
-  %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
+  %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
 
   // CHECK: linalg.generic
   // CHECK: constant 0
   // CHECK: cmpi
-  %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
+  %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
 
   // CHECK: linalg.generic
   // CHECK: sitofp
-  %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
+  %24 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: constant 0
   // CHECK: cmpi sgt
   // CHECK: subi
   // CHECK: select
-  %24 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
+  %25 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
 
   return
 }

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index f9e468edfde2c..631620d8b9f4e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6497,6 +6497,7 @@ cc_library(
         ":LinalgOps",
         ":MathDialect",
         ":Pass",
+        ":SCFDialect",
         ":StandardOps",
         ":TensorDialect",
         ":TosaDialect",


        


More information about the Mlir-commits mailing list