[Mlir-commits] [mlir] b053228 - [mlir][tosa] Add shape inference for tosa.while

Rob Suderman llvmlistbot at llvm.org
Fri Sep 10 13:13:05 PDT 2021


Author: Rob Suderman
Date: 2021-09-10T13:11:53-07:00
New Revision: b0532286fe486b5ee208bae53c7afba278cbd609

URL: https://github.com/llvm/llvm-project/commit/b0532286fe486b5ee208bae53c7afba278cbd609
DIFF: https://github.com/llvm/llvm-project/commit/b0532286fe486b5ee208bae53c7afba278cbd609.diff

LOG: [mlir][tosa] Add shape inference for tosa.while

Tosa.while shape inference requires repeatedly running shape inference across
the body of the loop until the types become static as we do not know the number
of iterations required by the loop body. Once the least specific arguments are
known they are propagated to both regions.

To determine the final end type, the least restrictive types are determined
from all yields.

Differential Revision: https://reviews.llvm.org/D108801

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 4e8c80eaf761d..f633d2030f16f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1376,12 +1376,12 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
   }];
 
   let arguments = (ins
-    Variadic<Tosa_RankedTensor>:$input1,
+    Variadic<Tosa_Tensor>:$input1,
     I64Attr:$axis
   );
 
   let results = (outs
-    Tosa_RankedTensor:$output
+    Tosa_Tensor:$output
   );
 
   let hasCanonicalizer = 1;
@@ -1846,6 +1846,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if", [
 //===----------------------------------------------------------------------===//
 def Tosa_WhileOp : Tosa_Op<"while_loop", [
        DeclareOpInterfaceMethods<LoopLikeOpInterface>,
+       DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                                 ["inferReturnTypeComponents"]>,
        SingleBlockImplicitTerminator<"YieldOp">,
        RecursiveSideEffects]> {
   let summary = "output = input; While (Cond(output)) {output = Body(output)}";

diff  --git a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
index b7f742e4e31c6..1bfacd09ed68c 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ShapeUtils.h
@@ -15,6 +15,7 @@
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 
@@ -61,6 +62,10 @@ struct ValueKnowledge {
     return ValueKnowledge(false, {}, Type());
   }
 
+  ShapedTypeComponents getShapedTypeComponents() const {
+    return hasRank ? ShapedTypeComponents(sizes) : ShapedTypeComponents();
+  }
+
   Type getType() const {
     if (hasRank)
       return RankedTensorType::get(llvm::makeArrayRef(sizes), dtype);

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6ee487bd8e651..4d28d24789d23 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseMap.h"
 
 using namespace mlir;
 using namespace mlir::tosa;
@@ -1437,13 +1438,52 @@ LogicalResult IfOp::inferReturnTypeComponents(
   }
 
   for (const ValueKnowledge &result : resultKnowledge) {
-    if (result.hasRank) {
-      inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes));
-    } else {
-      inferredReturnShapes.push_back(ShapedTypeComponents());
+    inferredReturnShapes.push_back(result.getShapedTypeComponents());
+  }
+
+  return success();
+}
+
+LogicalResult WhileOp::inferReturnTypeComponents(
+    MLIRContext *context, ::llvm::Optional<Location> location,
+    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  llvm::SmallVector<tosa::YieldOp> yieldOps;
+  for (auto &block : *regions[1])
+    if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
+      yieldOps.push_back(returnOp);
+
+  // TOSA's while must have a tosa.yield as its terminator. If not found this
+  // tosa.while is invalid.
+  if (yieldOps.empty())
+    return failure();
+
+  // Get the initial type information from the operand types.
+  llvm::SmallVector<ValueKnowledge> resultKnowledge;
+  resultKnowledge.reserve(yieldOps.front().getNumOperands());
+  for (auto operand : yieldOps.front().getOperands()) {
+    resultKnowledge.push_back(
+        ValueKnowledge::getKnowledgeFromType(operand.getType()));
+  }
+
+  for (auto yieldOp : yieldOps) {
+    if (resultKnowledge.size() != yieldOp.getNumOperands())
+      return failure();
+
+    for (auto it : llvm::enumerate(yieldOp.getOperands())) {
+      int32_t index = it.index();
+      if (auto meet = ValueKnowledge::meet(
+              resultKnowledge[index],
+              ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
+        resultKnowledge[index] = meet;
+      };
     }
   }
 
+  for (const ValueKnowledge &result : resultKnowledge) {
+    inferredReturnShapes.push_back(result.getShapedTypeComponents());
+  }
+
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 006f4148ea1fe..cd94be1abf7e2 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -34,8 +34,9 @@ namespace {
 
 void propagateShapesInRegion(Region &region);
 
-void propagateShapesToTosaIf(Operation &op) {
-  tosa::IfOp ifOp = dyn_cast<tosa::IfOp>(op);
+void propagateShapesToTosaIf(
+    Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
+  IfOp ifOp = dyn_cast<IfOp>(op);
   if (!ifOp)
     return;
 
@@ -44,6 +45,17 @@ void propagateShapesToTosaIf(Operation &op) {
     if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
       return;
 
+    for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
+      auto inferredTy = shapesStorage[op.getOperand(i)];
+      auto blockArg = frontBlock.getArgument(i - 1);
+      auto oldType = blockArg.getType().cast<ShapedType>();
+
+      if (inferredTy.hasRank()) {
+        Type newType = oldType.clone(inferredTy.getDims());
+        blockArg.setType(newType);
+      }
+    }
+
     for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
       ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
           ifOp.getOperand(i + 1).getType());
@@ -58,8 +70,113 @@ void propagateShapesToTosaIf(Operation &op) {
 
     propagateShapesInRegion(region);
   }
+}
+
+void propagateShapesToTosaWhile(
+    Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
+  WhileOp whileOp = dyn_cast<WhileOp>(op);
+  if (!whileOp)
+    return;
+
+  // Determine what the expected argument types are to the cond/body blocks.
+  // The expected arguments should be compatible with ever iteration of the
+  // loop body / condition for tosa.while.
+  llvm::SmallVector<Type> argTypes;
+  for (auto operand : op.getOperands()) {
+    auto operandTy = operand.getType().cast<ShapedType>();
+    auto shapedTypeComponent = shapesStorage[operand];
+    if (shapedTypeComponent.hasRank()) {
+      auto newTy = operandTy.clone(shapedTypeComponent.getDims());
+      argTypes.push_back(newTy);
+    } else {
+      argTypes.push_back(operand.getType());
+    }
+  }
+
+  // Save out the type information so we can restore at the end.
+  llvm::DenseMap<Value, Type> originalTypeMap;
+  for (auto &block : op.getRegion(1)) {
+    for (auto arg : block.getArguments())
+      originalTypeMap[arg] = arg.getType();
+    for (auto &op : block)
+      for (auto result : op.getResults())
+        originalTypeMap[result] = result.getType();
+  }
+
+  bool hasNewTypes = true;
+  while (hasNewTypes) {
+
+    // Set types on the block args.
+    Region &bodyRegion = op.getRegion(1);
+    Block &block = bodyRegion.front();
+    for (int i = 0, s = argTypes.size(); i < s; i++) {
+      block.getArgument(i).setType(argTypes[i]);
+    }
+
+    // Propagate to the end.
+    propagateShapesInRegion(bodyRegion);
+
+    // Find all the tosa yield types and verify there is atleast one.
+    llvm::SmallVector<YieldOp> yieldOps;
+    for (auto &block : bodyRegion)
+      if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
+        yieldOps.push_back(yieldOp);
+
+    if (yieldOps.empty())
+      return;
 
-  return;
+    // Using the new tosa.yield operand types, infer the new subtypes.
+    llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
+    for (auto ty : argTypes) {
+      yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
+    }
+
+    for (auto yieldOp : yieldOps) {
+      for (auto it : llvm::enumerate(yieldOp.getOperands())) {
+        auto newKnowledge =
+            ValueKnowledge::getKnowledgeFromType(it.value().getType());
+        yieldTypeInfo[it.index()] =
+            ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
+      }
+    }
+
+    // This should never happen.
+    if (yieldTypeInfo.size() != argTypes.size()) {
+      op.emitWarning("has a tosa.yield with the incorrect number of operands");
+      return;
+    }
+
+    // Determine the new block args and see if any changed.
+    hasNewTypes = false;
+    for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
+      Type newType = yieldTypeInfo[i].getType();
+      hasNewTypes |= (newType != argTypes[i]);
+      argTypes[i] = newType;
+    }
+
+    // The types inferred in the block assume the operand types specified for
+    // this iteration. We need to restore the original types to ensure that
+    // future iterations only use the already specified types, not possible
+    // types from previous iterations.
+    for (auto &block : bodyRegion) {
+      for (auto arg : block.getArguments())
+        arg.setType(originalTypeMap[arg]);
+      for (auto &op : block)
+        for (auto result : op.getResults())
+          result.setType(originalTypeMap[result]);
+    }
+  }
+
+  // We now set the block arguments according to the most recent shape
+  // inference results. This gives us the block arg types for the next
+  // iteration.
+  for (auto &region : op.getRegions()) {
+    for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
+      region.front().getArgument(i).setType(argTypes[i]);
+    }
+
+    propagateShapesInRegion(region);
+  }
 }
 
 void propagateShapesInRegion(Region &region) {
@@ -80,11 +197,11 @@ void propagateShapesInRegion(Region &region) {
 
   for (auto &block : region) {
     for (Operation &op : block) {
-      if (op.getDialect()->getNamespace() !=
-          tosa::TosaDialect::getDialectNamespace())
+      if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
         continue;
 
-      propagateShapesToTosaIf(op);
+      propagateShapesToTosaIf(op, shapesStorage);
+      propagateShapesToTosaWhile(op, shapesStorage);
 
       InferShapedTypeOpInterface shapeInterface =
           dyn_cast<InferShapedTypeOpInterface>(op);
@@ -110,7 +227,7 @@ void propagateShapesInRegion(Region &region) {
             if (isa<ReturnOp>(user))
               continue;
             if (user->getDialect()->getNamespace() ==
-                tosa::TosaDialect::getDialectNamespace())
+                TosaDialect::getDialectNamespace())
               continue;
 
             replaceable = false;

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 50189d46a3882..2eecdd29de9dd 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1043,14 +1043,16 @@ func @resize_fp_offsetted(%arg0: tensor<1x2x4x1xi32>) {
 
 // CHECK-LABEL: @if_test_simple
 func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () {
+  %a = "tosa.log"(%arg0) : (tensor<f32>) -> tensor<*xf32>
+  %b = "tosa.log"(%arg1) : (tensor<f32>) -> tensor<*xf32>
   // CHECK: (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
-  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
-  ^bb1(%arg3 : tensor<f32>, %arg4 : tensor<f32>):
-    "tosa.yield"(%arg3) : (tensor<f32>) -> ()
+  %0 = "tosa.cond_if"(%arg2, %a, %b) ({
+  ^bb1(%arg3 : tensor<*xf32>, %arg4 : tensor<*xf32>):
+    "tosa.yield"(%arg3) : (tensor<*xf32>) -> ()
   }, {
-  ^bb1(%arg5 : tensor<f32>, %arg6 : tensor<f32>):
-    "tosa.yield"(%arg6) : (tensor<f32>) -> ()
-  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
+  ^bb1(%arg5 : tensor<*xf32>, %arg6 : tensor<*xf32>):
+    "tosa.yield"(%arg6) : (tensor<*xf32>) -> ()
+  }) : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>)
   return
 }
 
@@ -1100,3 +1102,88 @@ func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor
   }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<*xf32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: @while_test
+func @while_test(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
+  // CHECK:      "tosa.add" 
+  // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %0 = "tosa.add"(%arg0, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<*xi32>
+
+  // CHECK:      "tosa.while_loop"
+  %1 = "tosa.while_loop"(%0) ( {
+
+  // CHECK:      ^bb0
+  // CHECK-SAME: tensor<i32>
+  ^bb0(%arg2: tensor<*xi32>):
+    %2 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+    // CHECK:       "tosa.greater_equal"
+    // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = "tosa.greater_equal"(%2, %arg2) : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
+    // CHECK:      "tosa.yield"
+    // CHECK-SAME: tensor<i1>
+    "tosa.yield"(%3) : (tensor<*xi1>) -> ()
+  },  {
+  // CHECK:      ^bb0
+  // CHECK-SAME: tensor<i32>
+  ^bb0(%arg2: tensor<*xi32>):
+    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+
+    // CHECK:     "tosa.add"
+    // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i32>
+    %3 = "tosa.add"(%arg2, %2) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
+
+    // CHECK:      "tosa.yield"
+    // CHECK-SAME: tensor<i32>
+    "tosa.yield"(%3) : (tensor<*xi32>) -> ()
+
+  // CHECK:      (tensor<i32>) -> tensor<i32>
+  }) : (tensor<*xi32>) -> (tensor<*xi32>)
+
+  // CHECK:      tensor.cast
+  return %1 : tensor<*xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @while_test
+func @while_test(%arg0 : tensor<i32>, %arg1 : tensor<1xi32>) -> () {
+  // CHECK:      "tosa.while_loop"
+  %1:2 = "tosa.while_loop"(%arg0, %arg1) ( {
+
+  // CHECK:      ^bb0
+  // CHECK-SAME: tensor<i32>
+  // CHECK-SAME: tensor<?xi32>
+  ^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xi32>):
+    %2 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+
+    // CHECK:       "tosa.greater_equal"
+    // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = "tosa.greater_equal"(%2, %arg2) : (tensor<i32>, tensor<*xi32>) -> tensor<*xi1>
+    "tosa.yield"(%3) : (tensor<*xi1>) -> ()
+  },  {
+
+  // CHECK:      ^bb0
+  // CHECK-SAME: tensor<i32>
+  // CHECK-SAME: tensor<?xi32>
+  ^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xi32>):
+    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+
+    // CHECK:     "tosa.add"
+    // CHECK-SAME: (tensor<i32>, tensor<i32>) -> tensor<i32>
+    %3 = "tosa.add"(%arg2, %2) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
+
+    // CHECK:      "tosa.concat"
+    // CHECK-SAME: (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
+    %4 = "tosa.concat"(%arg3, %arg3) { axis = 0 : i64 } : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>)
+
+    // CHECK:      "tosa.yield"
+    // CHECK-SAME: tensor<i32>
+    // CHECK-SAME: tensor<?xi32>
+    "tosa.yield"(%3, %4) : (tensor<*xi32>, tensor<*xi32>) -> ()
+
+  // CHECK:      (tensor<i32>, tensor<1xi32>) -> (tensor<i32>, tensor<?xi32>)
+  }) : (tensor<i32>, tensor<1xi32>) -> (tensor<*xi32>, tensor<*xi32>)
+  return
+}


        


More information about the Mlir-commits mailing list