[Mlir-commits] [mlir] [mlir] Convert `expand_shape` to more static form (PR #112265)
Ian Wood
llvmlistbot at llvm.org
Tue Oct 22 14:43:44 PDT 2024
https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/112265
>From 67649194b893a9a017082964d285056f4c6656fa Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Mon, 14 Oct 2024 13:29:42 -0500
Subject: [PATCH 1/5] [mlir] Fold expand of cast
Sink tensor.cast op through tensor.expand_shape ops when it makes the
expand op more static. This allows for other ops further down infer
their shapes.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 31 +++++++++++++++++++++-
mlir/test/Dialect/Tensor/canonicalize.mlir | 14 ++++++++++
2 files changed, 44 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 4d6c5965c4fcc3..9be647f687e600 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1982,6 +1982,35 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
return success();
}
};
+
+struct FoldExpandOfCast : public OpRewritePattern<ExpandShapeOp> {
+ using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
+ if (!canFoldIntoConsumerOp(castOp))
+ return failure();
+
+ SmallVector<OpFoldResult> outputOfr =
+ getMixedValues(expandOp.getResultType().getShape(),
+ expandOp.getOutputShape(), rewriter);
+ std::optional<SmallVector<int64_t>> constantOutputShape =
+ getConstantIntValues(outputOfr);
+ if (!constantOutputShape.has_value()) {
+ return failure();
+ }
+ auto newType = RankedTensorType::get(
+ constantOutputShape.value(), expandOp.getSrcType().getElementType());
+
+ auto newExpand = rewriter.create<ExpandShapeOp>(
+ castOp.getLoc(), newType, castOp.getSource(),
+ expandOp.getReassociationIndices());
+ rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
+ newExpand.getResult());
+ return success();
+ }
+};
} // namespace
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -1989,7 +2018,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
- FoldReshapeWithConstant<ExpandShapeOp>,
+ FoldExpandOfCast, FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithSplat<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
FoldDimOfCollapseShape>(context);
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0aa2d33ef17ed4..1509d26151119d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2718,3 +2718,17 @@ func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128
%pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
return %pack : tensor<128x?x100x16x1xf16>
}
+
+// -----
+
+func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
+ -> tensor<?x?x?xf32> {
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func.func @fold_expand_of_cast
+// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
>From 3f4c7bb63fc16dcfa809ae917d039c25782c7cb9 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 15 Oct 2024 02:19:37 +0000
Subject: [PATCH 2/5] Convert to static expand_shape
When output_sizes can be determined, convert to a static expand_shape
op and insert cast ops. The top cast will be (dynamic -> static) allowing
it to be propagated upwards and the bottom will be (static -> dynamic)
allowing it to propagate down (or cancel with adjacent tensor.cast ops).
[skip ci]
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 61 ++++++++++++++++------
mlir/test/Dialect/Tensor/canonicalize.mlir | 14 +++++
2 files changed, 60 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 9be647f687e600..96384385b6a060 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1983,29 +1983,60 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
}
};
-struct FoldExpandOfCast : public OpRewritePattern<ExpandShapeOp> {
+struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
- auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
- if (!canFoldIntoConsumerOp(castOp))
- return failure();
+ SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
+ SmallVector<Value> dynamicOutputShape;
+ auto outputIt = expandOp.getOutputShape().begin();
+ for (auto [i, staticShape] : llvm::enumerate(newOutputShape)) {
+ if (!ShapedType::isDynamic(staticShape))
+ continue;
- SmallVector<OpFoldResult> outputOfr =
- getMixedValues(expandOp.getResultType().getShape(),
- expandOp.getOutputShape(), rewriter);
- std::optional<SmallVector<int64_t>> constantOutputShape =
- getConstantIntValues(outputOfr);
- if (!constantOutputShape.has_value()) {
+ APInt cst;
+ Value val = *outputIt;
+ ++outputIt;
+ if (matchPattern(val, m_ConstantInt(&cst))) {
+ newOutputShape[i] = cst.getSExtValue();
+ } else {
+ dynamicOutputShape.push_back(val);
+ }
+ }
+
+ // Couldn't match any values, nothing to change
+ if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
return failure();
+
+ // Calculate the input shape from the output
+ SmallVector<ReassociationIndices, 4> reassoc =
+ expandOp.getReassociationIndices();
+ SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
+ for (uint64_t inDim = 0; inDim < newInputShape.size(); inDim++) {
+ for (auto outDim : reassoc[inDim]) {
+ auto ofr = newOutputShape[outDim];
+ if (ShapedType::isDynamic(ofr)) {
+ newInputShape[inDim] = ShapedType::kDynamic;
+ break;
+ }
+ newInputShape[inDim] *= ofr;
+ }
}
- auto newType = RankedTensorType::get(
- constantOutputShape.value(), expandOp.getSrcType().getElementType());
+ // `inputCast` can be propagated up and the final cast can be propagated
+ // down.
+ SmallVector<OpFoldResult> outputOfr =
+ getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
+ auto inputType = RankedTensorType::get(
+ newInputShape, expandOp.getSrcType().getElementType());
+ auto outputType = RankedTensorType::get(
+ newOutputShape, expandOp.getSrcType().getElementType());
+ auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
+ expandOp.getSrc());
auto newExpand = rewriter.create<ExpandShapeOp>(
- castOp.getLoc(), newType, castOp.getSource(),
- expandOp.getReassociationIndices());
+ expandOp.getLoc(), outputType, inputCast.getResult(),
+ expandOp.getReassociationIndices(), outputOfr);
rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
newExpand.getResult());
return success();
@@ -2018,7 +2049,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
- FoldExpandOfCast, FoldReshapeWithConstant<ExpandShapeOp>,
+ ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithSplat<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
FoldDimOfCollapseShape>(context);
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 1509d26151119d..52dcfd1d427d93 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2732,3 +2732,17 @@ func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
}
// CHECK-LABEL: func.func @fold_expand_of_cast
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
+
+// -----
+
+func.func @fold_expand_of_cast_dynamic(%arg0 : tensor<?x10xf32>)
+ -> tensor<?x?x?xf32> {
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %0 = tensor.cast %arg0 : tensor<?x10xf32> to tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func.func @fold_expand_of_cast_dynamic
+// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
>From 4e16a9764cc0b125d3b851fd077865fe50b62003 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 15 Oct 2024 21:58:24 +0000
Subject: [PATCH 3/5] Redo logic to ensure cast gets folded
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 46 +++++++++++++++-------
mlir/test/Dialect/Tensor/canonicalize.mlir | 38 +++++++++++++++---
2 files changed, 64 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 96384385b6a060..ee0e8c2d201226 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -1988,20 +1989,41 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
+ auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
+ if (!canFoldIntoConsumerOp(castOp))
+ return failure();
+
+ const ArrayRef<int64_t> castSrcShape =
+ castOp.getSource().getType().getShape();
+ const SmallVector<ReassociationIndices, 4> reassoc =
+ expandOp.getReassociationIndices();
+
SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
SmallVector<Value> dynamicOutputShape;
auto outputIt = expandOp.getOutputShape().begin();
- for (auto [i, staticShape] : llvm::enumerate(newOutputShape)) {
- if (!ShapedType::isDynamic(staticShape))
- continue;
- APInt cst;
- Value val = *outputIt;
- ++outputIt;
- if (matchPattern(val, m_ConstantInt(&cst))) {
- newOutputShape[i] = cst.getSExtValue();
- } else {
- dynamicOutputShape.push_back(val);
+ for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
+ for (const uint64_t outDim : innerReassoc) {
+ if (!ShapedType::isDynamic(newOutputShape[outDim]))
+ continue;
+
+ // If the cast's src type is dynamic, don't infer any of the
+ // corresponding expanded dimensions. `tensor.expand_shape` requires at
+ // least one of the expanded dimensions to be dynamic if the input is
+ // dynamic.
+ Value val = *outputIt;
+ ++outputIt;
+ if (ShapedType::isDynamic(castSrcShape[inputDim])) {
+ dynamicOutputShape.push_back(val);
+ continue;
+ }
+
+ APInt cst;
+ if (matchPattern(val, m_ConstantInt(&cst))) {
+ newOutputShape[outDim] = cst.getSExtValue();
+ } else {
+ dynamicOutputShape.push_back(val);
+ }
}
}
@@ -2010,8 +2032,6 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
return failure();
// Calculate the input shape from the output
- SmallVector<ReassociationIndices, 4> reassoc =
- expandOp.getReassociationIndices();
SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
for (uint64_t inDim = 0; inDim < newInputShape.size(); inDim++) {
for (auto outDim : reassoc[inDim]) {
@@ -2024,8 +2044,6 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
}
}
- // `inputCast` can be propagated up and the final cast can be propagated
- // down.
SmallVector<OpFoldResult> outputOfr =
getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
auto inputType = RankedTensorType::get(
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 52dcfd1d427d93..63f394a14d3899 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2722,20 +2722,22 @@ func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128
// -----
func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
- -> tensor<?x?x?xf32> {
+ -> tensor<10x1x10xf32> {
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
: tensor<?x?xf32> into tensor<?x?x?xf32>
- return %1 : tensor<?x?x?xf32>
+ %2 = tensor.cast %1 : tensor<?x?x?xf32> to tensor<10x1x10xf32>
+ return %2 : tensor<10x1x10xf32>
}
// CHECK-LABEL: func.func @fold_expand_of_cast
-// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
+// CHECK: %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
+// CHECK: return %[[RES]]
// -----
-func.func @fold_expand_of_cast_dynamic(%arg0 : tensor<?x10xf32>)
+func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
-> tensor<?x?x?xf32> {
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
@@ -2744,5 +2746,29 @@ func.func @fold_expand_of_cast_dynamic(%arg0 : tensor<?x10xf32>)
: tensor<?x?xf32> into tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
-// CHECK-LABEL: func.func @fold_expand_of_cast_dynamic
-// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
+// CHECK-LABEL: func.func @sink_expand_of_cast
+// CHECK-DAG: %[[C10:.*]] = arith.constant 10
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: output_shape [%[[C10]], %[[C1]], 10]
+// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
+// CHECK: return %[[RES]]
+
+// -----
+
+func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index)
+ -> tensor<?x?x?xf32> {
+ %c10 = arith.constant 10 : index
+ %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func.func @partial_sink_expand_of_cast
+// CHECK: %[[CAST:.+]] = tensor.cast
+// CHECK-SAME: tensor<10x10xf32> to tensor<?x10xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
+// CHECK-SAME: output_shape [%{{.*}}, %{{.*}}, 10]
+// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
+// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
+// CHECK: return %[[RES]]
>From c6e1139536a622ab2b46f98a85336db4cc7fa404 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 22 Oct 2024 16:16:56 +0000
Subject: [PATCH 4/5] Drop const qualifier
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ee0e8c2d201226..c7a675733311a9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1993,9 +1993,8 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
if (!canFoldIntoConsumerOp(castOp))
return failure();
- const ArrayRef<int64_t> castSrcShape =
- castOp.getSource().getType().getShape();
- const SmallVector<ReassociationIndices, 4> reassoc =
+ ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
+ SmallVector<ReassociationIndices, 4> reassoc =
expandOp.getReassociationIndices();
SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
@@ -2003,7 +2002,7 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
auto outputIt = expandOp.getOutputShape().begin();
for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
- for (const uint64_t outDim : innerReassoc) {
+ for (uint64_t outDim : innerReassoc) {
if (!ShapedType::isDynamic(newOutputShape[outDim]))
continue;
>From 8cf255b1552413ec568d172644baf00b30fb64a4 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 22 Oct 2024 21:36:44 +0000
Subject: [PATCH 5/5] Add comment to rewrite pattern
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index c7a675733311a9..6bf01b2ee1b9fc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1984,6 +1984,10 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
}
};
+/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
+/// matching constant output_shape operands of the expand. This makes the
+/// `tensor.expand_shape` more static and creates a consumer cast that can be
+/// propagated further.
struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
More information about the Mlir-commits
mailing list