[Mlir-commits] [mlir] [mlir][tensor] Fold when source is const (PR #71643)

Rik Huijzer llvmlistbot at llvm.org
Wed Nov 8 01:27:48 PST 2023


https://github.com/rikhuijzer created https://github.com/llvm/llvm-project/pull/71643

Fixes https://github.com/llvm/llvm-project/issues/60656.

This patch implements a basic fold for various reshape/resize tensor operations. Specifically, the folding removes tensor reshape/resize ops when they are applied to a constant tensor. For example, the following function:

```mlir
func.func @main(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
  %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
  %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
    inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
  return %0 : tensor<8x16x8x32xf32>
}
```
will be changed into the following with `mlir-opt -canonicalize`:
```mlir
func.func @main(%arg0: tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
  %cst = arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
  return %cst : tensor<8x16x8x32xf32>
}
```

As a side-note, this patch is essentially an extension of https://github.com/llvm/llvm-project/commit/f79f430d4b268429f96be95622facd2775b25624. That commit implemented the folding for `tensor.extract_slice`, whereas this patch also implements it for `tensor.gather`, `tensor.reshape`, `tensor.pack`, and `tensor.unpack`.

>From 4c021b09cb2fe6adeec4a9b6be863000e93d59cc Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Wed, 8 Nov 2023 10:10:51 +0100
Subject: [PATCH] [mlir][tensor] Fold when source is const

---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  6 +++
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 53 ++++++++++++++++---
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 48 +++++++++++++++++
 3 files changed, 101 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 21e1f87bfa53709..c184971e478195e 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -659,6 +659,7 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
     }
   }];
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -986,6 +987,7 @@ def Tensor_ReshapeOp: Tensor_Op<"reshape", [
     $source `(` $shape `)` attr-dict `:` functional-type(operands, results)
   }];
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1867,6 +1869,8 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
   }];
 
   let hasCanonicalizeMethod = 1;
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1948,6 +1952,8 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
   }];
 
   let hasCanonicalizeMethod = 1;
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 6fc45379111fc34..c33dd603cb02899 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -834,6 +834,16 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
               ReplaceEmptyTensorStaticShapeDims>(context);
 }
 
+/// Try to remove a tensor operation if it would only reshape a constant.
+/// Removes the op and replaces the constant with a new constant of the result shape.
+static OpFoldResult reshapeConstantSource(DenseElementsAttr source,
+                                  TensorType result) {
+  if (source && source.isSplat() && result.hasStaticShape())
+    return source.resizeSplat(result);
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractOp
 //===----------------------------------------------------------------------===//
@@ -1089,6 +1099,14 @@ LogicalResult GatherOp::verify() {
   return success();
 }
 
+OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
+  if (OpFoldResult reshapedSource = reshapeConstantSource(
+          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+          getResult().getType()))
+    return reshapedSource;
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // InsertOp
 //===----------------------------------------------------------------------===//
@@ -1367,6 +1385,14 @@ LogicalResult ReshapeOp::verify() {
   return success();
 }
 
+OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
+  if (OpFoldResult reshapedSource = reshapeConstantSource(
+          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+          getResult().getType()))
+    return reshapedSource;
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // Reassociative reshape ops
 //===----------------------------------------------------------------------===//
@@ -2153,12 +2179,10 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
 }
 
 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
-  if (auto splat =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
-    auto resultType = llvm::cast<ShapedType>(getResult().getType());
-    if (resultType.hasStaticShape())
-      return splat.resizeSplat(resultType);
-  }
+  if (OpFoldResult reshapedSource = reshapeConstantSource(
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
+          getResult().getType()))
+    return reshapedSource;
   if (getSourceType() == getType() &&
       succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
     return this->getSource();
@@ -3823,6 +3847,14 @@ bool PackOp::isLikePad() {
   return isLikePadUnPad(*this, packedTensorType);
 }
 
+OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+  if (OpFoldResult reshapedSource = reshapeConstantSource(
+          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+          getResult().getType()))
+    return reshapedSource;
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // UnPackOp
 //===----------------------------------------------------------------------===//
@@ -3951,6 +3983,15 @@ bool UnPackOp::isLikeUnPad() {
   RankedTensorType packedTensorType = getSourceType();
   return isLikePadUnPad(*this, packedTensorType);
 }
+
+OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
+  if (OpFoldResult reshapedSource = reshapeConstantSource(
+          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
+          getResult().getType()))
+    return reshapedSource;
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // Common Canonicalizers and Folders.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 1078ee3b59a4306..ea8c17640d7c143 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -672,6 +672,30 @@ func.func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8x
 
 // -----
 
+// CHECK-LABEL: func @fold_gather_constant_splat
+//   CHECK-NOT: tensor.gather
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<1x2x1x1x1xf32>
+func.func @fold_gather_constant_splat(%indices : tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> {
+  %cst = arith.constant dense<1.000000e-01> : tensor<4x4x4xf32>
+  %0 = tensor.gather %cst[%indices] gather_dims([0, 1, 2]) :
+    (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
+  return %0 : tensor<1x2x 1x1x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_reshape_constant_splat
+//   CHECK-NOT: tensor.reshape
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<4xf32>
+func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32> {
+  %cst = arith.constant dense<1.000000e-01> : tensor<4x1xf32>
+  %0 = tensor.reshape %cst(%shape)
+             : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract_constant_splat
 //   CHECK-NOT: tensor.extract_slice
 //       CHECK: arith.constant dense<42> : tensor<4x4xi32>
@@ -683,6 +707,30 @@ func.func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
 
 // -----
 
+// CHECK-LABEL: func @fold_pack_constant_splat
+//   CHECK-NOT: tensor.pack
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
+func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+  %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
+  %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+    inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
+  return %0 : tensor<8x16x8x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_unpack_constant_splat
+//   CHECK-NOT: tensor.unpack
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
+func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128x256xf32> {
+  %cst = arith.constant dense<1.000000e-01> : tensor<16x8x8x32xf32>
+  %0 = tensor.unpack %cst inner_dims_pos = [0, 1]
+    inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
+  return %0 : tensor<128x256xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_overlapping_insert
 //  CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
 func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {



More information about the Mlir-commits mailing list