[Mlir-commits] [mlir] [mlir] Convert `expand_shape` to more static form (PR #112265)

Ian Wood llvmlistbot at llvm.org
Tue Oct 15 14:58:36 PDT 2024


https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/112265

>From 67649194b893a9a017082964d285056f4c6656fa Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Mon, 14 Oct 2024 13:29:42 -0500
Subject: [PATCH 1/3] [mlir] Fold expand of cast

Sink tensor.cast op through tensor.expand_shape ops when it makes the
expand op more static. This allows for other ops further down infer
their shapes.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 31 +++++++++++++++++++++-
 mlir/test/Dialect/Tensor/canonicalize.mlir | 14 ++++++++++
 2 files changed, 44 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4d6c5965c4fcc3..9be647f687e600 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1982,6 +1982,35 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+struct FoldExpandOfCast : public OpRewritePattern<ExpandShapeOp> {
+  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
+    if (!canFoldIntoConsumerOp(castOp))
+      return failure();
+
+    SmallVector<OpFoldResult> outputOfr =
+        getMixedValues(expandOp.getResultType().getShape(),
+                       expandOp.getOutputShape(), rewriter);
+    std::optional<SmallVector<int64_t>> constantOutputShape =
+        getConstantIntValues(outputOfr);
+    if (!constantOutputShape.has_value()) {
+      return failure();
+    }
+    auto newType = RankedTensorType::get(
+        constantOutputShape.value(), expandOp.getSrcType().getElementType());
+
+    auto newExpand = rewriter.create<ExpandShapeOp>(
+        castOp.getLoc(), newType, castOp.getSource(),
+        expandOp.getReassociationIndices());
+    rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
+                                        newExpand.getResult());
+    return success();
+  }
+};
 } // namespace
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -1989,7 +2018,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<
       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
       ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
-      FoldReshapeWithConstant<ExpandShapeOp>,
+      FoldExpandOfCast, FoldReshapeWithConstant<ExpandShapeOp>,
       FoldReshapeWithSplat<ExpandShapeOp>,
       FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
       FoldDimOfCollapseShape>(context);
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0aa2d33ef17ed4..1509d26151119d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2718,3 +2718,17 @@ func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128
   %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
   return %pack : tensor<128x?x100x16x1xf16>
 }
+
+// -----
+
+func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
+    -> tensor<?x?x?xf32> {
+  %c1 = arith.constant 1 : index 
+  %c10 = arith.constant 10 : index 
+  %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL:  func.func @fold_expand_of_cast
+//       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]

>From 3f4c7bb63fc16dcfa809ae917d039c25782c7cb9 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 15 Oct 2024 02:19:37 +0000
Subject: [PATCH 2/3] Convert to static expand_shape

When output_sizes can be determined, convert to a static expand_shape
op and insert cast ops. The top cast will be (dynamic -> static) allowing
it to be propagated upwards and the bottom will be (static -> dynamic)
allowing it to propagate down (or cancel with adjacent tensor.cast ops).

[skip ci]
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 61 ++++++++++++++++------
 mlir/test/Dialect/Tensor/canonicalize.mlir | 14 +++++
 2 files changed, 60 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 9be647f687e600..96384385b6a060 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1983,29 +1983,60 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
   }
 };
 
-struct FoldExpandOfCast : public OpRewritePattern<ExpandShapeOp> {
+struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
   using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
                                 PatternRewriter &rewriter) const override {
-    auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
-    if (!canFoldIntoConsumerOp(castOp))
-      return failure();
+    SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
+    SmallVector<Value> dynamicOutputShape;
+    auto outputIt = expandOp.getOutputShape().begin();
+    for (auto [i, staticShape] : llvm::enumerate(newOutputShape)) {
+      if (!ShapedType::isDynamic(staticShape))
+        continue;
 
-    SmallVector<OpFoldResult> outputOfr =
-        getMixedValues(expandOp.getResultType().getShape(),
-                       expandOp.getOutputShape(), rewriter);
-    std::optional<SmallVector<int64_t>> constantOutputShape =
-        getConstantIntValues(outputOfr);
-    if (!constantOutputShape.has_value()) {
+      APInt cst;
+      Value val = *outputIt;
+      ++outputIt;
+      if (matchPattern(val, m_ConstantInt(&cst))) {
+        newOutputShape[i] = cst.getSExtValue();
+      } else {
+        dynamicOutputShape.push_back(val);
+      }
+    }
+
+    // Couldn't match any values, nothing to change
+    if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
       return failure();
+
+    // Calculate the input shape from the output
+    SmallVector<ReassociationIndices, 4> reassoc =
+        expandOp.getReassociationIndices();
+    SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
+    for (uint64_t inDim = 0; inDim < newInputShape.size(); inDim++) {
+      for (auto outDim : reassoc[inDim]) {
+        auto ofr = newOutputShape[outDim];
+        if (ShapedType::isDynamic(ofr)) {
+          newInputShape[inDim] = ShapedType::kDynamic;
+          break;
+        }
+        newInputShape[inDim] *= ofr;
+      }
     }
-    auto newType = RankedTensorType::get(
-        constantOutputShape.value(), expandOp.getSrcType().getElementType());
 
+    // `inputCast` can be propagated up and the final cast can be propagated
+    // down.
+    SmallVector<OpFoldResult> outputOfr =
+        getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
+    auto inputType = RankedTensorType::get(
+        newInputShape, expandOp.getSrcType().getElementType());
+    auto outputType = RankedTensorType::get(
+        newOutputShape, expandOp.getSrcType().getElementType());
+    auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
+                                             expandOp.getSrc());
     auto newExpand = rewriter.create<ExpandShapeOp>(
-        castOp.getLoc(), newType, castOp.getSource(),
-        expandOp.getReassociationIndices());
+        expandOp.getLoc(), outputType, inputCast.getResult(),
+        expandOp.getReassociationIndices(), outputOfr);
     rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
                                         newExpand.getResult());
     return success();
@@ -2018,7 +2049,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<
       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
       ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
-      FoldExpandOfCast, FoldReshapeWithConstant<ExpandShapeOp>,
+      ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
       FoldReshapeWithSplat<ExpandShapeOp>,
       FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
       FoldDimOfCollapseShape>(context);
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 1509d26151119d..52dcfd1d427d93 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2732,3 +2732,17 @@ func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
 }
 // CHECK-LABEL:  func.func @fold_expand_of_cast
 //       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
+
+// -----
+
+func.func @fold_expand_of_cast_dynamic(%arg0 : tensor<?x10xf32>)
+    -> tensor<?x?x?xf32> {
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %0 = tensor.cast %arg0 : tensor<?x10xf32> to tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL:  func.func @fold_expand_of_cast_dynamic
+//       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]

>From 4e16a9764cc0b125d3b851fd077865fe50b62003 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 15 Oct 2024 21:58:24 +0000
Subject: [PATCH 3/3] Redo logic to ensure cast gets folded

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 46 +++++++++++++++-------
 mlir/test/Dialect/Tensor/canonicalize.mlir | 38 +++++++++++++++---
 2 files changed, 64 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 96384385b6a060..ee0e8c2d201226 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -1988,20 +1989,41 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
 
   LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
                                 PatternRewriter &rewriter) const override {
+    auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
+    if (!canFoldIntoConsumerOp(castOp))
+      return failure();
+
+    const ArrayRef<int64_t> castSrcShape =
+        castOp.getSource().getType().getShape();
+    const SmallVector<ReassociationIndices, 4> reassoc =
+        expandOp.getReassociationIndices();
+
     SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
     SmallVector<Value> dynamicOutputShape;
     auto outputIt = expandOp.getOutputShape().begin();
-    for (auto [i, staticShape] : llvm::enumerate(newOutputShape)) {
-      if (!ShapedType::isDynamic(staticShape))
-        continue;
 
-      APInt cst;
-      Value val = *outputIt;
-      ++outputIt;
-      if (matchPattern(val, m_ConstantInt(&cst))) {
-        newOutputShape[i] = cst.getSExtValue();
-      } else {
-        dynamicOutputShape.push_back(val);
+    for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
+      for (const uint64_t outDim : innerReassoc) {
+        if (!ShapedType::isDynamic(newOutputShape[outDim]))
+          continue;
+
+        // If the cast's src type is dynamic, don't infer any of the
+        // corresponding expanded dimensions. `tensor.expand_shape` requires at
+        // least one of the expanded dimensions to be dynamic if the input is
+        // dynamic.
+        Value val = *outputIt;
+        ++outputIt;
+        if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+          dynamicOutputShape.push_back(val);
+          continue;
+        }
+
+        APInt cst;
+        if (matchPattern(val, m_ConstantInt(&cst))) {
+          newOutputShape[outDim] = cst.getSExtValue();
+        } else {
+          dynamicOutputShape.push_back(val);
+        }
       }
     }
 
@@ -2010,8 +2032,6 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
       return failure();
 
     // Calculate the input shape from the output
-    SmallVector<ReassociationIndices, 4> reassoc =
-        expandOp.getReassociationIndices();
     SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
     for (uint64_t inDim = 0; inDim < newInputShape.size(); inDim++) {
       for (auto outDim : reassoc[inDim]) {
@@ -2024,8 +2044,6 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
       }
     }
 
-    // `inputCast` can be propagated up and the final cast can be propagated
-    // down.
     SmallVector<OpFoldResult> outputOfr =
         getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
     auto inputType = RankedTensorType::get(
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 52dcfd1d427d93..63f394a14d3899 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2722,20 +2722,22 @@ func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128
 // -----
 
 func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
-    -> tensor<?x?x?xf32> {
+    -> tensor<10x1x10xf32> {
   %c1 = arith.constant 1 : index 
   %c10 = arith.constant 10 : index 
   %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
   %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
       : tensor<?x?xf32> into tensor<?x?x?xf32>
-  return %1 : tensor<?x?x?xf32>
+  %2 = tensor.cast %1 : tensor<?x?x?xf32> to tensor<10x1x10xf32>
+  return %2 : tensor<10x1x10xf32>
 }
 // CHECK-LABEL:  func.func @fold_expand_of_cast
-//       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
+//       CHECK:   %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
+//       CHECK:   return %[[RES]]
 
 // -----
 
-func.func @fold_expand_of_cast_dynamic(%arg0 : tensor<?x10xf32>)
+func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
     -> tensor<?x?x?xf32> {
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
@@ -2744,5 +2746,29 @@ func.func @fold_expand_of_cast_dynamic(%arg0 : tensor<?x10xf32>)
       : tensor<?x?xf32> into tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
 }
-// CHECK-LABEL:  func.func @fold_expand_of_cast_dynamic
-//       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
+// CHECK-LABEL:  func.func @sink_expand_of_cast
+//   CHECK-DAG:   %[[C10:.*]] = arith.constant 10
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] 
+//  CHECK-SAME:     output_shape [%[[C10]], %[[C1]], 10]
+//       CHECK:   %[[RES:.+]] = tensor.cast %[[EXPAND]]
+//       CHECK:   return %[[RES]]
+
+// -----
+
+func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index)
+    -> tensor<?x?x?xf32> {
+  %c10 = arith.constant 10 : index
+  %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL:  func.func @partial_sink_expand_of_cast
+//       CHECK:   %[[CAST:.+]] = tensor.cast
+//  CHECK-SAME:     tensor<10x10xf32> to tensor<?x10xf32>
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] 
+//  CHECK-SAME:     output_shape [%{{.*}}, %{{.*}}, 10]
+//       CHECK:   %[[RES:.+]] = tensor.cast %[[EXPAND]]
+//  CHECK-SAME:     tensor<?x?x10xf32> to tensor<?x?x?xf32>
+//       CHECK:   return %[[RES]]



More information about the Mlir-commits mailing list