[Mlir-commits] [mlir] 936819b - [mlir][tosa] make Select operator broadcastable in the pass
Rob Suderman
llvmlistbot at llvm.org
Wed Feb 8 16:53:17 PST 2023
Author: TatWai Chong
Date: 2023-02-08T16:37:19-08:00
New Revision: 936819bf55af580a94e73ff5c7e4c1cc4d5d43f6
URL: https://github.com/llvm/llvm-project/commit/936819bf55af580a94e73ff5c7e4c1cc4d5d43f6
DIFF: https://github.com/llvm/llvm-project/commit/936819bf55af580a94e73ff5c7e4c1cc4d5d43f6.diff
LOG: [mlir][tosa] make Select operator broadcastable in the pass
Making Select broadcastable can let this op easier to use.
Change-Id: I4a4bec4f7cbe532e954a5b4fe53136676ab4300c
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D139156
Added:
Modified:
mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
mlir/test/Dialect/Tosa/broadcast.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 76b7e9560d406..b18e3b4bd2777 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -75,28 +75,28 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
}
/// Common code to create the reshape op where necessary to make the rank of the
-/// operations equal. Returns the updated input1 and input2 for the original
-/// input. The caller is expected to use these to rewrite the original operator
-/// with the RESHAPE now in the graph.
+/// operations equal. input1 and input2 will be updated when the rank has
+/// changed. The caller is expected to use these to rewrite the original
+/// operator with the RESHAPE now in the graph.
static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
Location loc,
RankedTensorType outputType,
- Value input1, Value input2,
- Value &outInput1, Value &outInput2) {
+ Value &input1, Value &input2) {
auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
- if (!input1Ty || !input2Ty)
- return failure();
+ if (!input1Ty || !input2Ty) {
+ return rewriter.notifyMatchFailure(loc, "input not a ranked tensor");
+ }
int64_t input1Rank = input1Ty.getRank();
int64_t input2Rank = input2Ty.getRank();
- Value higherTensorValue, lowerTensorValue;
- // Cannot rewrite as its already correct.
if (input1Rank == input2Rank)
- return failure();
+ return rewriter.notifyMatchFailure(loc,
+ "cannot rewrite as its already correct");
+ Value higherTensorValue, lowerTensorValue;
if (input1Rank > input2Rank) {
higherTensorValue = input1;
lowerTensorValue = input2;
@@ -107,7 +107,6 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
ArrayRef<int64_t> higherRankShape =
higherTensorValue.getType().cast<RankedTensorType>().getShape();
- (void)higherRankShape;
ArrayRef<int64_t> lowerRankShape =
lowerTensorValue.getType().cast<RankedTensorType>().getShape();
@@ -115,7 +114,7 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
.failed())
- return failure();
+ return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type");
auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
auto reshapeOutputType = RankedTensorType::get(
@@ -125,7 +124,8 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
if (outputType) {
if (outputType.getShape().size() != reshapeOutputShape.size() ||
outputType.getShape().size() != higherRankShape.size())
- return failure();
+ return rewriter.notifyMatchFailure(
+ loc, "the reshaped type doesn't agrees with the ranked output type");
}
auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
@@ -133,18 +133,19 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
if (input1Rank > input2Rank) {
- outInput1 = higherTensorValue;
- outInput2 = reshapeLower.getResult();
+ input1 = higherTensorValue;
+ input2 = reshapeLower.getResult();
} else {
- outInput1 = reshapeLower.getResult();
- outInput2 = higherTensorValue;
+ input1 = reshapeLower.getResult();
+ input2 = higherTensorValue;
}
return success();
}
namespace {
-template <typename OpTy> struct ConvertTosaOp : public OpRewritePattern<OpTy> {
+template <typename OpTy>
+struct ConvertTosaOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
@@ -158,14 +159,12 @@ template <typename OpTy> struct ConvertTosaOp : public OpRewritePattern<OpTy> {
if (!outputType)
return failure();
- Value outInput1, outInput2;
if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
- input1, input2, outInput1, outInput2)
+ input1, input2)
.failed())
return failure();
- rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
- outInput2);
+ rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2);
return success();
}
@@ -188,14 +187,13 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
if (!outputType)
return failure();
- Value outInput1, outInput2;
if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
- input1, input2, outInput1, outInput2)
+ input1, input2)
.failed())
return failure();
- rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
- outInput1, outInput2, shift);
+ rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1,
+ input2, shift);
return success();
}
@@ -220,14 +218,63 @@ struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
if (!outputType)
return failure();
- Value outInput1, outInput2;
if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
- input1, input2, outInput1, outInput2)
+ input1, input2)
.failed())
return failure();
rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
- tosaBinaryOp, outputType, outInput1, outInput2, round);
+ tosaBinaryOp, outputType, input1, input2, round);
+
+ return success();
+ }
+};
+
+template <>
+struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
+ using OpRewritePattern<tosa::SelectOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
+ PatternRewriter &rewriter) const override {
+
+ Value input1 = tosaOp.getPred();
+ Value input2 = tosaOp.getOnTrue();
+ Value input3 = tosaOp.getOnFalse();
+ Value output = tosaOp.getResult();
+
+ auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ if (!outputType)
+ return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor");
+
+ // Apply broadcasting to each pair of inputs separately, and chain them as
+ // compound as below so that the broadcasting happens all at once.
+ bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
+ input1, input2)
+ .succeeded();
+
+ bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
+ input1, input3)
+ .succeeded();
+
+ bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
+ input2, input3)
+ .succeeded();
+
+ if (!reshaped1 && !reshaped2 && !reshaped3)
+ return rewriter.notifyMatchFailure(
+ tosaOp,
+ "cannot rewrite as the rank of all operands is already aligned");
+
+ int32_t result1Rank = input1.getType().cast<RankedTensorType>().getRank();
+ int32_t result2Rank = input2.getType().cast<RankedTensorType>().getRank();
+ int32_t result3Rank = input3.getType().cast<RankedTensorType>().getRank();
+
+ if ((result1Rank != result2Rank) || (result2Rank != result3Rank))
+ return rewriter.notifyMatchFailure(
+ tosaOp, "not all ranks are aligned with each other");
+
+ rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1,
+ input2, input3);
return success();
}
@@ -263,6 +310,7 @@ struct TosaMakeBroadcastable
patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
+ patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
index 3858399df59be..ed1cd1e17b24d 100644
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ b/mlir/test/Dialect/Tosa/broadcast.mlir
@@ -195,3 +195,91 @@ func.func @test_broadcast_scalar(%arg0: tensor<i32>, %arg1: tensor<17x16x15x14xi
%0 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
return %0 : tensor<17x16x15x14xi32>
}
+
+// -----
+// CHECK-LABEL: broadcast_select_both_input
+func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<1x16x16xf32> {
+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 1>}
+ // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1>}
+ // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]])
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x16x16xi1>, tensor<f32>, tensor<f32>) -> tensor<1x16x16xf32>
+ return %0 : tensor<1x16x16xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_one_input
+func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor<f32>) -> tensor<17x16x15x14xf32> {
+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1, 1>}
+ // CHECK: %[[VAL_1:.*]] = "tosa.select"(%arg0, %arg1, %[[VAL_0]])
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor<f32>) -> tensor<17x16x15x14xf32>
+ return %0 : tensor<17x16x15x14xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_predicate
+func.func @test_broadcast_select_predicate(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 1, 1>}
+ // CHECK: %[[VAL_1:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %arg2)
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_abc
+func.func @test_broadcast_select_abc(%arg0: tensor<i1>, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 1, 1>}
+ // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 32, 8>}
+ // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2)
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_acb
+func.func @test_broadcast_select_acb(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 1, 1>}
+ // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 32, 8>}
+ // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]])
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_bac
+func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 32, 8>}
+ // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 1, 1>}
+ // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2)
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor<f32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_bca
+func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 32, 8>}
+ // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1, 1>}
+ // CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]])
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_cab
+func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 1, 1>}
+ // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 32, 8>}
+ // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]])
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor<f32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: broadcast_select_cba
+func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
+ // CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 32, 8>}
+ // CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1, 1>}
+ // CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]])
+ %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
More information about the Mlir-commits
mailing list