[Mlir-commits] [mlir] 286e7bd - [mlir][tosa] Make tosa MakeBroadcastable pass handle unreanked tensors.

Rob Suderman llvmlistbot at llvm.org
Thu Jul 22 17:58:49 PDT 2021


Author: Rob Suderman
Date: 2021-07-22T17:57:05-07:00
New Revision: 286e7bdd3ea4f1c7a90a2877e28f353dcd9a7493

URL: https://github.com/llvm/llvm-project/commit/286e7bdd3ea4f1c7a90a2877e28f353dcd9a7493
DIFF: https://github.com/llvm/llvm-project/commit/286e7bdd3ea4f1c7a90a2877e28f353dcd9a7493.diff

LOG: [mlir][tosa] Make tosa MakeBroadcastable pass handle unreanked tensors.

If this pass executes without shape inference its possible for unranked tensors
to appear in the IR. This pass should gracefully handle unranked tensors.

Differential Revision: https://reviews.llvm.org/D106617

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index e850e1f517d2f..98df91198aca1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -108,18 +108,24 @@ static void computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
 /// operations equal. Returns the updated input1 and input2 for the original
 /// input. The caller is expected to use these to rewrite the original operator
 /// with the RESHAPE now in the graph.
-static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
-                                RankedTensorType outputType, Value input1,
-                                Value input2, Value &outInput1,
-                                Value &outInput2) {
+static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
+                                          Location loc,
+                                          RankedTensorType outputType,
+                                          Value input1, Value input2,
+                                          Value &outInput1, Value &outInput2) {
+  auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
+  auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
 
-  int64_t input1Rank = input1.getType().cast<RankedTensorType>().getRank();
-  int64_t input2Rank = input2.getType().cast<RankedTensorType>().getRank();
+  if (!input1Ty || !input2Ty)
+    return failure();
+
+  int64_t input1Rank = input1Ty.getRank();
+  int64_t input2Rank = input2Ty.getRank();
 
   Value higherTensorValue, lowerTensorValue;
-  // return if rank already match
+  // Cannot rewrite as its already correct.
   if (input1Rank == input2Rank)
-    return 1;
+    return failure();
 
   if (input1Rank > input2Rank) {
     higherTensorValue = input1;
@@ -129,24 +135,27 @@ static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
     lowerTensorValue = input1;
   }
 
-  ArrayRef<int64_t> outputRankShape = outputType.getShape();
   ArrayRef<int64_t> higherRankShape =
       higherTensorValue.getType().cast<RankedTensorType>().getShape();
   (void)higherRankShape;
   ArrayRef<int64_t> lowerRankShape =
       lowerTensorValue.getType().cast<RankedTensorType>().getShape();
 
-  // outputRank == higherRank == max(input1Rank, input2Rank)
-  assert(higherRankShape.size() == outputRankShape.size());
-
   SmallVector<int64_t, 4> reshapeOutputShape;
 
-  computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape);
+  computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape);
 
   auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
   auto reshapeOutputType = RankedTensorType::get(
       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
 
+  // Verify the rank agrees with the output type if the output type is ranked.
+  if (outputType) {
+    if (outputType.getShape().size() != reshapeOutputShape.size() ||
+        outputType.getShape().size() != higherRankShape.size())
+      return failure();
+  }
+
   auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
       loc, reshapeOutputType, lowerTensorValue,
       rewriter.getI64ArrayAttr(reshapeOutputShape));
@@ -159,7 +168,7 @@ static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
     outInput2 = higherTensorValue;
   }
 
-  return 0;
+  return success();
 }
 
 namespace {
@@ -173,11 +182,13 @@ struct ConvertTosaOp : public OpRewritePattern<OpTy> {
     Value input1 = tosaBinaryOp.input1();
     Value input2 = tosaBinaryOp.input2();
     Value output = tosaBinaryOp.getResult();
-    auto outputType = output.getType().cast<RankedTensorType>();
+
+    auto outputType = output.getType().dyn_cast<RankedTensorType>();
 
     Value outInput1, outInput2;
     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
-                             input1, input2, outInput1, outInput2))
+                             input1, input2, outInput1, outInput2)
+            .failed())
       return failure();
 
     rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
@@ -200,11 +211,12 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
     Value input2 = tosaBinaryOp.input2();
     int32_t shift = tosaBinaryOp.shift();
     Value output = tosaBinaryOp.getResult();
-    auto outputType = output.getType().cast<RankedTensorType>();
+    auto outputType = output.getType().dyn_cast<RankedTensorType>();
 
     Value outInput1, outInput2;
     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
-                             input1, input2, outInput1, outInput2))
+                             input1, input2, outInput1, outInput2)
+            .failed())
       return failure();
 
     rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
@@ -233,7 +245,8 @@ struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
 
     Value outInput1, outInput2;
     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
-                             input1, input2, outInput1, outInput2))
+                             input1, input2, outInput1, outInput2)
+            .failed())
       return failure();
 
     rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(


        


More information about the Mlir-commits mailing list