[Mlir-commits] [mlir] [mlir][tensor] Fold when source is const (PR #71643)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 8 01:28:15 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Rik Huijzer (rikhuijzer)
<details>
<summary>Changes</summary>
Fixes https://github.com/llvm/llvm-project/issues/60656.
This patch implements a basic fold for various reshape/resize tensor operations. Specifically, the folding removes tensor reshape/resize ops when they are applied to a constant tensor. For example, the following function:
```mlir
func.func @<!-- -->main(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
%0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
return %0 : tensor<8x16x8x32xf32>
}
```
will be changed into the following with `mlir-opt -canonicalize`:
```mlir
func.func @<!-- -->main(%arg0: tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
%cst = arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
return %cst : tensor<8x16x8x32xf32>
}
```
As a side-note, this patch is essentially an extension of https://github.com/llvm/llvm-project/commit/f79f430d4b268429f96be95622facd2775b25624. That commit implemented the folding for `tensor.extract_slice`, whereas this patch also implements it for `tensor.gather`, `tensor.reshape`, `tensor.pack`, and `tensor.unpack`.
---
Full diff: https://github.com/llvm/llvm-project/pull/71643.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+6)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+47-6)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+48)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 21e1f87bfa53709..c184971e478195e 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -659,6 +659,7 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
}
}];
let hasVerifier = 1;
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -986,6 +987,7 @@ def Tensor_ReshapeOp: Tensor_Op<"reshape", [
$source `(` $shape `)` attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -1867,6 +1869,8 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
}];
let hasCanonicalizeMethod = 1;
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -1948,6 +1952,8 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
}];
let hasCanonicalizeMethod = 1;
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 6fc45379111fc34..c33dd603cb02899 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -834,6 +834,16 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
ReplaceEmptyTensorStaticShapeDims>(context);
}
+/// Try to remove a tensor operation if it would only reshape a constant.
+/// Removes the op and replaces the constant with a new constant of the result shape.
+static OpFoldResult reshapeConstantSource(DenseElementsAttr source,
+ TensorType result) {
+ if (source && source.isSplat() && result.hasStaticShape())
+ return source.resizeSplat(result);
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
@@ -1089,6 +1099,14 @@ LogicalResult GatherOp::verify() {
return success();
}
+OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
+ if (OpFoldResult reshapedSource = reshapeConstantSource(
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+ getResult().getType()))
+ return reshapedSource;
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
@@ -1367,6 +1385,14 @@ LogicalResult ReshapeOp::verify() {
return success();
}
+OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
+ if (OpFoldResult reshapedSource = reshapeConstantSource(
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+ getResult().getType()))
+ return reshapedSource;
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// Reassociative reshape ops
//===----------------------------------------------------------------------===//
@@ -2153,12 +2179,10 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
}
OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
- if (auto splat =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
- auto resultType = llvm::cast<ShapedType>(getResult().getType());
- if (resultType.hasStaticShape())
- return splat.resizeSplat(resultType);
- }
+ if (OpFoldResult reshapedSource = reshapeConstantSource(
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
+ getResult().getType()))
+ return reshapedSource;
if (getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
return this->getSource();
@@ -3823,6 +3847,14 @@ bool PackOp::isLikePad() {
return isLikePadUnPad(*this, packedTensorType);
}
+OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+ if (OpFoldResult reshapedSource = reshapeConstantSource(
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+ getResult().getType()))
+ return reshapedSource;
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// UnPackOp
//===----------------------------------------------------------------------===//
@@ -3951,6 +3983,15 @@ bool UnPackOp::isLikeUnPad() {
RankedTensorType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
}
+
+OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
+ if (OpFoldResult reshapedSource = reshapeConstantSource(
+ llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+ getResult().getType()))
+ return reshapedSource;
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// Common Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 1078ee3b59a4306..ea8c17640d7c143 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -672,6 +672,30 @@ func.func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8x
// -----
+// CHECK-LABEL: func @fold_gather_constant_splat
+// CHECK-NOT: tensor.gather
+// CHECK: arith.constant dense<1.000000e-01> : tensor<1x2x1x1x1xf32>
+func.func @fold_gather_constant_splat(%indices : tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> {
+ %cst = arith.constant dense<1.000000e-01> : tensor<4x4x4xf32>
+ %0 = tensor.gather %cst[%indices] gather_dims([0, 1, 2]) :
+ (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
+ return %0 : tensor<1x2x 1x1x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_reshape_constant_splat
+// CHECK-NOT: tensor.reshape
+// CHECK: arith.constant dense<1.000000e-01> : tensor<4xf32>
+func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32> {
+ %cst = arith.constant dense<1.000000e-01> : tensor<4x1xf32>
+ %0 = tensor.reshape %cst(%shape)
+ : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
+ return %0 : tensor<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract_constant_splat
// CHECK-NOT: tensor.extract_slice
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
@@ -683,6 +707,30 @@ func.func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
// -----
+// CHECK-LABEL: func @fold_pack_constant_splat
+// CHECK-NOT: tensor.pack
+// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
+func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+ %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
+ %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
+ return %0 : tensor<8x16x8x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_unpack_constant_splat
+// CHECK-NOT: tensor.unpack
+// CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
+func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128x256xf32> {
+ %cst = arith.constant dense<1.000000e-01> : tensor<16x8x8x32xf32>
+ %0 = tensor.unpack %cst inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
+ return %0 : tensor<128x256xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_overlapping_insert
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/71643
More information about the Mlir-commits
mailing list