[Mlir-commits] [mlir] [TOSA] Move CreateOpAndInfer into ConversionUtils.h (PR #106122)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 26 12:29:02 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

This moves CreateOpAndInfer from TF legalize_util.h into ConversionUtils.h

Renamed to CreateOpAndInferShape so we can upstream this independently of tensorflow (otherwise a redefinition error would break TF compile if not upstreamed together with removal of CreateOpAndInfer in TF)


---
Full diff: https://github.com/llvm/llvm-project/pull/106122.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h (+137) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+23-70) 
- (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+9-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ceab7d9c628a54..60e7ed1ce2f876 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -15,7 +15,9 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/PatternMatch.h"
 #include <optional>
 
@@ -79,6 +81,141 @@ checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
 LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
                             Value &input1, Value &input2);
 
+LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
+                            Value &input2);
+
+namespace {
+
+// Creates a TOSA operation and performs shape inference on the individual
+// op. This allows shape inference during the TFLite to TOSA lowering.
+template <typename TosaOp, typename... Args>
+TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type result_ty,
+                             Args &&...args) {
+  auto op = builder.create<TosaOp>(result_ty, args...);
+
+  InferShapedTypeOpInterface shapeInterface =
+      dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
+  if (!shapeInterface)
+    return op;
+
+  SmallVector<ShapedTypeComponents> returnedShapes;
+  if (shapeInterface
+          .inferReturnTypeComponents(op.getContext(), builder.getLoc(),
+                                     op->getOperands(), op->getAttrDictionary(),
+                                     op->getPropertiesStorage(),
+                                     op->getRegions(), returnedShapes)
+          .failed())
+    return op;
+
+  // We need to use the element type of the existing result type to generate
+  // the new result shaped type. This is because rescale can include a cast to
+  // different bit-width types and does not have a TypeAttr to define the
+  // target type.
+  auto result = op->getResult(0);
+  auto predictedShape = returnedShapes[0];
+  auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);
+
+  // Compute the knowledge based on the inferred type.
+  auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
+  inferredKnowledge.dtype = mlir::cast<ShapedType>(result_ty).getElementType();
+  inferredKnowledge.hasRank = predictedShape.hasRank();
+  if (predictedShape.hasRank()) {
+    for (auto dim : predictedShape.getDims()) {
+      inferredKnowledge.sizes.push_back(dim);
+    }
+  }
+
+  // Compute the new type based on the joined version.
+  auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+  Type new_ty =
+      newKnowledge.hasRank
+          ? Type{mlir::RankedTensorType::get(llvm::ArrayRef(newKnowledge.sizes),
+                                             newKnowledge.dtype)}
+          : Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)};
+  result.setType(new_ty);
+  return op;
+}
+
+} // namespace
+
+// Creates a TOSA operation by:
+//   - first equalize ranks for ops with SameOperandsAndResultRank trait
+//   - create operator
+//   - performs shape inference on this operator
+template <typename TosaOp, typename... Args>
+TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type result_ty,
+                             Args &&...args) {
+  if (TosaOp::template hasTrait<OpTrait::SameOperandsAndResultRank>()) {
+    // op requires same ranks for tensor operands
+    if constexpr (sizeof...(Args) == 2) {
+      auto argX = std::get<0>(std::tie(args...));
+      auto argY = std::get<1>(std::tie(args...));
+      using ArgX = decltype(argX);
+      using ArgY = decltype(argY);
+      if constexpr (std::is_same_v<ArgX, Value> &&
+                    std::is_same_v<ArgY, Value>) {
+        Value x = std::get<0>(std::tie(args...));
+        Value y = std::get<1>(std::tie(args...));
+        if (EqualizeRanks(builder, x, y).failed()) {
+          // incompatible broadcast shapes, no reshape is inserted
+          // ResultsBroadcastableShape verify will handle this
+        }
+        return createOpAndInferShape<TosaOp>(builder, result_ty, x, y);
+      }
+    }
+    if constexpr (sizeof...(Args) == 3) {
+      auto argX = std::get<0>(std::tie(args...));
+      auto argY = std::get<1>(std::tie(args...));
+      auto argZ = std::get<2>(std::tie(args...));
+      using ArgX = decltype(argX);
+      using ArgY = decltype(argY);
+      using ArgZ = decltype(argZ);
+      if constexpr (std::is_same_v<ArgX, Value> &&
+                    std::is_same_v<ArgY, Value> && std::is_same_v<ArgZ, bool>) {
+        // special case for ArithmeticRightShiftOp
+        Value x = std::get<0>(std::tie(args...));
+        Value y = std::get<1>(std::tie(args...));
+        bool round = std::get<2>(std::tie(args...));
+        if (EqualizeRanks(builder, x, y).failed()) {
+          // incompatible broadcast shapes, no reshape is inserted
+          // ResultsBroadcastableShape verify will handle this
+        }
+        return createOpAndInferShape<TosaOp>(builder, result_ty, x, y, round);
+      }
+      if constexpr (std::is_same_v<ArgX, Value> &&
+                    std::is_same_v<ArgY, Value> &&
+                    std::is_same_v<ArgZ, Value>) {
+        // special case for Select
+        Value x = std::get<0>(std::tie(args...));
+        Value y = std::get<1>(std::tie(args...));
+        Value z = std::get<2>(std::tie(args...));
+
+        if (EqualizeRanks(builder, x, y).failed() ||
+            EqualizeRanks(builder, x, z).failed() ||
+            EqualizeRanks(builder, y, z).failed()) {
+          // incompatible broadcast shapes, no reshape is inserted
+          // ResultsBroadcastableShape verify will handle this
+        }
+
+        return createOpAndInferShape<TosaOp>(builder, result_ty, x, y, z);
+      }
+    }
+  }
+
+  return createOpAndInferShape<TosaOp>(builder, result_ty, args...);
+}
+
+// Creates a TOSA operation by:
+//   - first equalize ranks for ops with SameOperandsAndResultRank trait
+//   - create operator
+//   - performs shape inference on this operator
+template <typename TosaOp, typename... Args>
+TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
+                             Type result_ty, Args &&...args) {
+  ImplicitLocOpBuilder builder(loc, rewriter);
+  return CreateOpAndInferShape<TosaOp>(builder, result_ty, args...);
+}
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index a94bb3a920b1db..0779cdb9667a1a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -26,53 +26,6 @@ using namespace mlir::tosa;
 
 namespace {
 
-template <typename TosaOp, typename... Args>
-TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
-                        Args &&...args) {
-  auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
-
-  InferShapedTypeOpInterface shapeInterface =
-      dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
-  if (!shapeInterface)
-    return op;
-
-  SmallVector<ShapedTypeComponents> returnedShapes;
-  if (shapeInterface
-          .inferReturnTypeComponents(
-              op.getContext(), op.getLoc(), op->getOperands(),
-              op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
-              op->getRegions(), returnedShapes)
-          .failed())
-    return op;
-
-  // We need to use the element type of the existing result type to generate
-  // the new result shaped type. This is because rescale can include a cast to
-  // different bit-width types and does not have a TypeAttr to define the
-  // target type.
-  auto result = op->getResult(0);
-  auto predictedShape = returnedShapes[0];
-  auto currentKnowledge =
-      mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);
-
-  // Compute the knowledge based on the inferred type.
-  auto inferredKnowledge =
-      mlir::tosa::ValueKnowledge::getPessimisticValueState();
-  inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
-  inferredKnowledge.hasRank = predictedShape.hasRank();
-  if (predictedShape.hasRank()) {
-    for (auto dim : predictedShape.getDims()) {
-      inferredKnowledge.sizes.push_back(dim);
-    }
-  }
-
-  // Compute the new type based on the joined version.
-  auto newKnowledge =
-      mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
-  auto newTy = newKnowledge.getType();
-  result.setType(newTy);
-  return op;
-}
-
 class TransposeConvNonStridedConverter
     : public OpRewritePattern<tosa::TransposeConv2DOp> {
 public:
@@ -187,20 +140,20 @@ class TransposeConvStridedConverter
         (weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
     DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
         RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
-    Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
+    Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
         rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
 
     if (op.getQuantizationInfo().has_value()) {
       auto quantInfo = op.getQuantizationInfo().value();
-      weight = createOpAndInfer<tosa::PadOp>(
+      weight = CreateOpAndInferShape<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
           weightPaddingVal, nullptr,
           rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
 
     } else {
-      weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
-                                             UnrankedTensorType::get(weightETy),
-                                             weight, weightPaddingVal);
+      weight = CreateOpAndInferShape<tosa::PadOp>(
+          rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+          weightPaddingVal);
     }
 
     weightTy = cast<ShapedType>(weight.getType());
@@ -212,7 +165,7 @@ class TransposeConvStridedConverter
         outputChannels, weightHeight / stride[0],
         stride[0],      weightWidth / stride[1],
         stride[1],      inputChannels};
-    weight = createOpAndInfer<tosa::ReshapeOp>(
+    weight = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
         rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
 
@@ -221,7 +174,7 @@ class TransposeConvStridedConverter
         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
         rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
 
-    weight = createOpAndInfer<tosa::TransposeOp>(
+    weight = CreateOpAndInferShape<tosa::TransposeOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
         transposeWeightVal);
 
@@ -229,15 +182,15 @@ class TransposeConvStridedConverter
     llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
         outputChannels * stride[0] * stride[1], weightHeight / stride[0],
         weightWidth / stride[1], inputChannels};
-    weight = createOpAndInfer<tosa::ReshapeOp>(
+    weight = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
         rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
     ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
 
-    weight = createOpAndInfer<tosa::ReverseOp>(
+    weight = CreateOpAndInferShape<tosa::ReverseOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
         /* axis = */ rewriter.getI32IntegerAttr(1));
-    weight = createOpAndInfer<tosa::ReverseOp>(
+    weight = CreateOpAndInferShape<tosa::ReverseOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
         /* axis = */ rewriter.getI32IntegerAttr(2));
 
@@ -251,19 +204,19 @@ class TransposeConvStridedConverter
     DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
         RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
 
-    Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
+    Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
         rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
 
     if (op.getQuantizationInfo().has_value()) {
       auto quantInfo = op.getQuantizationInfo().value();
-      input = createOpAndInfer<tosa::PadOp>(
+      input = CreateOpAndInferShape<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(inputETy), input,
           inputPaddingVal, nullptr,
           rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
     } else {
-      input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
-                                            UnrankedTensorType::get(inputETy),
-                                            input, inputPaddingVal);
+      input = CreateOpAndInferShape<tosa::PadOp>(
+          rewriter, loc, UnrankedTensorType::get(inputETy), input,
+          inputPaddingVal);
     }
 
     // We use a zero bias as we need to broadcast the bias.
@@ -279,7 +232,7 @@ class TransposeConvStridedConverter
     // Perform the convolution using the zero bias.
     Value conv2d;
     if (op.getQuantizationInfo()) {
-      conv2d = createOpAndInfer<tosa::Conv2DOp>(
+      conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
                    weight, zeroBias,
                    /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
@@ -288,7 +241,7 @@ class TransposeConvStridedConverter
                    *op.getQuantizationInfo())
                    .getResult();
     } else {
-      conv2d = createOpAndInfer<tosa::Conv2DOp>(
+      conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
                    weight, zeroBias,
                    /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
@@ -307,7 +260,7 @@ class TransposeConvStridedConverter
     // Factor striding out of the convolution result.
     llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
         batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
-    conv2d = createOpAndInfer<tosa::ReshapeOp>(
+    conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
         rewriter.getDenseI64ArrayAttr(convReshapeDims0));
 
@@ -316,14 +269,14 @@ class TransposeConvStridedConverter
         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
         rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
 
-    conv2d = createOpAndInfer<tosa::TransposeOp>(
+    conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
         rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
         transposeConvVal);
 
     // Fuse striding behavior back into width / height.
     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
         batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
-    conv2d = createOpAndInfer<tosa::ReshapeOp>(
+    conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
         rewriter.getDenseI64ArrayAttr(convReshapeDims1));
 
@@ -348,7 +301,7 @@ class TransposeConvStridedConverter
     sliceSize[1] = resultSliceHeight;
     sliceSize[2] = resultSliceWidth;
 
-    auto slice = createOpAndInfer<tosa::SliceOp>(
+    auto slice = CreateOpAndInferShape<tosa::SliceOp>(
                      rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
                      rewriter.getDenseI64ArrayAttr(sliceBegin),
                      rewriter.getDenseI64ArrayAttr(sliceSize))
@@ -363,10 +316,10 @@ class TransposeConvStridedConverter
     DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
         RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
 
-    Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
+    Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
         rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
 
-    Value resultPad = createOpAndInfer<tosa::PadOp>(
+    Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
         rewriter, loc, UnrankedTensorType::get(resultETy), slice,
         resultPaddingVal);
 
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index f276924a8a9f62..1f6e3b2ab83919 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -102,6 +102,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
 
 LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
                                         Value &input1, Value &input2) {
+  ImplicitLocOpBuilder builder(loc, rewriter);
+  return EqualizeRanks(builder, input1, input2);
+}
+
+LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
+                                        Value &input1, Value &input2) {
   auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
   auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
 
@@ -140,9 +146,9 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
   auto reshapeOutputType = RankedTensorType::get(
       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
 
-  auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
-      loc, reshapeOutputType, lowerTensorValue,
-      rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
+  auto reshapeLower = builder.create<tosa::ReshapeOp>(
+      reshapeOutputType, lowerTensorValue,
+      builder.getDenseI64ArrayAttr(reshapeOutputShape));
 
   if (input1Rank > input2Rank) {
     input1 = higherTensorValue;

``````````

</details>


https://github.com/llvm/llvm-project/pull/106122


More information about the Mlir-commits mailing list