[Mlir-commits] [mlir] [MLIR][TENSOR] Decompose tensor.reshape to tensor.collapse+expand (PR #91499)

Gaurav Shukla llvmlistbot at llvm.org
Wed May 8 09:22:37 PDT 2024


https://github.com/Shukla-Gaurav created https://github.com/llvm/llvm-project/pull/91499

None

>From fd57711b075ffe9a3b62f91058e10559a1e14a87 Mon Sep 17 00:00:00 2001
From: Gaurav Shukla <gaurav at nod-labs.com>
Date: Wed, 8 May 2024 21:46:42 +0530
Subject: [PATCH] [MLIR][TENSOR] Decompose tensor.reshape to
 tensor.collapse+expand

Signed-Off-by: Gaurav Shukla <gaurav.shukla at amd.com>
---
 .../mlir/Dialect/Tensor/Transforms/Passes.h   |   3 +
 .../mlir/Dialect/Tensor/Transforms/Passes.td  |  11 +
 .../Dialect/Tensor/Transforms/Transforms.h    |   4 +
 .../Dialect/Tensor/Transforms/CMakeLists.txt  |   1 +
 .../Transforms/DecomposeTensorReshapeOp.cpp   | 294 ++++++++++++++++++
 5 files changed, 313 insertions(+)
 create mode 100644 mlir/lib/Dialect/Tensor/Transforms/DecomposeTensorReshapeOp.cpp

diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
index 48f9066934a25..7a71fe95da7dc 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
@@ -21,6 +21,9 @@ namespace tensor {
 /// Creates an instance of the `tensor` subset folding pass.
 std::unique_ptr<Pass> createFoldTensorSubsetOpsPass();
 
+/// Creates an instance of the `tensor` reshape decomposition pass.
+std::unique_ptr<Pass> createDecomposeTensorReshapeOpPass();
+
 /// Creates an instance of the `tensor` dialect bufferization pass.
 std::unique_ptr<Pass> createTensorBufferizePass();
 
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
index 4cc3844f29120..d88fb75b8a50d 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td
@@ -27,6 +27,17 @@ def FoldTensorSubsetOps : Pass<"fold-tensor-subset-ops"> {
   ];
 }
 
+def DecomposeTensorReshapeOp : Pass<"decompose-tensor-reshape-op"> {
+  let summary = "Decompose tensor reshape op into expand_shape/collapse_shape ops";
+  let description = [{
+    The pass decomposes tensor reshape op into expand_shape/collapse_shape ops.
+  }];
+  let constructor = "mlir::tensor::createDecomposeTensorReshapeOpPass()";
+  let dependentDialects = [
+      "affine::AffineDialect", "tensor::TensorDialect"
+  ];
+}
+
 def TensorBufferize : Pass<"tensor-bufferize", "func::FuncOp"> {
   let summary = "Bufferize the `tensor` dialect";
   let constructor = "mlir::tensor::createTensorBufferizePass()";
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index e8a09c4741043..4275c47144010 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -59,6 +59,10 @@ void populateDropRedundantInsertSliceRankExpansionPatterns(
 /// `tensor.collapse_shape` into other ops.
 void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns that decompose 'tensor.reshape` op into
+/// `tensor.expand_shape` and `tensor.collapse_shape` ops.
+void populateDecomposeTensorReshapePatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold tensor.empty with
 /// tensor.[extract_slice|expand_shape|collapse_shape].
 ///
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index c6ef6ed86e0d9..c82fd32172611 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   ConcatOpPatterns.cpp
+  DecomposeTensorReshapeOp.cpp
   EmptyOpPatterns.cpp
   ExtractSliceFromReshapeUtils.cpp
   FoldTensorSubsetOps.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/DecomposeTensorReshapeOp.cpp b/mlir/lib/Dialect/Tensor/Transforms/DecomposeTensorReshapeOp.cpp
new file mode 100644
index 0000000000000..72b8662814823
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/DecomposeTensorReshapeOp.cpp
@@ -0,0 +1,294 @@
+//===- DecomposeTensorReshapeOp.cpp - Decompose tensor reshape op
+//-------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Decompose tensor reshape op into tensor collapse-expand pair.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <type_traits>
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+// Infer the type to which the input of a 'tensor.reshape' op must be cast when
+// lowered.
+TensorType inferReshapeInputType(TypedValue<TensorType> input,
+                                 SmallVector<OpFoldResult> newShape) {
+  // No need to cast input for non-empty target shape
+  if (!newShape.empty())
+    return input.getType();
+
+  // The input type must be cast into a tensor with the same rank and all static
+  // dimensions set to 1. This prevents the generation of a
+  // tensor.collapse_shape op that converts a dynamically shaped tensor into a
+  // 0D tensor.
+  SmallVector<int64_t> shape(input.getType().getRank(), 1);
+  return input.getType().clone(shape);
+}
+
+// Infer the result type of 'tensor.expand_shape' in the collapse-expand
+// pair emitted for a 'tensor.reshape' op.
+TensorType inferReshapeExpandedType(TensorType inputType,
+                                    SmallVector<int64_t> newShape) {
+  // Special case for 0D output tensor. Note: Watch out when using Type::clone()
+  // with just '{}', as it will invoke the incorrect overload.
+  if (newShape.empty())
+    return inputType.clone(ArrayRef<int64_t>{});
+
+  // Check if the input is static, and if so, get its total size
+  bool inputIsStatic = inputType.hasStaticShape();
+  int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
+
+  // Compute result shape
+  bool resultIsStatic = true;
+  auto resultShape =
+      llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
+        // If this is not a placeholder, do not change it
+        if (size >= 0)
+          return size;
+
+        // If we do not know the total size of the tensor, keep this dimension
+        // dynamic in the result shape.
+        if (!inputIsStatic) {
+          resultIsStatic = false;
+          return ShapedType::kDynamic;
+        }
+
+        // Calculate the product of all elements in 'newShape' except for the -1
+        // placeholder, which we discard by negating the result.
+        int64_t totalSizeNoPlaceholder = -std::accumulate(
+            newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
+
+        // If there is a 0 component in 'newShape', resolve the placeholder as
+        // 0.
+        if (totalSizeNoPlaceholder == 0)
+          return 0;
+
+        // Resolve the placeholder as the quotient between the total tensor size
+        // and the product of all other sizes.
+        return totalSize / totalSizeNoPlaceholder;
+      });
+
+  // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
+  // shaped input from being reshaped into a statically shaped result. We may
+  // simply turn the first result dimension dynamic to address this.
+  if (!inputIsStatic && resultIsStatic)
+    resultShape[0] = ShapedType::kDynamic;
+
+  // The 'tensor.expand_shape' op also forbids a statically shaped input from
+  // being reshaped into a dynamically shaped result, but the placeholder
+  // inference algorithm above guarantees that this will never be the case.
+  assert(!inputIsStatic || resultIsStatic);
+
+  // Create result type
+  return inputType.clone(resultShape);
+}
+
+// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
+// pair emitted for a 'tensor.reshape' op.
+TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
+  auto lhsShape = lhsType.getShape();
+  auto rhsShape = rhsType.getShape();
+
+  if (lhsShape.empty() || rhsShape.empty())
+    return lhsType.clone(ArrayRef<int64_t>{});
+
+  if (ShapedType::isDynamicShape(lhsShape) ||
+      ShapedType::isDynamicShape(rhsShape))
+    return lhsType.clone({ShapedType::kDynamic});
+
+  SmallVector<int64_t> intermediateShape;
+  unsigned currLhsDim = 0, currRhsDim = 0;
+  while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
+    int64_t rhsSize = rhsShape[currRhsDim];
+    int64_t lhsSize = lhsShape[currLhsDim];
+    while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
+           currRhsDim < rhsShape.size()) {
+      if (lhsSize < rhsSize) {
+        currLhsDim++;
+        if (currLhsDim < lhsShape.size()) {
+          lhsSize *= lhsShape[currLhsDim];
+        }
+      } else {
+        currRhsDim++;
+        if (currRhsDim < rhsShape.size()) {
+          rhsSize *= rhsShape[currRhsDim];
+        }
+      }
+    }
+    if (lhsSize == rhsSize) {
+      intermediateShape.push_back(lhsSize);
+    }
+    currRhsDim++;
+    currLhsDim++;
+  }
+
+  // Static shapes are guaranteed to be compatible by the op verifier, so all
+  // leftover dimensions should be 1.
+  for (; currLhsDim < lhsShape.size(); currLhsDim++) {
+    assert(lhsShape[currLhsDim] == 1);
+  }
+  for (; currRhsDim < rhsShape.size(); currRhsDim++) {
+    assert(rhsShape[currRhsDim] == 1);
+  }
+
+  return lhsType.clone(intermediateShape);
+}
+
+SmallVector<ReassociationExprs>
+createReassociationMapForCollapse(OpBuilder &builder, Type srcType,
+                                  Type dstType) {
+  auto srcShape = cast<TensorType>(srcType).getShape();
+  auto dstShape = cast<TensorType>(dstType).getShape();
+
+  if (srcShape.empty() || dstShape.empty())
+    return {};
+
+  if (ShapedType::isDynamicShape(srcShape) ||
+      ShapedType::isDynamicShape(dstShape)) {
+    assert(dstShape.size() == 1);
+    SmallVector<AffineExpr, 2> exprs;
+    for (auto i : llvm::seq<int64_t>(srcShape.size()))
+      exprs.push_back(builder.getAffineDimExpr(i));
+    return {exprs};
+  }
+
+  SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
+  unsigned currSrcDim = 0, currDstDim = 0;
+  while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
+    int64_t dstSize = dstShape[currDstDim];
+    int64_t srcSize = srcShape[currSrcDim];
+    while (srcSize < dstSize && currSrcDim < srcShape.size()) {
+      reassociationMap[currDstDim].push_back(
+          builder.getAffineDimExpr(currSrcDim++));
+      srcSize *= srcShape[currSrcDim];
+    }
+    if (srcSize == dstSize) {
+      reassociationMap[currDstDim].push_back(
+          builder.getAffineDimExpr(currSrcDim++));
+      // If the next dim in collapsedShape is not 1, treat subsequent dims in
+      // expandedShape which are 1 to be collapsed.
+      if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
+        while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
+          reassociationMap[currDstDim].push_back(
+              builder.getAffineDimExpr(currSrcDim++));
+        }
+      }
+    }
+    currDstDim++;
+  }
+
+  // If the source and target shapes are compatible, both iterators must have
+  // reached the end. This condition is guaranteed by the op verifier for
+  // static shapes.
+  assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
+  return reassociationMap;
+}
+
+// Create a tensor.collapse_shape op that reshapes the input into the given
+// result type.
+Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
+                     Value input) {
+  auto reassociationMap =
+      createReassociationMapForCollapse(builder, input.getType(), resultType);
+  return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
+                                                       reassociationMap);
+}
+
+// Create a tensor.expand_shape op that reshapes the input into the given result
+// type.
+Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
+                   Value input, SmallVector<OpFoldResult> outputShape) {
+  auto reassociationMap =
+      createReassociationMapForCollapse(builder, resultType, input.getType());
+  return builder.createOrFold<tensor::ExpandShapeOp>(
+      loc, resultType, input, reassociationMap, outputShape);
+}
+
+struct DecomposeTensorReshapeOp : public OpRewritePattern<ReshapeOp> {
+  using OpRewritePattern<ReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = reshapeOp.getLoc();
+    auto resultType = reshapeOp.getResult().getType();
+    Value input = reshapeOp.getOperand();
+    Value newShape = reshapeOp.getShape();
+    auto fromElementsOp = newShape.getDefiningOp<FromElementsOp>();
+    if (!fromElementsOp)
+      return failure();
+    SmallVector<OpFoldResult> newShapeList(fromElementsOp.getElements());
+
+    // Infer all intermediate types
+    auto inputType = inferReshapeInputType(input, newShapeList);
+    auto expandedType =
+        inferReshapeExpandedType(inputType, resultType.getShape());
+    auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
+
+    // Cast input if needed
+    auto castInput =
+        rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
+
+    // Emit collaspe-expand pair
+    auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
+    auto expanded =
+        createExpand(rewriter, loc, expandedType, collapsed, newShapeList);
+
+    // Cast to final result type if needed
+    auto result =
+        rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
+    rewriter.replaceOp(reshapeOp, result);
+    return success();
+  }
+};
+
+} // namespace
+
+void tensor::populateDecomposeTensorReshapePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<DecomposeTensorReshapeOp>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct DecomposeTensorReshapeOpPass final
+    : public tensor::impl::DecomposeTensorReshapeOpBase<
+          DecomposeTensorReshapeOpPass> {
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void DecomposeTensorReshapeOpPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  tensor::populateDecomposeTensorReshapePatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+std::unique_ptr<Pass> tensor::createDecomposeTensorReshapeOpPass() {
+  return std::make_unique<DecomposeTensorReshapeOpPass>();
+}



More information about the Mlir-commits mailing list