[Mlir-commits] [mlir] 1b7feac - [mlir][tosa] Split canonicalization and folders out of TosaOps.

Jacques Pienaar llvmlistbot at llvm.org
Fri Jul 22 07:20:33 PDT 2022


Author: Jacques Pienaar
Date: 2022-07-22T07:20:25-07:00
New Revision: 1b7feac2a6c42f5f4302579eeafbe904f5ccf972

URL: https://github.com/llvm/llvm-project/commit/1b7feac2a6c42f5f4302579eeafbe904f5ccf972
DIFF: https://github.com/llvm/llvm-project/commit/1b7feac2a6c42f5f4302579eeafbe904f5ccf972.diff

LOG: [mlir][tosa] Split canonicalization and folders out of TosaOps.

Scope ops file to ops. Used canonicalization as grouping for canonicalization
patterns and folders (also considered OpTransforms but that felt too generic
and the former two are used together).

Reviewed By: silvas, rsuderman

Differential Revision: https://reviews.llvm.org/D130297

Added: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
    mlir/lib/Dialect/Tosa/CMakeLists.txt
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index ff4225bef39df..afdd8017cec58 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -34,17 +34,6 @@ namespace tosa {
 } // namespace tosa
 } // namespace mlir
 
-//===----------------------------------------------------------------------===//
-// Utility Functions
-//===----------------------------------------------------------------------===//
-namespace mlir {
-namespace tosa {
-/// Appends the canonicalization patterns for all the TOSA ops to the `patterns`
-void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
-                                             RewritePatternSet &patterns);
-} // namespace tosa
-} // namespace mlir
-
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"
 

diff  --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index 520f6425cbe67..77e9051c7418a 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -1,7 +1,8 @@
 add_mlir_dialect_library(MLIRTosaDialect
+  IR/TosaOps.cpp
+  IR/TosaCanonicalizations.cpp
   Utils/ConversionUtils.cpp
   Utils/QuantUtils.cpp
-  IR/TosaOps.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
new file mode 100644
index 0000000000000..7bb6339c69b7f
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -0,0 +1,543 @@
+//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file
+// TOSA canonicalization patterns and folders.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/FoldUtils.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::tosa;
+
+//===----------------------------------------------------------------------===//
+// Operator Canonicalizers.
+//===----------------------------------------------------------------------===//
+
+struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
+  using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ConcatOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.input1().size() != 1)
+      return failure();
+    if (op.input1().front().getType() != op.getType()) {
+      rewriter
+          .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
+                                              op.input1().front())
+          .getResult();
+      return success();
+    }
+
+    rewriter.replaceOp(op, op.input1().front());
+    return success();
+  }
+};
+
+void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add<ConcatOptimization>(context);
+}
+
+struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
+  using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ReshapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input1();
+    Operation *definingOp = input.getDefiningOp();
+    if (!definingOp)
+      return failure();
+
+    if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
+      rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
+          op, op.getType(), reshapeOp.input1(), op.new_shape());
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
+  using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ReshapeOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input1();
+    ArrayAttr newShape = op.new_shape();
+
+    // Check if input is constant
+    DenseElementsAttr inputAttr;
+    if (!matchPattern(input, m_Constant(&inputAttr)))
+      return failure();
+
+    // Check if has >1 consumer and is not splat
+    if (!input.hasOneUse() && !inputAttr.isSplat())
+      return failure();
+
+    // Grab the new shape
+    SmallVector<int64_t> newShapeValues = llvm::to_vector<6>(
+        llvm::map_range(newShape.getValue(), [](const Attribute &val) {
+          return val.cast<IntegerAttr>().getValue().getSExtValue();
+        }));
+
+    // Build new const op with correct output shape
+    ShapedType inputShape = input.getType().cast<ShapedType>();
+    DenseElementsAttr outputAttr =
+        inputAttr.reshape(inputShape.clone(newShapeValues));
+    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputAttr.getType(),
+                                               outputAttr);
+    return success();
+  }
+};
+
+void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                            MLIRContext *context) {
+  results.add<ReshapeReshapeOptimization>(context);
+  results.add<ReshapeConstOptimization>(context);
+}
+
+LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
+  auto notOp = op.pred().getDefiningOp<tosa::LogicalNotOp>();
+  if (!notOp)
+    return failure();
+  rewriter.updateRootInPlace(op, [&]() {
+    op.getOperation()->setOperands(
+        {notOp.input1(), op.on_false(), op.on_true()});
+  });
+  return success();
+}
+
+struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TransposeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto perm = op.perms();
+
+    DenseIntElementsAttr permAttr;
+    if (!matchPattern(perm, m_Constant(&permAttr))) {
+      return failure();
+    }
+
+    SmallVector<int64_t> permValues = llvm::to_vector<6>(
+        llvm::map_range(permAttr.getValues<APInt>(),
+                        [](const APInt &val) { return val.getSExtValue(); }));
+
+    for (int i = 0, s = permValues.size(); i < s; i++) {
+      if (i != permValues[i]) {
+        return failure();
+      }
+    }
+
+    rewriter.replaceOp(op, op.input1());
+    return success();
+  }
+};
+
+void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                              MLIRContext *context) {
+  results.add<NoOpOptimization>(context);
+}
+
+struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::AddOp op,
+                                PatternRewriter &rewriter) const override {
+    auto input1 = op.input1();
+    auto input2 = op.input2();
+
+    DenseElementsAttr input1Attr;
+    if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
+        input2.getType() == op.getType()) {
+      if (input1Attr.getType().getElementType().isa<IntegerType>() &&
+          input1Attr.getSplatValue<APInt>().isZero()) {
+        rewriter.replaceOp(op, op.input2());
+        return success();
+      }
+    }
+
+    DenseElementsAttr input2Attr;
+    if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
+        input1.getType() == op.getType()) {
+      if (input2Attr.getType().getElementType().isa<IntegerType>() &&
+          input2Attr.getSplatValue<APInt>().isZero()) {
+        rewriter.replaceOp(op, op.input1());
+        return success();
+      }
+    }
+
+    return failure();
+  }
+};
+
+void AddOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                        MLIRContext *context) {
+  results.add<AddZeroOptimization>(context);
+}
+
+struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::MulOp op,
+                                PatternRewriter &rewriter) const override {
+    auto input1 = op.input1();
+    auto input2 = op.input2();
+
+    DenseElementsAttr input1Attr;
+    if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
+        input2.getType() == op.getType()) {
+      if (input1Attr.getType().getElementType().isa<FloatType>() &&
+          input1Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
+        rewriter.replaceOp(op, op.input2());
+        return success();
+      }
+
+      if (input1Attr.getType().getElementType().isa<IntegerType>() &&
+          matchPattern(input1, m_One())) {
+        rewriter.replaceOp(op, op.input2());
+        return success();
+      }
+    }
+
+    DenseElementsAttr input2Attr;
+    if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
+        input1.getType() == op.getType()) {
+      if (input2Attr.getType().getElementType().isa<FloatType>() &&
+          input2Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
+        rewriter.replaceOp(op, op.input1());
+        return success();
+      }
+
+      if (input2Attr.getType().getElementType().isa<IntegerType>() &&
+          matchPattern(input2, m_One())) {
+        rewriter.replaceOp(op, op.input1());
+        return success();
+      }
+    }
+
+    return failure();
+  }
+};
+
+void MulOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                        MLIRContext *context) {
+  results.add<MulOneOptimization>(context);
+}
+
+struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::PadOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.pad_const())
+      return failure();
+
+    auto input = op.input1();
+    auto padding = op.padding();
+
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    Type elementTy = inputTy.getElementType();
+
+    Attribute constantAttr;
+    if (elementTy.isa<FloatType>()) {
+      constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
+    } else if (elementTy.isa<IntegerType>() && !op.quantization_info()) {
+      constantAttr = rewriter.getIntegerAttr(elementTy, 0);
+    } else if (elementTy.isa<IntegerType>() && op.quantization_info()) {
+      auto value = op.quantization_info()->getInputZp();
+      constantAttr = rewriter.getIntegerAttr(elementTy, value);
+    }
+
+    if (!constantAttr) {
+      return rewriter.notifyMatchFailure(
+          op,
+          "tosa.pad to linalg lowering encountered an unknown element type");
+    }
+
+    auto denseAttr = DenseElementsAttr::get(
+        RankedTensorType::get({}, elementTy), constantAttr);
+    auto constantVal = rewriter.create<tosa::ConstOp>(
+        op.getLoc(), denseAttr.getType(), denseAttr);
+
+    rewriter.replaceOpWithNewOp<tosa::PadOp>(
+        op, op.getType(), ValueRange{input, padding, constantVal},
+        op->getAttrs());
+    return success();
+  }
+};
+
+void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                        MLIRContext *context) {
+  results.add<MaterializePadValue>(context);
+}
+
+struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input();
+    Value output = op.output();
+    ShapedType inputType = input.getType().cast<ShapedType>();
+    ShapedType outputType = output.getType().cast<ShapedType>();
+
+    if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
+      return failure();
+    }
+
+    // If the output and input shapes are 1x1, then this is a no op.
+    ArrayRef<int64_t> outputShape = outputType.getShape();
+    if (outputShape[1] != 1 || outputShape[2] != 1) {
+      return failure();
+    }
+
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    if (inputShape[1] != 1 || inputShape[2] != 1) {
+      return failure();
+    }
+
+    rewriter.replaceOp(op, input);
+    return success();
+  }
+};
+
+void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                              MLIRContext *context) {
+  results.add<MaxPool2dIsNoOp>(context);
+}
+
+struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ClampOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input();
+    auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
+    auto inputElementType = inputType.getElementType();
+
+    if (!inputType.hasStaticShape()) {
+      return failure();
+    }
+
+    if (inputElementType.isF32()) {
+      auto minClamp = op.min_fp();
+      auto maxClamp = op.max_fp();
+      bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) &&
+                   minClamp.isNegative();
+      bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) &&
+                   !maxClamp.isNegative();
+
+      if (isMin && isMax) {
+        rewriter.replaceOp(op, input);
+        return success();
+      }
+      return failure();
+    }
+
+    if (inputElementType.isUnsignedInteger()) {
+      int64_t minClamp = op.min_int();
+      int64_t maxClamp = op.max_int();
+
+      int64_t intMin =
+          APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
+              .getZExtValue();
+      int64_t intMax =
+          APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
+              .getZExtValue();
+
+      if (minClamp <= intMin && maxClamp >= intMax) {
+        rewriter.replaceOp(op, input);
+        return success();
+      }
+      return failure();
+    }
+
+    if (inputElementType.isa<IntegerType>()) {
+      int64_t minClamp = op.min_int();
+      int64_t maxClamp = op.max_int();
+
+      int64_t intMin =
+          APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
+              .getSExtValue();
+      int64_t intMax =
+          APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
+              .getSExtValue();
+
+      if (minClamp <= intMin && maxClamp >= intMax) {
+        rewriter.replaceOp(op, input);
+        return success();
+      }
+      return failure();
+    }
+
+    return failure();
+  }
+};
+
+struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
+  using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ClampOp op,
+                                PatternRewriter &rewriter) const override {
+    Value input = op.input();
+
+    Operation *definingOp = input.getDefiningOp();
+    if (!definingOp)
+      return failure();
+
+    if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
+      auto minFp = std::max(op.min_fp(), clampOp.min_fp()).convertToFloat();
+      auto maxFp = std::min(op.max_fp(), clampOp.max_fp()).convertToFloat();
+
+      auto minInt = std::max(op.min_int(), clampOp.min_int());
+      auto maxInt = std::min(op.max_int(), clampOp.max_int());
+
+      rewriter.replaceOpWithNewOp<tosa::ClampOp>(
+          op, op.getType(), clampOp.input(), rewriter.getI64IntegerAttr(minInt),
+          rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
+          rewriter.getF32FloatAttr(maxFp));
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                          MLIRContext *context) {
+  results.add<ClampIsNoOp>(context);
+  results.add<ClampClampOptimization>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// Operator Folders.
+//===----------------------------------------------------------------------===//
+
+OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
+  if (input().getType() == getType())
+    return input();
+  return {};
+}
+
+OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.empty() && "constant has no operands");
+  return valueAttr();
+}
+
+#define REDUCE_FOLDER(OP)                                                      \
+  OpFoldResult OP::fold(ArrayRef<Attribute> operands) {                        \
+    ShapedType inputTy = input().getType().cast<ShapedType>();                 \
+    if (!inputTy.hasRank())                                                    \
+      return {};                                                               \
+    if (inputTy.getDimSize(axis()) == 1)                                       \
+      return input();                                                          \
+    return {};                                                                 \
+  }
+
+REDUCE_FOLDER(ReduceAllOp)
+REDUCE_FOLDER(ReduceAnyOp)
+REDUCE_FOLDER(ReduceMaxOp)
+REDUCE_FOLDER(ReduceMinOp)
+REDUCE_FOLDER(ReduceProdOp)
+REDUCE_FOLDER(ReduceSumOp)
+#undef REDUCE_FOLDER
+
+OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
+  auto inputTy = input1().getType().dyn_cast<RankedTensorType>();
+  auto outputTy = getType().dyn_cast<RankedTensorType>();
+
+  if (!inputTy || !outputTy || inputTy != outputTy)
+    return {};
+  return input1();
+}
+
+OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
+  // If the pad is all zeros we can fold this operation away.
+  if (operands[1]) {
+    auto densePad = operands[1].cast<DenseElementsAttr>();
+    if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
+      return input1();
+    }
+  }
+
+  return {};
+}
+
+OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
+  auto inputTy = input().getType().dyn_cast<RankedTensorType>();
+  auto outputTy = getType().dyn_cast<RankedTensorType>();
+
+  if (!inputTy || !outputTy || inputTy != outputTy)
+    return {};
+  if (inputTy.hasStaticShape())
+    return input();
+
+  return {};
+}
+
+OpFoldResult tosa::SelectOp::fold(ArrayRef<Attribute> operands) {
+  if (on_true() == on_false())
+    return on_true();
+
+  auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+  if (!predicate)
+    return {};
+
+  if (!predicate.isSplat())
+    return {};
+  return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
+                                                         : on_false();
+}
+
+OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
+  bool allOnes = true;
+  for (Attribute val : multiples().getValue()) {
+    allOnes = allOnes && val.cast<IntegerAttr>().getValue().getSExtValue() == 1;
+  }
+
+  if (allOnes && input1().getType() == getType())
+    return input1();
+  return {};
+}
+
+OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
+  if (!operands[1])
+    return {};
+
+  // Transposing splat values just means reshaping.
+  if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
+    if (input.isSplat())
+      return input.reshape(getType().cast<ShapedType>());
+  }
+
+  auto perms = llvm::to_vector<6>(llvm::map_range(
+      operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
+      [](const APInt &val) { return val.getSExtValue(); }));
+
+  if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
+      input1().getType() == getType())
+    return input1();
+  return {};
+}

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 93fddb5e5d6c8..38a067bcd6551 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -21,9 +21,7 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
-#include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -96,533 +94,6 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return nullptr;
 }
 
-//===----------------------------------------------------------------------===//
-// Operator Canonicalizers.
-//===----------------------------------------------------------------------===//
-
-template <typename... Args>
-void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
-  (void)std::initializer_list<int>{
-      0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...};
-}
-
-void mlir::tosa::populateTosaOpsCanonicalizationPatterns(
-    MLIRContext *ctx, RewritePatternSet &patterns) {
-  addOpsCanonicalizations<
-#define GET_OP_LIST
-#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
-      >(ctx, patterns);
-}
-
-struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
-  using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::ConcatOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.input1().size() != 1)
-      return failure();
-    if (op.input1().front().getType() != op.getType()) {
-      rewriter
-          .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
-                                              op.input1().front())
-          .getResult();
-      return success();
-    }
-
-    rewriter.replaceOp(op, op.input1().front());
-    return success();
-  }
-};
-
-void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                           MLIRContext *context) {
-  results.add<ConcatOptimization>(context);
-}
-
-struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
-  using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::ReshapeOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input1();
-    Operation *definingOp = input.getDefiningOp();
-    if (!definingOp)
-      return failure();
-
-    if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
-      rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-          op, op.getType(), reshapeOp.input1(), op.new_shape());
-      return success();
-    }
-
-    return failure();
-  }
-};
-
-struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
-  using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::ReshapeOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input1();
-    ArrayAttr newShape = op.new_shape();
-
-    // Check if input is constant
-    DenseElementsAttr inputAttr;
-    if (!matchPattern(input, m_Constant(&inputAttr)))
-      return failure();
-
-    // Check if has >1 consumer and is not splat
-    if (!input.hasOneUse() && !inputAttr.isSplat())
-      return failure();
-
-    // Grab the new shape
-    SmallVector<int64_t> newShapeValues = llvm::to_vector<6>(
-        llvm::map_range(newShape.getValue(), [](const Attribute &val) {
-          return val.cast<IntegerAttr>().getValue().getSExtValue();
-        }));
-
-    // Build new const op with correct output shape
-    ShapedType inputShape = input.getType().cast<ShapedType>();
-    DenseElementsAttr outputAttr =
-        inputAttr.reshape(inputShape.clone(newShapeValues));
-    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputAttr.getType(),
-                                               outputAttr);
-    return success();
-  }
-};
-
-void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                            MLIRContext *context) {
-  results.add<ReshapeReshapeOptimization>(context);
-  results.add<ReshapeConstOptimization>(context);
-}
-
-LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
-  auto notOp = op.pred().getDefiningOp<tosa::LogicalNotOp>();
-  if (!notOp)
-    return failure();
-  rewriter.updateRootInPlace(op, [&]() {
-    op.getOperation()->setOperands(
-        {notOp.input1(), op.on_false(), op.on_true()});
-  });
-  return success();
-}
-
-struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    auto perm = op.perms();
-
-    DenseIntElementsAttr permAttr;
-    if (!matchPattern(perm, m_Constant(&permAttr))) {
-      return failure();
-    }
-
-    SmallVector<int64_t> permValues = llvm::to_vector<6>(
-        llvm::map_range(permAttr.getValues<APInt>(),
-                        [](const APInt &val) { return val.getSExtValue(); }));
-
-    for (int i = 0, s = permValues.size(); i < s; i++) {
-      if (i != permValues[i]) {
-        return failure();
-      }
-    }
-
-    rewriter.replaceOp(op, op.input1());
-    return success();
-  }
-};
-
-void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                              MLIRContext *context) {
-  results.add<NoOpOptimization>(context);
-}
-
-struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::AddOp op,
-                                PatternRewriter &rewriter) const override {
-    auto input1 = op.input1();
-    auto input2 = op.input2();
-
-    DenseElementsAttr input1Attr;
-    if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
-        input2.getType() == op.getType()) {
-      if (input1Attr.getType().getElementType().isa<IntegerType>() &&
-          input1Attr.getSplatValue<APInt>().isZero()) {
-        rewriter.replaceOp(op, op.input2());
-        return success();
-      }
-    }
-
-    DenseElementsAttr input2Attr;
-    if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
-        input1.getType() == op.getType()) {
-      if (input2Attr.getType().getElementType().isa<IntegerType>() &&
-          input2Attr.getSplatValue<APInt>().isZero()) {
-        rewriter.replaceOp(op, op.input1());
-        return success();
-      }
-    }
-
-    return failure();
-  }
-};
-
-void AddOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                        MLIRContext *context) {
-  results.add<AddZeroOptimization>(context);
-}
-
-struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::MulOp op,
-                                PatternRewriter &rewriter) const override {
-    auto input1 = op.input1();
-    auto input2 = op.input2();
-
-    DenseElementsAttr input1Attr;
-    if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
-        input2.getType() == op.getType()) {
-      if (input1Attr.getType().getElementType().isa<FloatType>() &&
-          input1Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
-        rewriter.replaceOp(op, op.input2());
-        return success();
-      }
-
-      if (input1Attr.getType().getElementType().isa<IntegerType>() &&
-          matchPattern(input1, m_One())) {
-        rewriter.replaceOp(op, op.input2());
-        return success();
-      }
-    }
-
-    DenseElementsAttr input2Attr;
-    if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
-        input1.getType() == op.getType()) {
-      if (input2Attr.getType().getElementType().isa<FloatType>() &&
-          input2Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
-        rewriter.replaceOp(op, op.input1());
-        return success();
-      }
-
-      if (input2Attr.getType().getElementType().isa<IntegerType>() &&
-          matchPattern(input2, m_One())) {
-        rewriter.replaceOp(op, op.input1());
-        return success();
-      }
-    }
-
-    return failure();
-  }
-};
-
-void MulOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                        MLIRContext *context) {
-  results.add<MulOneOptimization>(context);
-}
-
-struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::PadOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.pad_const())
-      return failure();
-
-    auto input = op.input1();
-    auto padding = op.padding();
-
-    ShapedType inputTy = input.getType().cast<ShapedType>();
-    Type elementTy = inputTy.getElementType();
-
-    Attribute constantAttr;
-    if (elementTy.isa<FloatType>()) {
-      constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-    } else if (elementTy.isa<IntegerType>() && !op.quantization_info()) {
-      constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-    } else if (elementTy.isa<IntegerType>() && op.quantization_info()) {
-      auto value = op.quantization_info()->getInputZp();
-      constantAttr = rewriter.getIntegerAttr(elementTy, value);
-    }
-
-    if (!constantAttr) {
-      return rewriter.notifyMatchFailure(
-          op,
-          "tosa.pad to linalg lowering encountered an unknown element type");
-    }
-
-    auto denseAttr = DenseElementsAttr::get(
-        RankedTensorType::get({}, elementTy), constantAttr);
-    auto constantVal = rewriter.create<tosa::ConstOp>(
-        op.getLoc(), denseAttr.getType(), denseAttr);
-
-    rewriter.replaceOpWithNewOp<tosa::PadOp>(
-        op, op.getType(), ValueRange{input, padding, constantVal},
-        op->getAttrs());
-    return success();
-  }
-};
-
-void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                        MLIRContext *context) {
-  results.add<MaterializePadValue>(context);
-}
-
-struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input();
-    Value output = op.output();
-    ShapedType inputType = input.getType().cast<ShapedType>();
-    ShapedType outputType = output.getType().cast<ShapedType>();
-
-    if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
-      return failure();
-    }
-
-    // If the output and input shapes are 1x1, then this is a no op.
-    ArrayRef<int64_t> outputShape = outputType.getShape();
-    if (outputShape[1] != 1 || outputShape[2] != 1) {
-      return failure();
-    }
-
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    if (inputShape[1] != 1 || inputShape[2] != 1) {
-      return failure();
-    }
-
-    rewriter.replaceOp(op, input);
-    return success();
-  }
-};
-
-void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                              MLIRContext *context) {
-  results.add<MaxPool2dIsNoOp>(context);
-}
-
-struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::ClampOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input();
-    auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
-    auto inputElementType = inputType.getElementType();
-
-    if (!inputType.hasStaticShape()) {
-      return failure();
-    }
-
-    if (inputElementType.isF32()) {
-      auto minClamp = op.min_fp();
-      auto maxClamp = op.max_fp();
-      bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) &&
-                   minClamp.isNegative();
-      bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) &&
-                   !maxClamp.isNegative();
-
-      if (isMin && isMax) {
-        rewriter.replaceOp(op, input);
-        return success();
-      }
-      return failure();
-    }
-
-    if (inputElementType.isUnsignedInteger()) {
-      int64_t minClamp = op.min_int();
-      int64_t maxClamp = op.max_int();
-
-      int64_t intMin =
-          APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
-              .getZExtValue();
-      int64_t intMax =
-          APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
-              .getZExtValue();
-
-      if (minClamp <= intMin && maxClamp >= intMax) {
-        rewriter.replaceOp(op, input);
-        return success();
-      }
-      return failure();
-    }
-
-    if (inputElementType.isa<IntegerType>()) {
-      int64_t minClamp = op.min_int();
-      int64_t maxClamp = op.max_int();
-
-      int64_t intMin =
-          APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
-              .getSExtValue();
-      int64_t intMax =
-          APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
-              .getSExtValue();
-
-      if (minClamp <= intMin && maxClamp >= intMax) {
-        rewriter.replaceOp(op, input);
-        return success();
-      }
-      return failure();
-    }
-
-    return failure();
-  }
-};
-
-struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
-  using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::ClampOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.input();
-
-    Operation *definingOp = input.getDefiningOp();
-    if (!definingOp)
-      return failure();
-
-    if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
-      auto minFp = std::max(op.min_fp(), clampOp.min_fp()).convertToFloat();
-      auto maxFp = std::min(op.max_fp(), clampOp.max_fp()).convertToFloat();
-
-      auto minInt = std::max(op.min_int(), clampOp.min_int());
-      auto maxInt = std::min(op.max_int(), clampOp.max_int());
-
-      rewriter.replaceOpWithNewOp<tosa::ClampOp>(
-          op, op.getType(), clampOp.input(), rewriter.getI64IntegerAttr(minInt),
-          rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
-          rewriter.getF32FloatAttr(maxFp));
-      return success();
-    }
-
-    return failure();
-  }
-};
-
-void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                          MLIRContext *context) {
-  results.add<ClampIsNoOp>(context);
-  results.add<ClampClampOptimization>(context);
-}
-
-//===----------------------------------------------------------------------===//
-// Operator Folders.
-//===----------------------------------------------------------------------===//
-
-OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
-  if (input().getType() == getType())
-    return input();
-  return {};
-}
-
-OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.empty() && "constant has no operands");
-  return valueAttr();
-}
-
-#define REDUCE_FOLDER(OP)                                                      \
-  OpFoldResult OP::fold(ArrayRef<Attribute> operands) {                        \
-    ShapedType inputTy = input().getType().cast<ShapedType>();                 \
-    if (!inputTy.hasRank())                                                    \
-      return {};                                                               \
-    if (inputTy.getDimSize(axis()) == 1)                                       \
-      return input();                                                          \
-    return {};                                                                 \
-  }
-
-REDUCE_FOLDER(ReduceAllOp)
-REDUCE_FOLDER(ReduceAnyOp)
-REDUCE_FOLDER(ReduceMaxOp)
-REDUCE_FOLDER(ReduceMinOp)
-REDUCE_FOLDER(ReduceProdOp)
-REDUCE_FOLDER(ReduceSumOp)
-#undef REDUCE_FOLDER
-
-OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
-  auto inputTy = input1().getType().dyn_cast<RankedTensorType>();
-  auto outputTy = getType().dyn_cast<RankedTensorType>();
-
-  if (!inputTy || !outputTy || inputTy != outputTy)
-    return {};
-  return input1();
-}
-
-OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
-  // If the pad is all zeros we can fold this operation away.
-  if (operands[1]) {
-    auto densePad = operands[1].cast<DenseElementsAttr>();
-    if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
-      return input1();
-    }
-  }
-
-  return {};
-}
-
-OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
-  auto inputTy = input().getType().dyn_cast<RankedTensorType>();
-  auto outputTy = getType().dyn_cast<RankedTensorType>();
-
-  if (!inputTy || !outputTy || inputTy != outputTy)
-    return {};
-  if (inputTy.hasStaticShape())
-    return input();
-
-  return {};
-}
-
-OpFoldResult tosa::SelectOp::fold(ArrayRef<Attribute> operands) {
-  if (on_true() == on_false())
-    return on_true();
-
-  auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
-  if (!predicate)
-    return {};
-
-  if (!predicate.isSplat())
-    return {};
-  return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
-                                                         : on_false();
-}
-
-OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
-  bool allOnes = true;
-  for (Attribute val : multiples().getValue()) {
-    allOnes = allOnes && val.cast<IntegerAttr>().getValue().getSExtValue() == 1;
-  }
-
-  if (allOnes && input1().getType() == getType())
-    return input1();
-  return {};
-}
-
-OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
-  if (!operands[1])
-    return {};
-
-  // Transposing splat values just means reshaping.
-  if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
-    if (input.isSplat())
-      return input.reshape(getType().cast<ShapedType>());
-  }
-
-  auto perms = llvm::to_vector<6>(llvm::map_range(
-      operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
-      [](const APInt &val) { return val.getSExtValue(); }));
-
-  if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
-      input1().getType() == getType())
-    return input1();
-  return {};
-}
-
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
index 7cf7ff14eb9ac..7814b91dd6f34 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp
@@ -21,6 +21,20 @@ using namespace mlir::tosa;
 
 namespace {
 
+template <typename... Args>
+void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) {
+  (void)std::initializer_list<int>{
+      0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...};
+}
+
+void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx,
+                                             RewritePatternSet &patterns) {
+  addOpsCanonicalizations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
+      >(ctx, patterns);
+}
+
 struct TosaLayerwiseConstantFoldPass
     : public TosaLayerwiseConstantFoldPassBase<TosaLayerwiseConstantFoldPass> {
   void runOnOperation() override {
@@ -29,7 +43,7 @@ struct TosaLayerwiseConstantFoldPass
     auto func = getOperation();
 
     mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
-    mlir::tosa::populateTosaOpsCanonicalizationPatterns(ctx, patterns);
+    populateTosaOpsCanonicalizationPatterns(ctx, patterns);
 
     if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
       signalPassFailure();


        


More information about the Mlir-commits mailing list