[Mlir-commits] [mlir] 0934930 - [mlir] Enable specifying querying function in ValueShapeRange

Jacques Pienaar llvmlistbot at llvm.org
Tue Aug 10 11:44:54 PDT 2021


Author: Jacques Pienaar
Date: 2021-08-10T11:44:20-07:00
New Revision: 093493032d19adf879528a7e407909b36bd6570f

URL: https://github.com/llvm/llvm-project/commit/093493032d19adf879528a7e407909b36bd6570f
DIFF: https://github.com/llvm/llvm-project/commit/093493032d19adf879528a7e407909b36bd6570f.diff

LOG: [mlir] Enable specifying querying function in ValueShapeRange

This enables querying shapes/values as shapes without mutating the IR
directly (e.g., towards enabling doing inference in analysis &
application steps, inferring function shape with constant from callsite,
...). Add a new ShapeAdaptor that abstracts over whether shape is from
Type or ShapedTypeComponents or DenseIntElementsAttribute. This adds new
accessors to ValueShapeRange to get Shape and value as shape, but
doesn't restrict or remove the previous way of accessing Type via the
Value for now, that does mean a less refined shape could be accidentally
queried and will be restricted in follow up.

Currently restricted Value query to what can be represented as Shape. So
only supports cases where constant subgraph evaluation's output is a
shape. I had considered making it more general, but without TBD extern
attribute concept or some such a user cannot today uniformly avoid
overhead.

Update TOSA ops and also the shape inference pass.

Differential Revision: https://reviews.llvm.org/D107768

Added: 
    mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp

Modified: 
    mlir/include/mlir/Interfaces/InferTypeOpInterface.h
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
    mlir/lib/Interfaces/InferTypeOpInterface.cpp
    mlir/unittests/Interfaces/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 0535ff1e714dc..e153e66bd9d97 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -16,9 +16,11 @@
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/SmallVector.h"
 
 namespace mlir {
@@ -39,19 +41,25 @@ class ShapedTypeComponents {
 
 public:
   /// Default construction is an unranked shape.
-  ShapedTypeComponents() : ranked(false), elementType(nullptr), attr(nullptr){};
+  ShapedTypeComponents() : elementType(nullptr), attr(nullptr), ranked(false){};
   ShapedTypeComponents(Type elementType)
-      : ranked(false), elementType(elementType), attr(nullptr) {}
+      : elementType(elementType), attr(nullptr), ranked(false) {}
+  ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
+    ranked = shapedType.hasRank();
+    elementType = shapedType.getElementType();
+    if (ranked)
+      dims = llvm::to_vector<4>(shapedType.getShape());
+  }
   template <typename Arg, typename = typename std::enable_if_t<
                               std::is_constructible<ShapeStorageT, Arg>::value>>
   ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
                        Attribute attr = nullptr)
-      : dims(std::forward<Arg>(arg)), ranked(true), elementType(elementType),
-        attr(attr) {}
+      : dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
+        ranked(true) {}
   ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
                        Attribute attr = nullptr)
-      : dims(vec.begin(), vec.end()), ranked(true), elementType(elementType),
-        attr(attr) {}
+      : dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
+        ranked(true) {}
 
   /// Return the dimensions of the shape.
   /// Requires: shape is ranked.
@@ -70,24 +78,115 @@ class ShapedTypeComponents {
   Attribute getAttribute() const { return attr; };
 
 private:
+  friend class ShapeAdaptor;
+
   ShapeStorageT dims;
-  bool ranked;
   Type elementType;
   Attribute attr;
+  bool ranked;
+};
+
+/// Adaptor class to abstract the 
diff erences between whether value is from
+/// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
+class ShapeAdaptor {
+public:
+  ShapeAdaptor(Type t) {
+    if (auto st = t.dyn_cast<ShapedType>())
+      val = st;
+  }
+  ShapeAdaptor(Attribute t) {
+    if (auto da = t.dyn_cast<DenseIntElementsAttr>())
+      val = da;
+  }
+  ShapeAdaptor(ShapedTypeComponents *components) : val(components) {}
+  ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {}
+
+  /// Returns whether the shape has a rank.
+  bool hasRank() const;
+
+  /// Returns the element type.
+  Type getElementType() const;
+
+  /// Populates the dimensions from shape referenced.
+  /// Requires: shape is ranked.
+  void getDims(SmallVectorImpl<int64_t> &res) const;
+
+  /// Populates the dimensions of the ShapeTypeComponents.
+  /// Requires: shape is ranked.
+  void getDims(ShapedTypeComponents &res) const;
+
+  /// Returns the size of the index'th dimension.
+  /// Requires: shape is ranked.
+  int64_t getDimSize(int index) const;
+
+  /// Returns whether the index'th dimension is dynamic.
+  /// Requires: shape is ranked.
+  bool isDynamicDim(int index) const {
+    return ShapedType::isDynamic(getDimSize(index));
+  }
+
+  /// Returns whether the shape is fully static.
+  bool hasStaticShape() const;
+
+  /// Returns the rank of the shape.
+  /// Requires: shape is ranked.
+  int64_t getRank() const;
+
+  /// Returns the number of elements in the shape.
+  /// Requires: hasStaticShape
+  int64_t getNumElements() const;
+
+  /// Returns whether valid (non-null) shape.
+  operator bool() const { return !val.isNull(); }
+
+  /// Dumps textual repesentation to stderr.
+  void dump() const;
+
+private:
+  // Union storing either ShapedTypeComponents, ShapedType (stored as Type and
+  // casted), or DenseIntElementsAttribute (stored as Atrtribute).
+  PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr;
 };
 
 /// Range of values and shapes (corresponding effectively to Shapes dialect's
 /// ValueShape type concept).
+// Currently this exposes the Value (of operands) and Type of the Value. This is
+// not ideal as then one can accidentally reference an out of date shape. This
+// is done to both enable gradual switch and also as OpAdaptor doesn't currently
+// allow returning anything other than Value.
 class ValueShapeRange : public ValueRange::RangeBaseT {
 public:
-  ValueShapeRange(ValueRange values) : RangeBaseT(values) {}
-  template <typename Arg, typename = typename std::enable_if_t<
-                              std::is_constructible<ValueRange, Arg>::value>>
-  ValueShapeRange(Arg &&arg)
-      : ValueShapeRange(ValueRange(std::forward<Arg>(arg))) {}
+  using ValueShapeMapFn = function_ref<ShapeAdaptor(Value)>;
+
+  ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr,
+                  ValueShapeMapFn valueToShape = nullptr)
+      : RangeBaseT(values), operandShape(operandShape),
+        valueToShape(valueToShape) {}
   ValueShapeRange(const std::initializer_list<Value> &values)
       : ValueShapeRange(ValueRange(values)) {}
 
+  ValueShapeRange(const ValueShapeRange &other) : RangeBaseT(other) {
+    operandShape = other.operandShape;
+    valueToShape = other.valueToShape;
+  }
+
+  /// Sets the Value to ShapeAdaptor mapping function and returns this.
+  ValueShapeRange &setValueToShapeMapping(ValueShapeMapFn fn) {
+    valueToShape = fn;
+    return *this;
+  }
+
+  ValueShapeRange &setOperandShapeMapping(ValueShapeMapFn fn) {
+    operandShape = fn;
+    return *this;
+  }
+
+  /// Returns the set Value to ShapeAdaptor mapping function.
+  ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; }
+  ValueShapeMapFn getOperandShapeMapping() const { return operandShape; }
+
+  // Accessors.
+
   /// Returns the types of the values within this range.
   /// Note: This returns only the types of Values in the ValueRange and not a
   /// more refined type.
@@ -97,7 +196,32 @@ class ValueShapeRange : public ValueRange::RangeBaseT {
   auto getType() const { return getTypes(); }
 
   /// Returns the Values in the ValueRange.
+  /// To query the most up to date shape of a Value, query the shape
+  /// using getShape below rather than using the type of the Value.
   ValueRange getValues() const { return ValueRange(begin(), end()); };
+
+  /// Returns an argument as shape. If the argument is not constant or not a
+  /// shape, then the function returns a nullptr.
+  /// This will first query the valueToShape mapping (if set), before querying
+  /// the ValueRange.
+  ShapeAdaptor getValueAsShape(int index);
+
+  /// Returns the shape of index'th operand.
+  // TODO: Update so that operator[] references these instead to avoid
+  // accidentally refering to less refined shape.
+  ShapeAdaptor getShape(int index) const;
+
+  /// Returns the shape of the given Value.
+  ShapeAdaptor getShape(Value val) const;
+
+private:
+  // Mapping from Value to ShapedTypeComponents corresponding to shape of type
+  // of Value.
+  ValueShapeMapFn operandShape;
+
+  // Mapping from Value to ShapedTypeComponents corresponding to constant Value
+  // if interpreted as shape.
+  ValueShapeMapFn valueToShape;
 };
 
 namespace detail {

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d6f9905506ca7..ed0dbb5e094da 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -356,21 +356,21 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
+  ShapeAdaptor inputShape = operands.getShape(0);
   IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
   int32_t axisVal = axis.getValue().getSExtValue();
 
-  if (!inputTy.hasRank()) {
+  if (!inputShape.hasRank()) {
     inferredReturnShapes.push_back(ShapedTypeComponents());
     return success();
   }
 
   SmallVector<int64_t> outShape;
-  outShape.reserve(inputTy.getRank() - 1);
-  for (int i = 0, s = inputTy.getRank(); i < s; i++) {
+  outShape.reserve(inputShape.getRank() - 1);
+  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
     if (i == axisVal)
       continue;
-    outShape.push_back(inputTy.getDimSize(i));
+    outShape.push_back(inputShape.getDimSize(i));
   }
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
@@ -387,21 +387,21 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
   llvm::SmallVector<int64_t> outputShape;
   bool hasRankedInput = false;
   for (auto operand : operands) {
-    ShapedType operandTy = operand.getType().cast<ShapedType>();
-    if (!operandTy.hasRank())
+    ShapeAdaptor operandShape = operands.getShape(operand);
+    if (!operandShape.hasRank())
       continue;
 
     // Copy the Operand's rank.
     if (!hasRankedInput)
-      outputShape.resize(operandTy.getRank(), ShapedType::kDynamicSize);
+      outputShape.resize(operandShape.getRank(), ShapedType::kDynamicSize);
 
     // Copy shapes until the dim is non-dynamic.
-    for (int i = 0, s = operandTy.getRank(); i < s; i++) {
-      if (i == axis || operandTy.isDynamicDim(i))
+    for (int i = 0, s = operandShape.getRank(); i < s; i++) {
+      if (i == axis || operandShape.isDynamicDim(i))
         continue;
       if (outputShape[i] == ShapedType::kDynamicSize)
-        outputShape[i] = operandTy.getDimSize(i);
-      if (outputShape[i] != operandTy.getDimSize(i))
+        outputShape[i] = operandShape.getDimSize(i);
+      if (outputShape[i] != operandShape.getDimSize(i))
         return failure();
     }
 
@@ -416,16 +416,16 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
   // Determine the dimension size along the concatenation axis.
   int concatDimSize = 0;
   for (auto operand : operands) {
-    ShapedType operandTy = operand.getType().cast<ShapedType>();
+    ShapeAdaptor operandShape = operands.getShape(operand);
 
     // We need to know the length of the concatenation axis of all inputs to
     // determine the dimension size of the output shape.
-    if (!operandTy.hasRank() || operandTy.isDynamicDim(axis)) {
+    if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
       concatDimSize = ShapedType::kDynamicSize;
       break;
     }
 
-    concatDimSize += operandTy.getDimSize(axis);
+    concatDimSize += operandShape.getDimSize(axis);
   }
 
   outputShape[axis] = concatDimSize;
@@ -438,25 +438,26 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
-  ShapedType weightTy = operands[1].getType().cast<ShapedType>();
-  ShapedType biasTy = operands[2].getType().cast<ShapedType>();
+  ShapeAdaptor inputShape = operands.getShape(0);
+  ShapeAdaptor weightShape = operands.getShape(1);
+  ShapeAdaptor biasShape = operands.getShape(2);
 
   // All shapes are dynamic.
   SmallVector<int64_t> outShape;
   outShape.resize(2, ShapedType::kDynamicSize);
 
-  if (inputTy.hasRank()) {
-    outShape[0] = inputTy.getDimSize(0);
+  if (inputShape.hasRank()) {
+    outShape[0] = inputShape.getDimSize(0);
   }
 
-  if (weightTy.hasRank()) {
-    outShape[1] = weightTy.getDimSize(0);
+  if (weightShape.hasRank()) {
+    outShape[1] = weightShape.getDimSize(0);
   }
 
-  if (biasTy.hasRank()) {
-    outShape[1] = outShape[1] == ShapedType::kDynamicSize ? biasTy.getDimSize(0)
-                                                          : outShape[1];
+  if (biasShape.hasRank()) {
+    outShape[1] = outShape[1] == ShapedType::kDynamicSize
+                      ? biasShape.getDimSize(0)
+                      : outShape[1];
   }
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
@@ -467,22 +468,23 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapedType lhsTy = operands[0].getType().cast<ShapedType>();
-  ShapedType rhsTy = operands[1].getType().cast<ShapedType>();
+  ShapeAdaptor lhsShape = operands.getShape(0);
+  ShapeAdaptor rhsShape = operands.getShape(1);
 
   // All shapes are dynamic.
   SmallVector<int64_t> outShape;
   outShape.resize(3, ShapedType::kDynamicSize);
 
-  if (lhsTy.hasRank()) {
-    outShape[0] = lhsTy.getDimSize(0);
-    outShape[1] = lhsTy.getDimSize(1);
+  if (lhsShape.hasRank()) {
+    outShape[0] = lhsShape.getDimSize(0);
+    outShape[1] = lhsShape.getDimSize(1);
   }
 
-  if (rhsTy.hasRank()) {
-    outShape[0] = outShape[0] == ShapedType::kDynamicSize ? rhsTy.getDimSize(0)
-                                                          : outShape[0];
-    outShape[2] = rhsTy.getDimSize(2);
+  if (rhsShape.hasRank()) {
+    outShape[0] = outShape[0] == ShapedType::kDynamicSize
+                      ? rhsShape.getDimSize(0)
+                      : outShape[0];
+    outShape[2] = rhsShape.getDimSize(2);
   }
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
@@ -493,26 +495,26 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
-  ShapedType paddingTy = operands[1].getType().cast<ShapedType>();
+  ShapeAdaptor inputShape = operands.getShape(0);
+  ShapeAdaptor paddingShape = operands.getShape(1);
   SmallVector<int64_t> outputShape;
 
   // If both inputs have unknown shape, we cannot determine the shape of the
   // output.
-  if (!inputTy.hasRank() && !paddingTy.hasRank()) {
+  if (!inputShape.hasRank() && !paddingShape.hasRank()) {
     inferredReturnShapes.push_back(ShapedTypeComponents());
     return success();
   }
 
   // If the input rank is unknown we can info the output rank using the padding
   // shape's first dim.
-  if (!inputTy.hasRank()) {
-    if (paddingTy.isDynamicDim(0)) {
+  if (!inputShape.hasRank()) {
+    if (paddingShape.isDynamicDim(0)) {
       inferredReturnShapes.push_back(ShapedTypeComponents());
       return success();
     }
 
-    outputShape.resize(paddingTy.getDimSize(0), ShapedType::kDynamicSize);
+    outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamicSize);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
@@ -520,7 +522,7 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
   DenseIntElementsAttr paddings;
   // If the paddings value is not a constant, all dimensions must be dynamic.
   if (!matchPattern(operands[1], m_Constant(&paddings))) {
-    outputShape.resize(inputTy.getRank(), ShapedType::kDynamicSize);
+    outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
@@ -530,14 +532,14 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
     paddingValues.push_back(val.getSExtValue());
   }
 
-  outputShape.reserve(inputTy.getRank());
-  for (int i = 0, s = inputTy.getRank(); i < s; i++) {
-    if (inputTy.isDynamicDim(i)) {
+  outputShape.reserve(inputShape.getRank());
+  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
+    if (inputShape.isDynamicDim(i)) {
       outputShape.push_back(ShapedType::kDynamicSize);
       continue;
     }
 
-    outputShape.push_back(inputTy.getDimSize(i) + paddingValues[i * 2] +
+    outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
                           paddingValues[i * 2 + 1]);
   }
 
@@ -549,7 +551,7 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  auto sizes = attributes.get("size").cast<ArrayAttr>().getValue();
+  ArrayAttr sizes = SliceOpAdaptor(operands, attributes).size();
   SmallVector<int64_t> outputShape;
   outputShape.reserve(sizes.size());
   for (auto val : sizes) {
@@ -564,14 +566,15 @@ LogicalResult tosa::TableOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
+  ShapeAdaptor inputShape = operands.getShape(0);
 
-  if (!inputTy.hasRank()) {
+  if (!inputShape.hasRank()) {
     inferredReturnShapes.push_back(ShapedTypeComponents());
     return success();
   }
 
-  inferredReturnShapes.push_back(inputTy.getShape());
+  inferredReturnShapes.resize(1);
+  inputShape.getDims(inferredReturnShapes[0]);
   return success();
 }
 
@@ -579,10 +582,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  auto multiples = attributes.get("multiples").cast<ArrayAttr>().getValue();
-  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
+  TileOpAdaptor adaptor(operands, attributes);
+  ArrayAttr multiples = adaptor.multiples();
+  ShapeAdaptor inputShape = operands.getShape(0);
   SmallVector<int64_t> outputShape;
-  if (!inputTy.hasRank()) {
+  if (!inputShape.hasRank()) {
     outputShape.resize(multiples.size(), ShapedType::kDynamicSize);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
@@ -597,8 +601,8 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
 
   // Any non dynamic dimension can be multiplied to a known size.
   outputShape.reserve(multiples.size());
-  for (int i = 0, s = inputTy.getRank(); i < s; i++) {
-    int dim = inputTy.getDimSize(i);
+  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
+    int dim = inputShape.getDimSize(i);
     if (dim != ShapedType::kDynamicSize)
       dim *= multipleValues[i];
     outputShape.push_back(dim);
@@ -612,15 +616,16 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapedType type = operands.front().getType().cast<ShapedType>();
+  ReshapeOpAdaptor adaptor(operands, attributes);
+  ShapeAdaptor inputShape = operands.getShape(0);
 
-  auto newShape = attributes.get("new_shape").cast<ArrayAttr>();
+  ArrayAttr newShape = adaptor.new_shape();
   llvm::SmallVector<int64_t> newShapeValue;
   getI64Values(newShape, newShapeValue);
 
   // We cannot infer from the total number of elements so we must take the
   // shape attribute as exact.
-  if (!type.hasRank() || !type.hasStaticShape()) {
+  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
     inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
     return success();
   }
@@ -628,7 +633,7 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
   // Determine the number of elements covered by the slice of all static
   // dimensions. This allows us to infer the length of the remaining dynamic
   // dimension.
-  int64_t numElements = type.getNumElements();
+  int64_t numElements = inputShape.getNumElements();
   int64_t staticMul = 1;
   for (auto val : newShapeValue) {
     if (val != ShapedType::kDynamicSize) {
@@ -650,12 +655,13 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapedType inputTy = operands[0].getType().cast<ShapedType>();
-  ShapedType permsTy = operands[1].getType().cast<ShapedType>();
+  ShapeAdaptor inputShape = operands.getShape(0);
+  ShapeAdaptor permsShape = operands.getShape(1);
 
   // If input rank and permutation length is unknown, the output rank is
   // unknown.
-  if (!inputTy.hasRank() && (!permsTy.hasRank() || permsTy.isDynamicDim(0))) {
+  if (!inputShape.hasRank() &&
+      (!permsShape.hasRank() || permsShape.isDynamicDim(0))) {
     inferredReturnShapes.push_back(ShapedTypeComponents());
     return success();
   }
@@ -663,22 +669,22 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   // Without the input dims we cannot determine the output dim sizes but we
   // can determine the output rank.
   SmallVector<int64_t> outputShape;
-  if (!inputTy.hasRank()) {
-    outputShape.resize(permsTy.getDimSize(0), ShapedType::kDynamicSize);
+  if (!inputShape.hasRank()) {
+    outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamicSize);
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
 
   // Rank-0 means no permutations matter.
-  if (inputTy.getRank() == 0) {
+  if (inputShape.getRank() == 0) {
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
 
   // Check whether the input dimensions are all the same.
   bool allTheSame = true;
-  for (int i = 1, s = inputTy.getRank(); i < s; i++) {
-    if (inputTy.getDimSize(0) != inputTy.getDimSize(i)) {
+  for (int i = 1, s = inputShape.getRank(); i < s; i++) {
+    if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
       allTheSame = false;
       break;
     }
@@ -687,24 +693,18 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   // If all of the input dimensions are the same we don't care about the
   // permutation.
   if (allTheSame) {
-    outputShape.resize(inputTy.getRank(), inputTy.getDimSize(0));
+    outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
 
-  DenseIntElementsAttr perms;
-  outputShape.resize(inputTy.getRank(), ShapedType::kDynamicSize);
+  outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
   // If the permuations are a constant we can directly determine the output
   // shape.
-  if (matchPattern(operands[1], m_Constant(&perms))) {
-    llvm::SmallVector<int64_t> permValues;
-    for (auto val : perms) {
-      permValues.push_back(val.getSExtValue());
-    }
-
-    outputShape.reserve(inputTy.getRank());
-    for (int i = 0, s = inputTy.getRank(); i < s; i++) {
-      outputShape[i] = inputTy.getDimSize(permValues[i]);
+  if (ShapeAdaptor permShape = operands.getValueAsShape(1)) {
+    outputShape.reserve(inputShape.getRank());
+    for (int i = 0, s = inputShape.getRank(); i < s; i++) {
+      outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
     }
   }
 
@@ -719,16 +719,18 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
   llvm::SmallVector<int64_t> outputShape;
   outputShape.resize(3, ShapedType::kDynamicSize);
 
-  if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
-    outputShape[0] = ty.getDimSize(0);
-    outputShape[2] = ty.getDimSize(2);
+  ShapeAdaptor valuesShape = operands.getShape(0);
+  if (valuesShape.hasRank()) {
+    outputShape[0] = valuesShape.getDimSize(0);
+    outputShape[2] = valuesShape.getDimSize(2);
   }
 
-  if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
+  ShapeAdaptor indicesShape = operands.getShape(1);
+  if (indicesShape.hasRank()) {
     if (outputShape[0] == ShapedType::kDynamicSize)
-      outputShape[0] = ty.getDimSize(0);
+      outputShape[0] = indicesShape.getDimSize(0);
     if (outputShape[1] == ShapedType::kDynamicSize)
-      outputShape[1] = ty.getDimSize(1);
+      outputShape[1] = indicesShape.getDimSize(1);
   }
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -739,24 +741,25 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ResizeOpAdaptor adaptor(operands, attributes);
   llvm::SmallVector<int64_t, 4> outputShape;
   outputShape.resize(4, ShapedType::kDynamicSize);
 
   int32_t inHeight = ShapedType::kDynamicSize;
   int32_t inWidth = ShapedType::kDynamicSize;
 
-  if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
-    outputShape[0] = ty.getDimSize(0);
-    outputShape[3] = ty.getDimSize(3);
+  ShapeAdaptor inputShape = operands.getShape(adaptor.input());
+  if (inputShape.hasRank()) {
+    outputShape[0] = inputShape.getDimSize(0);
+    outputShape[3] = inputShape.getDimSize(3);
 
-    inHeight = ty.getDimSize(1);
-    inWidth = ty.getDimSize(2);
+    inHeight = inputShape.getDimSize(1);
+    inWidth = inputShape.getDimSize(2);
   }
 
-  int32_t shift =
-      attributes.get("shift").cast<IntegerAttr>().getValue().getSExtValue();
+  int32_t shift = adaptor.shift().getValue().getSExtValue();
   llvm::SmallVector<int64_t> newShape;
-  getI64Values(attributes.get("output_size").cast<ArrayAttr>(), newShape);
+  getI64Values(adaptor.output_size(), newShape);
   outputShape[1] = newShape[0];
   outputShape[2] = newShape[1];
 
@@ -764,10 +767,10 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
   llvm::SmallVector<int64_t> offsetInt;
   llvm::SmallVector<double> strideFp;
   llvm::SmallVector<double> offsetFp;
-  getI64Values(attributes.get("offset").cast<ArrayAttr>(), offsetInt);
-  getF64Values(attributes.get("offset_fp").cast<ArrayAttr>(), offsetFp);
-  getI64Values(attributes.get("stride").cast<ArrayAttr>(), strideInt);
-  getF64Values(attributes.get("stride_fp").cast<ArrayAttr>(), strideFp);
+  getI64Values(adaptor.offset(), offsetInt);
+  getF64Values(adaptor.offset_fp(), offsetFp);
+  getI64Values(adaptor.stride(), strideInt);
+  getF64Values(adaptor.stride_fp(), strideFp);
 
   // If we have a 0 zero in integers we know that the resize indexing needs to
   // be performed in floating point. Use the floating point varient to compute
@@ -812,22 +815,25 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
   llvm::SmallVector<int64_t> outputShape;
   outputShape.resize(3, ShapedType::kDynamicSize);
 
-  if (auto ty = operands[0].getType().dyn_cast<RankedTensorType>()) {
-    outputShape[0] = ty.getDimSize(0);
-    outputShape[1] = ty.getDimSize(1);
-    outputShape[2] = ty.getDimSize(2);
+  ShapeAdaptor valuesInShape = operands.getShape(0);
+  if (valuesInShape.hasRank()) {
+    outputShape[0] = valuesInShape.getDimSize(0);
+    outputShape[1] = valuesInShape.getDimSize(1);
+    outputShape[2] = valuesInShape.getDimSize(2);
   }
 
-  if (auto ty = operands[1].getType().dyn_cast<RankedTensorType>()) {
+  ShapeAdaptor indicesShape = operands.getShape(1);
+  if (indicesShape.hasRank()) {
     if (outputShape[0] == ShapedType::kDynamicSize)
-      outputShape[0] = ty.getDimSize(0);
+      outputShape[0] = indicesShape.getDimSize(0);
   }
 
-  if (auto ty = operands[2].getType().dyn_cast<RankedTensorType>()) {
+  ShapeAdaptor inputShape = operands.getShape(2);
+  if (inputShape.hasRank()) {
     if (outputShape[0] == ShapedType::kDynamicSize)
-      outputShape[0] = ty.getDimSize(0);
+      outputShape[0] = inputShape.getDimSize(0);
     if (outputShape[2] == ShapedType::kDynamicSize)
-      outputShape[2] = ty.getDimSize(2);
+      outputShape[2] = inputShape.getDimSize(2);
   }
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -835,21 +841,16 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
 }
 
 static LogicalResult ReduceInferReturnTypes(
-    Value operand, IntegerAttr axis,
+    ShapeAdaptor operandShape, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  auto operandTy = operand.getType().cast<ShapedType>();
-  if (!operandTy.hasRank()) {
+  if (!operandShape.hasRank()) {
     inferredReturnShapes.push_back(ShapedTypeComponents());
     return success();
   }
 
-  int64_t axisVal = axis.getValue().getSExtValue();
   SmallVector<int64_t> outputShape;
-  outputShape.reserve(operandTy.getRank());
-  for (auto dim : operandTy.getShape()) {
-    outputShape.push_back(dim);
-  }
-
+  operandShape.getDims(outputShape);
+  int64_t axisVal = axis.getValue().getSExtValue();
   outputShape[axisVal] = 1;
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
   return success();
@@ -861,7 +862,7 @@ static LogicalResult ReduceInferReturnTypes(
       ValueShapeRange operands, DictionaryAttr attributes,                     \
       RegionRange regions,                                                     \
       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
-    return ReduceInferReturnTypes(operands[0],                                 \
+    return ReduceInferReturnTypes(operands.getShape(0),                        \
                                   attributes.get("axis").cast<IntegerAttr>(),  \
                                   inferredReturnShapes);                       \
   }
@@ -874,26 +875,26 @@ REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
 #undef REDUCE_SHAPE_INFER
 
-static LogicalResult resolveBroadcastShape(ValueRange operands,
+static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
                                            SmallVector<int64_t> &outShape) {
   int64_t outRank = 0;
-  for (auto operand : operands) {
-    auto type = operand.getType().cast<ShapedType>();
-    if (!type.hasRank())
+  for (int i = 0, e = operands.size(); i != e; ++i) {
+    auto shape = operands.getShape(i);
+    if (!shape.hasRank()) {
       return failure();
-    outRank = std::max<int64_t>(outRank, type.getRank());
+    }
+    outRank = std::max<int64_t>(outRank, shape.getRank());
   }
 
   outShape.resize(outRank, 1);
 
-  for (auto operand : operands) {
-    auto type = operand.getType().cast<ShapedType>();
-    auto shape = type.getShape();
-    auto rankDiff = outShape.size() - shape.size();
+  for (int i = 0, e = operands.size(); i != e; ++i) {
+    auto shape = operands.getShape(i);
+    auto rankDiff = outShape.size() - shape.getRank();
 
-    for (size_t i = 0; i < shape.size(); i++) {
+    for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
       auto dim1 = outShape[i + rankDiff];
-      auto dim2 = shape[i];
+      auto dim2 = shape.getDimSize(i);
       auto resolvedDim = dim1;
 
       if (dim1 == 1) {
@@ -911,7 +912,7 @@ static LogicalResult resolveBroadcastShape(ValueRange operands,
 }
 
 static LogicalResult NAryInferReturnTypes(
-    ValueRange operands,
+    const ValueShapeRange &operands,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outShape;
   if (resolveBroadcastShape(operands, outShape).failed()) {
@@ -973,24 +974,24 @@ NARY_SHAPE_INFER(tosa::SigmoidOp)
 #undef PRED_SHAPE_INFER
 
 static LogicalResult poolingInferReturnTypes(
-    ValueRange operands, DictionaryAttr attributes,
+    const ValueShapeRange &operands, DictionaryAttr attributes,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  RankedTensorType inputTy = operands[0].getType().dyn_cast<RankedTensorType>();
+  ShapeAdaptor inputShape = operands.getShape(0);
   llvm::SmallVector<int64_t> outputShape;
   outputShape.resize(4, -1);
 
   // We only know the rank if the input type is unranked.
-  if (!inputTy) {
+  if (!inputShape) {
     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
     return success();
   }
 
   // Batch and number of channels are identical for pooling layer.
-  outputShape[0] = inputTy.getDimSize(0);
-  outputShape[3] = inputTy.getDimSize(3);
+  outputShape[0] = inputShape.getDimSize(0);
+  outputShape[3] = inputShape.getDimSize(3);
 
-  int32_t height = inputTy.getDimSize(1);
-  int32_t width = inputTy.getDimSize(2);
+  int32_t height = inputShape.getDimSize(1);
+  int32_t width = inputShape.getDimSize(2);
 
   llvm::SmallVector<int64_t> kernel;
   llvm::SmallVector<int64_t> stride;
@@ -1019,7 +1020,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
-  Conv2DOp::Adaptor adaptor(operands.getValues());
+  Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
 
   int32_t inputWidth = ShapedType::kDynamicSize;
   int32_t inputHeight = ShapedType::kDynamicSize;
@@ -1027,23 +1028,27 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
   int32_t weightHeight = ShapedType::kDynamicSize;
 
   // Input shape describes input width/height and batch.
-  if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
-    outputShape[0] = inputTy.getDimSize(0);
-    inputHeight = inputTy.getDimSize(1);
-    inputWidth = inputTy.getDimSize(2);
+
+  ShapeAdaptor inputShape = operands.getShape(adaptor.input());
+  if (inputShape.hasRank()) {
+    outputShape[0] = inputShape.getDimSize(0);
+    inputHeight = inputShape.getDimSize(1);
+    inputWidth = inputShape.getDimSize(2);
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  if (auto weightTy = adaptor.weight().getType().dyn_cast<RankedTensorType>()) {
-    outputShape[3] = weightTy.getDimSize(0);
-    weightHeight = weightTy.getDimSize(1);
-    weightWidth = weightTy.getDimSize(2);
+  ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
+  if (weightShape.hasRank()) {
+    outputShape[3] = weightShape.getDimSize(0);
+    weightHeight = weightShape.getDimSize(1);
+    weightWidth = weightShape.getDimSize(2);
   }
 
   // Bias shape can describe the output channels.
-  if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
+  ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
+  if (biasShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
-                         ? biasTy.getDimSize(0)
+                         ? biasShape.getDimSize(0)
                          : outputShape[3];
   }
 
@@ -1051,9 +1056,9 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
   llvm::SmallVector<int64_t> padding;
   llvm::SmallVector<int64_t> stride;
 
-  getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
-  getI64Values(attributes.get("pad").cast<ArrayAttr>(), padding);
-  getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
+  getI64Values(adaptor.dilation(), dilation);
+  getI64Values(adaptor.pad(), padding);
+  getI64Values(adaptor.stride(), stride);
 
   if (!ShapedType::isDynamic(inputHeight) &&
       !ShapedType::isDynamic(weightHeight)) {
@@ -1080,7 +1085,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
-  Conv2DOp::Adaptor adaptor(operands.getValues());
+  Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
 
   int32_t inputWidth = ShapedType::kDynamicSize;
   int32_t inputHeight = ShapedType::kDynamicSize;
@@ -1091,34 +1096,37 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
   int32_t weightDepth = ShapedType::kDynamicSize;
 
   // Input shape describes input width/height and batch.
-  if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
-    outputShape[0] = inputTy.getDimSize(0);
-    inputHeight = inputTy.getDimSize(1);
-    inputWidth = inputTy.getDimSize(2);
-    inputDepth = inputTy.getDimSize(3);
+  ShapeAdaptor inputShape = operands.getShape(adaptor.input());
+  if (inputShape.hasRank()) {
+    outputShape[0] = inputShape.getDimSize(0);
+    inputHeight = inputShape.getDimSize(1);
+    inputWidth = inputShape.getDimSize(2);
+    inputDepth = inputShape.getDimSize(3);
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  if (auto weightTy = adaptor.weight().getType().dyn_cast<RankedTensorType>()) {
-    outputShape[4] = weightTy.getDimSize(0);
-    weightHeight = weightTy.getDimSize(1);
-    weightWidth = weightTy.getDimSize(2);
-    weightDepth = weightTy.getDimSize(3);
+  ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
+  if (weightShape.hasRank()) {
+    outputShape[4] = weightShape.getDimSize(0);
+    weightHeight = weightShape.getDimSize(1);
+    weightWidth = weightShape.getDimSize(2);
+    weightDepth = weightShape.getDimSize(3);
   }
 
   // Bias shape can describe the output channels.
-  if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
+  ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
+  if (biasShape.hasRank()) {
     outputShape[4] =
-        (outputShape[4] == -1) ? biasTy.getDimSize(0) : outputShape[4];
+        (outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
   }
 
   llvm::SmallVector<int64_t> dilation;
   llvm::SmallVector<int64_t> padding;
   llvm::SmallVector<int64_t> stride;
 
-  getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
-  getI64Values(attributes.get("pad").cast<ArrayAttr>(), padding);
-  getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
+  getI64Values(adaptor.dilation(), dilation);
+  getI64Values(adaptor.pad(), padding);
+  getI64Values(adaptor.stride(), stride);
 
   if (!ShapedType::isDynamic(inputHeight) &&
       !ShapedType::isDynamic(weightHeight)) {
@@ -1167,7 +1175,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
-  DepthwiseConv2DOp::Adaptor adaptor(operands.getValues());
+  DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
 
   int32_t inputWidth = ShapedType::kDynamicSize;
   int32_t inputHeight = ShapedType::kDynamicSize;
@@ -1178,21 +1186,23 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
   int32_t depthChannels = ShapedType::kDynamicSize;
 
   // Input shape describes input width/height and batch.
-  if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
-    outputShape[0] = inputTy.getDimSize(0);
-    inputHeight = inputTy.getDimSize(1);
-    inputWidth = inputTy.getDimSize(2);
-    inputChannels = inputTy.getDimSize(3);
+  ShapeAdaptor inputShape = operands.getShape(adaptor.input());
+  if (inputShape.hasRank()) {
+    outputShape[0] = inputShape.getDimSize(0);
+    inputHeight = inputShape.getDimSize(1);
+    inputWidth = inputShape.getDimSize(2);
+    inputChannels = inputShape.getDimSize(3);
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  if (auto weightTy = adaptor.weight().getType().dyn_cast<RankedTensorType>()) {
-    weightHeight = weightTy.getDimSize(0);
-    weightWidth = weightTy.getDimSize(1);
+  ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
+  if (weightShape.hasRank()) {
+    weightHeight = weightShape.getDimSize(0);
+    weightWidth = weightShape.getDimSize(1);
     inputChannels = ShapedType::isDynamic(inputChannels)
-                        ? weightTy.getDimSize(2)
+                        ? weightShape.getDimSize(2)
                         : inputChannels;
-    depthChannels = weightTy.getDimSize(3);
+    depthChannels = weightShape.getDimSize(3);
   }
 
   // If both inputChannels and depthChannels are available we can determine
@@ -1203,9 +1213,10 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
   }
 
   // Bias shape can describe the output channels.
-  if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
+  ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
+  if (biasShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
-                         ? biasTy.getDimSize(0)
+                         ? biasShape.getDimSize(0)
                          : outputShape[3];
   }
 
@@ -1213,9 +1224,9 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
   llvm::SmallVector<int64_t> padding;
   llvm::SmallVector<int64_t> stride;
 
-  getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
-  getI64Values(attributes.get("pad").cast<ArrayAttr>(), padding);
-  getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
+  getI64Values(adaptor.dilation(), dilation);
+  getI64Values(adaptor.pad(), padding);
+  getI64Values(adaptor.stride(), stride);
 
   if (!ShapedType::isDynamic(inputHeight) &&
       !ShapedType::isDynamic(weightHeight)) {
@@ -1241,9 +1252,9 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::llvm::Optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  TransposeConv2DOp::Adaptor adaptor(operands.getValues());
+  TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
   llvm::SmallVector<int64_t> outputShape;
-  getI64Values(attributes.get("out_shape").cast<ArrayAttr>(), outputShape);
+  getI64Values(adaptor.out_shape(), outputShape);
 
   int32_t inputWidth = ShapedType::kDynamicSize;
   int32_t inputHeight = ShapedType::kDynamicSize;
@@ -1251,27 +1262,30 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   int32_t weightHeight = ShapedType::kDynamicSize;
 
   // Input shape describes input width/height and batch.
-  if (auto inputTy = adaptor.input().getType().dyn_cast<RankedTensorType>()) {
+  ShapeAdaptor inputShape = operands.getShape(adaptor.input());
+  if (inputShape.hasRank()) {
     outputShape[0] = ShapedType::isDynamic(outputShape[0])
-                         ? inputTy.getDimSize(0)
+                         ? inputShape.getDimSize(0)
                          : outputShape[0];
-    inputHeight = inputTy.getDimSize(1);
-    inputWidth = inputTy.getDimSize(2);
+    inputHeight = inputShape.getDimSize(1);
+    inputWidth = inputShape.getDimSize(2);
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  if (auto weightTy = adaptor.filter().getType().dyn_cast<RankedTensorType>()) {
+  ShapeAdaptor weightShape = operands.getShape(adaptor.input());
+  if (weightShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
-                         ? weightTy.getDimSize(0)
+                         ? weightShape.getDimSize(0)
                          : outputShape[3];
-    weightHeight = weightTy.getDimSize(1);
-    weightWidth = weightTy.getDimSize(2);
+    weightHeight = weightShape.getDimSize(1);
+    weightWidth = weightShape.getDimSize(2);
   }
 
   // Bias shape can describe the output channels.
-  if (auto biasTy = adaptor.bias().getType().dyn_cast<RankedTensorType>()) {
+  ShapeAdaptor biasShape = operands.getShape(adaptor.input());
+  if (biasShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
-                         ? biasTy.getDimSize(0)
+                         ? biasShape.getDimSize(0)
                          : outputShape[3];
   }
 
@@ -1279,9 +1293,9 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   llvm::SmallVector<int64_t> padding;
   llvm::SmallVector<int64_t> stride;
 
-  getI64Values(attributes.get("dilation").cast<ArrayAttr>(), dilation);
-  getI64Values(attributes.get("out_pad").cast<ArrayAttr>(), padding);
-  getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
+  getI64Values(adaptor.dilation(), dilation);
+  getI64Values(adaptor.out_pad(), padding);
+  getI64Values(adaptor.stride(), stride);
 
   if (!ShapedType::isDynamic(inputHeight) &&
       !ShapedType::isDynamic(weightHeight)) {
@@ -1339,7 +1353,7 @@ LogicalResult IfOp::inferReturnTypeComponents(
     }
   }
 
-  for (auto result : resultKnowledge) {
+  for (const ValueKnowledge &result : resultKnowledge) {
     if (result.hasRank) {
       inferredReturnShapes.push_back(ShapedTypeComponents(result.sizes));
     } else {

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 390950e7550de..006f4148ea1fe 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
 using namespace mlir::tosa;
@@ -62,6 +63,21 @@ void propagateShapesToTosaIf(Operation &op) {
 }
 
 void propagateShapesInRegion(Region &region) {
+  DenseMap<Value, ShapedTypeComponents> shapesStorage;
+  auto setShapes = [&](Value val, Type t) {
+    if (auto st = t.dyn_cast<ShapedType>())
+      shapesStorage[val] = st;
+    else
+      shapesStorage[val] = t;
+  };
+  auto operandShape = [&](Value val) -> ShapeAdaptor {
+    // Query the WIP mapping rather than the type if set.
+    auto it = shapesStorage.find(val);
+    if (it == shapesStorage.end())
+      return nullptr;
+    return it->second;
+  };
+
   for (auto &block : region) {
     for (Operation &op : block) {
       if (op.getDialect()->getNamespace() !=
@@ -76,10 +92,12 @@ void propagateShapesInRegion(Region &region) {
         continue;
 
       SmallVector<ShapedTypeComponents> returnedShapes;
+
+      ValueShapeRange range(op.getOperands(), operandShape);
       if (shapeInterface
-              .inferReturnTypeComponents(
-                  op.getContext(), op.getLoc(), op.getOperands(),
-                  op.getAttrDictionary(), op.getRegions(), returnedShapes)
+              .inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
+                                         op.getAttrDictionary(),
+                                         op.getRegions(), returnedShapes)
               .succeeded()) {
         for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
           Value result = std::get<0>(it);
@@ -99,6 +117,7 @@ void propagateShapesInRegion(Region &region) {
           }
 
           // Determine the knowledge based on the output type.
+          // TODO: should also query WIP type probably
           Type resultTy = result.getType();
           auto currentKnowledge =
               ValueKnowledge::getKnowledgeFromType(resultTy);
@@ -122,11 +141,20 @@ void propagateShapesInRegion(Region &region) {
               ValueKnowledge::join(currentKnowledge, inferredKnowledge);
           if (!newKnowledge)
             continue;
-          result.setType(newKnowledge.getType());
+          setShapes(result, newKnowledge.getType());
         }
       }
     }
   }
+
+  // Actually update types with updated shape knowledge.
+  for (auto it : shapesStorage) {
+    auto result = it.second;
+    if (result.hasRank()) {
+      Type t = it.first.getType().cast<ShapedType>().clone(result.getDims());
+      it.first.setType(t);
+    }
+  }
 }
 
 /// Pass that performs shape propagation across TOSA operations. This includes

diff  --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index cdae04e563932..7a3f602ca3d43 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -13,6 +13,8 @@
 
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
 
@@ -20,6 +22,160 @@ namespace mlir {
 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
 } // namespace mlir
 
+bool ShapeAdaptor::hasRank() const {
+  if (val.isNull())
+    return false;
+  if (auto t = val.dyn_cast<Type>())
+    return t.cast<ShapedType>().hasRank();
+  if (val.is<Attribute>())
+    return true;
+  return val.get<ShapedTypeComponents *>()->hasRank();
+}
+
+Type ShapeAdaptor::getElementType() const {
+  if (val.isNull())
+    return nullptr;
+  if (auto t = val.dyn_cast<Type>())
+    return t.cast<ShapedType>().getElementType();
+  if (val.is<Attribute>())
+    return nullptr;
+  return val.get<ShapedTypeComponents *>()->getElementType();
+}
+
+void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
+  assert(hasRank());
+  if (auto t = val.dyn_cast<Type>()) {
+    ArrayRef<int64_t> vals = t.cast<ShapedType>().getShape();
+    res.assign(vals.begin(), vals.end());
+  } else if (auto attr = val.dyn_cast<Attribute>()) {
+    auto dattr = attr.cast<DenseIntElementsAttr>();
+    res.clear();
+    res.reserve(dattr.size());
+    for (auto it : dattr.getIntValues())
+      res.push_back(it.getSExtValue());
+  } else {
+    auto vals = val.get<ShapedTypeComponents *>()->getDims();
+    res.assign(vals.begin(), vals.end());
+  }
+}
+
+void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
+  assert(hasRank());
+  res.ranked = true;
+  getDims(res.dims);
+}
+
+int64_t ShapeAdaptor::getDimSize(int index) const {
+  assert(hasRank());
+  if (auto t = val.dyn_cast<Type>())
+    return t.cast<ShapedType>().getDimSize(index);
+  if (auto attr = val.dyn_cast<Attribute>())
+    return attr.cast<DenseIntElementsAttr>()
+        .getValue<APInt>({static_cast<uint64_t>(index)})
+        .getSExtValue();
+  auto *stc = val.get<ShapedTypeComponents *>();
+  return stc->getDims()[index];
+}
+
+int64_t ShapeAdaptor::getRank() const {
+  assert(hasRank());
+  if (auto t = val.dyn_cast<Type>())
+    return t.cast<ShapedType>().getRank();
+  if (auto attr = val.dyn_cast<Attribute>())
+    return attr.cast<DenseIntElementsAttr>().size();
+  return val.get<ShapedTypeComponents *>()->getDims().size();
+}
+
+bool ShapeAdaptor::hasStaticShape() const {
+  if (!hasRank())
+    return false;
+
+  if (auto t = val.dyn_cast<Type>())
+    return t.cast<ShapedType>().hasStaticShape();
+  if (auto attr = val.dyn_cast<Attribute>()) {
+    auto dattr = attr.cast<DenseIntElementsAttr>();
+    for (auto index : dattr.getIntValues())
+      if (ShapedType::isDynamic(index.getSExtValue()))
+        return false;
+    return true;
+  }
+  auto *stc = val.get<ShapedTypeComponents *>();
+  for (int64_t dim : stc->getDims())
+    if (ShapedType::isDynamic(dim))
+      return false;
+  return true;
+}
+
+int64_t ShapeAdaptor::getNumElements() const {
+  assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
+
+  if (auto t = val.dyn_cast<Type>())
+    return t.cast<ShapedType>().getNumElements();
+
+  if (auto attr = val.dyn_cast<Attribute>()) {
+    auto dattr = attr.cast<DenseIntElementsAttr>();
+    int64_t num = 1;
+    for (auto index : dattr.getIntValues()) {
+      num *= index.getZExtValue();
+      assert(num >= 0 && "integer overflow in element count computation");
+    }
+    return num;
+  }
+
+  auto *stc = val.get<ShapedTypeComponents *>();
+  int64_t num = 1;
+  for (int64_t dim : stc->getDims()) {
+    num *= dim;
+    assert(num >= 0 && "integer overflow in element count computation");
+  }
+  return num;
+}
+
+void ShapeAdaptor::dump() const {
+  if (!hasRank()) {
+    llvm::errs() << "<<unranked>>\n";
+    return;
+  }
+
+  SmallVector<int64_t> dims;
+  getDims(dims);
+  auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
+    if (ShapedType::isDynamic(dim))
+      return "?";
+    return llvm::formatv("{0}", dim).str();
+  });
+  llvm::errs() << "rank = " << getRank() << " dims = [";
+  llvm::interleave(mapped, llvm::errs(), "x");
+  llvm::errs() << "]\n";
+}
+
+ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
+  Value val = operator[](index);
+  if (valueToShape)
+    if (ShapeAdaptor ret = valueToShape(val))
+      return ret;
+
+  DenseIntElementsAttr attr;
+  if (!matchPattern(val, m_Constant(&attr)))
+    return nullptr;
+  if (attr.getType().getRank() != 1)
+    return nullptr;
+  return attr;
+}
+
+ShapeAdaptor ValueShapeRange::getShape(Value val) const {
+  if (operandShape)
+    if (ShapeAdaptor ret = operandShape(val))
+      return ret;
+  return val.getType();
+}
+
+ShapeAdaptor ValueShapeRange::getShape(int index) const {
+  if (index < 0 || static_cast<size_t>(index) >= size())
+    return nullptr;
+  return getShape(operator[](index));
+}
+
 LogicalResult mlir::detail::inferReturnTensorTypes(
     function_ref<LogicalResult(
         MLIRContext *, Optional<Location> location, ValueShapeRange operands,

diff  --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index 38a1a0b9ea2bb..003bbc41ef7c0 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -1,10 +1,13 @@
 add_mlir_unittest(MLIRInterfacesTests
   DataLayoutInterfacesTest.cpp
+  InferTypeOpInterfaceTest.cpp
 )
 
 target_link_libraries(MLIRInterfacesTests
   PRIVATE
   MLIRDataLayoutInterfaces
   MLIRDLTI
+  MLIRInferTypeOpInterface
   MLIRParser
+  MLIRStandard
 )

diff  --git a/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
new file mode 100644
index 0000000000000..0fce0e4b9c176
--- /dev/null
+++ b/mlir/unittests/Interfaces/InferTypeOpInterfaceTest.cpp
@@ -0,0 +1,103 @@
+//===- InferTypeOpInterfaceTest.cpp - Unit Test for type interface --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+
+class ValueShapeRangeTest : public testing::Test {
+protected:
+  void SetUp() override {
+    const char *ir = R"MLIR(
+      func @map(%arg : tensor<1xi64>) {
+        %0 = constant dense<[10]> : tensor<1xi64>
+        %1 = addi %arg, %0 : tensor<1xi64>
+        return
+      }
+    )MLIR";
+
+    registry.insert<StandardOpsDialect>();
+    ctx.appendDialectRegistry(registry);
+    module = parseSourceString(ir, &ctx);
+    mapFn = cast<FuncOp>(module->front());
+  }
+
+  // Create ValueShapeRange on the addi operation.
+  ValueShapeRange addiRange() {
+    auto &fnBody = mapFn.body();
+    return std::next(fnBody.front().begin())->getOperands();
+  }
+
+  DialectRegistry registry;
+  MLIRContext ctx;
+  OwningModuleRef module;
+  FuncOp mapFn;
+};
+
+TEST_F(ValueShapeRangeTest, ShapesFromValues) {
+  ValueShapeRange range = addiRange();
+
+  EXPECT_FALSE(range.getValueAsShape(0));
+  ASSERT_TRUE(range.getValueAsShape(1));
+  EXPECT_TRUE(range.getValueAsShape(1).hasRank());
+  EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
+  EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
+  EXPECT_EQ(range.getShape(1).getRank(), 1);
+  EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
+}
+
+TEST_F(ValueShapeRangeTest, MapValuesToShapes) {
+  ValueShapeRange range = addiRange();
+  ShapedTypeComponents fixed(SmallVector<int64_t>{30});
+  auto mapping = [&](Value val) -> ShapeAdaptor {
+    if (val == mapFn.getArgument(0))
+      return &fixed;
+    return nullptr;
+  };
+  range.setValueToShapeMapping(mapping);
+
+  ASSERT_TRUE(range.getValueAsShape(0));
+  EXPECT_TRUE(range.getValueAsShape(0).hasRank());
+  EXPECT_EQ(range.getValueAsShape(0).getRank(), 1);
+  EXPECT_EQ(range.getValueAsShape(0).getDimSize(0), 30);
+  ASSERT_TRUE(range.getValueAsShape(1));
+  EXPECT_TRUE(range.getValueAsShape(1).hasRank());
+  EXPECT_EQ(range.getValueAsShape(1).getRank(), 1);
+  EXPECT_EQ(range.getValueAsShape(1).getDimSize(0), 10);
+}
+
+TEST_F(ValueShapeRangeTest, SettingShapes) {
+  ShapedTypeComponents shape(SmallVector<int64_t>{10, 20});
+  ValueShapeRange range = addiRange();
+  auto mapping = [&](Value val) -> ShapeAdaptor {
+    if (val == mapFn.getArgument(0))
+      return &shape;
+    return nullptr;
+  };
+  range.setOperandShapeMapping(mapping);
+
+  ASSERT_TRUE(range.getShape(0));
+  EXPECT_EQ(range.getShape(0).getRank(), 2);
+  EXPECT_EQ(range.getShape(0).getDimSize(0), 10);
+  EXPECT_EQ(range.getShape(0).getDimSize(1), 20);
+  ASSERT_TRUE(range.getShape(1));
+  EXPECT_EQ(range.getShape(1).getRank(), 1);
+  EXPECT_EQ(range.getShape(1).getDimSize(0), 1);
+  EXPECT_FALSE(range.getShape(2));
+}


        


More information about the Mlir-commits mailing list