[Mlir-commits] [mlir] 936819b - [mlir][tosa] make Select operator broadcastable in the pass

Rob Suderman llvmlistbot at llvm.org
Wed Feb 8 16:53:17 PST 2023


Author: TatWai Chong
Date: 2023-02-08T16:37:19-08:00
New Revision: 936819bf55af580a94e73ff5c7e4c1cc4d5d43f6

URL: https://github.com/llvm/llvm-project/commit/936819bf55af580a94e73ff5c7e4c1cc4d5d43f6
DIFF: https://github.com/llvm/llvm-project/commit/936819bf55af580a94e73ff5c7e4c1cc4d5d43f6.diff

LOG: [mlir][tosa] make Select operator broadcastable in the pass

Making Select broadcastable can let this op easier to use.

Change-Id: I4a4bec4f7cbe532e954a5b4fe53136676ab4300c

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D139156

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
    mlir/test/Dialect/Tosa/broadcast.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 76b7e9560d406..b18e3b4bd2777 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -75,28 +75,28 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
 }
 
 /// Common code to create the reshape op where necessary to make the rank of the
-/// operations equal. Returns the updated input1 and input2 for the original
-/// input. The caller is expected to use these to rewrite the original operator
-/// with the RESHAPE now in the graph.
+/// operations equal. input1 and input2 will be updated when the rank has
+/// changed. The caller is expected to use these to rewrite the original
+/// operator with the RESHAPE now in the graph.
 static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
                                           Location loc,
                                           RankedTensorType outputType,
-                                          Value input1, Value input2,
-                                          Value &outInput1, Value &outInput2) {
+                                          Value &input1, Value &input2) {
   auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
   auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
 
-  if (!input1Ty || !input2Ty)
-    return failure();
+  if (!input1Ty || !input2Ty) {
+    return rewriter.notifyMatchFailure(loc, "input not a ranked tensor");
+  }
 
   int64_t input1Rank = input1Ty.getRank();
   int64_t input2Rank = input2Ty.getRank();
 
-  Value higherTensorValue, lowerTensorValue;
-  // Cannot rewrite as its already correct.
   if (input1Rank == input2Rank)
-    return failure();
+    return rewriter.notifyMatchFailure(loc,
+                                       "cannot rewrite as its already correct");
 
+  Value higherTensorValue, lowerTensorValue;
   if (input1Rank > input2Rank) {
     higherTensorValue = input1;
     lowerTensorValue = input2;
@@ -107,7 +107,6 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
 
   ArrayRef<int64_t> higherRankShape =
       higherTensorValue.getType().cast<RankedTensorType>().getShape();
-  (void)higherRankShape;
   ArrayRef<int64_t> lowerRankShape =
       lowerTensorValue.getType().cast<RankedTensorType>().getShape();
 
@@ -115,7 +114,7 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
 
   if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
           .failed())
-    return failure();
+    return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type");
 
   auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
   auto reshapeOutputType = RankedTensorType::get(
@@ -125,7 +124,8 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
   if (outputType) {
     if (outputType.getShape().size() != reshapeOutputShape.size() ||
         outputType.getShape().size() != higherRankShape.size())
-      return failure();
+      return rewriter.notifyMatchFailure(
+          loc, "the reshaped type doesn't agrees with the ranked output type");
   }
 
   auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
@@ -133,18 +133,19 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
       rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
 
   if (input1Rank > input2Rank) {
-    outInput1 = higherTensorValue;
-    outInput2 = reshapeLower.getResult();
+    input1 = higherTensorValue;
+    input2 = reshapeLower.getResult();
   } else {
-    outInput1 = reshapeLower.getResult();
-    outInput2 = higherTensorValue;
+    input1 = reshapeLower.getResult();
+    input2 = higherTensorValue;
   }
 
   return success();
 }
 
 namespace {
-template <typename OpTy> struct ConvertTosaOp : public OpRewritePattern<OpTy> {
+template <typename OpTy>
+struct ConvertTosaOp : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
@@ -158,14 +159,12 @@ template <typename OpTy> struct ConvertTosaOp : public OpRewritePattern<OpTy> {
     if (!outputType)
       return failure();
 
-    Value outInput1, outInput2;
     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
-                             input1, input2, outInput1, outInput2)
+                             input1, input2)
             .failed())
       return failure();
 
-    rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
-                                      outInput2);
+    rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2);
 
     return success();
   }
@@ -188,14 +187,13 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
     if (!outputType)
       return failure();
 
-    Value outInput1, outInput2;
     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
-                             input1, input2, outInput1, outInput2)
+                             input1, input2)
             .failed())
       return failure();
 
-    rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
-                                             outInput1, outInput2, shift);
+    rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1,
+                                             input2, shift);
 
     return success();
   }
@@ -220,14 +218,63 @@ struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
     if (!outputType)
       return failure();
 
-    Value outInput1, outInput2;
     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
-                             input1, input2, outInput1, outInput2)
+                             input1, input2)
             .failed())
       return failure();
 
     rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
-        tosaBinaryOp, outputType, outInput1, outInput2, round);
+        tosaBinaryOp, outputType, input1, input2, round);
+
+    return success();
+  }
+};
+
+template <>
+struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
+  using OpRewritePattern<tosa::SelectOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
+                                PatternRewriter &rewriter) const override {
+
+    Value input1 = tosaOp.getPred();
+    Value input2 = tosaOp.getOnTrue();
+    Value input3 = tosaOp.getOnFalse();
+    Value output = tosaOp.getResult();
+
+    auto outputType = output.getType().dyn_cast<RankedTensorType>();
+    if (!outputType)
+      return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor");
+
+    // Apply broadcasting to each pair of inputs separately, and chain them as
+    // compound as below so that the broadcasting happens all at once.
+    bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
+                                          input1, input2)
+                         .succeeded();
+
+    bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
+                                          input1, input3)
+                         .succeeded();
+
+    bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
+                                          input2, input3)
+                         .succeeded();
+
+    if (!reshaped1 && !reshaped2 && !reshaped3)
+      return rewriter.notifyMatchFailure(
+          tosaOp,
+          "cannot rewrite as the rank of all operands is already aligned");
+
+    int32_t result1Rank = input1.getType().cast<RankedTensorType>().getRank();
+    int32_t result2Rank = input2.getType().cast<RankedTensorType>().getRank();
+    int32_t result3Rank = input3.getType().cast<RankedTensorType>().getRank();
+
+    if ((result1Rank != result2Rank) || (result2Rank != result3Rank))
+      return rewriter.notifyMatchFailure(
+          tosaOp, "not all ranks are aligned with each other");
+
+    rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1,
+                                                input2, input3);
 
     return success();
   }
@@ -263,6 +310,7 @@ struct TosaMakeBroadcastable
     patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
     patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
     patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
+    patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
     patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
     (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
   }

diff  --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
index 3858399df59be..ed1cd1e17b24d 100644
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ b/mlir/test/Dialect/Tosa/broadcast.mlir
@@ -195,3 +195,91 @@ func.func @test_broadcast_scalar(%arg0: tensor<i32>, %arg1: tensor<17x16x15x14xi
   %0 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
   return %0 : tensor<17x16x15x14xi32>
 }
+
+// -----
+// CHECK-LABEL: broadcast_select_both_input
+func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<1x16x16xf32> {
+  // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 1>}
+  // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1>}
+  // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]])
+  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x16x16xi1>, tensor<f32>, tensor<f32>) -> tensor<1x16x16xf32>
+  return %0 : tensor<1x16x16xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_one_input
+func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor<f32>) -> tensor<17x16x15x14xf32> {
+  // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1, 1>}
+  // CHECK: %[[VAL_1:.*]] = "tosa.select"(%arg0, %arg1, %[[VAL_0]])
+  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor<f32>) -> tensor<17x16x15x14xf32>
+  return %0 : tensor<17x16x15x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_predicate
+func.func @test_broadcast_select_predicate(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 1, 1>}
+  // CHECK: %[[VAL_1:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %arg2)
+  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_abc
+func.func @test_broadcast_select_abc(%arg0: tensor<i1>, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 1, 1>}
+  // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 32, 8>}
+  // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2)
+  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_acb
+func.func @test_broadcast_select_acb(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 1, 1>}
+  // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 32, 8>}
+  // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]])
+  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_bac
+func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 32, 8>}
+  // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 1, 1>}
+  // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2)
+  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor<f32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_bca
+func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
+  // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 32, 8>}
+  // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1, 1>}
+  // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]])
+  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_cab
+func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
+  // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 1, 1>}
+  // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 32, 8>}
+  // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]])
+  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor<f32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_cba
+func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
+  // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 32, 8>}
+  // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1, 1>}
+  // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]])
+  %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}


        


More information about the Mlir-commits mailing list