[Mlir-commits] [mlir] [MLIR][Tensor] Canonicalize fully covering slice insertions into tensors with unit prefixes (PR #92912)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 21 06:01:12 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andi Drebes (andidr)

<details>
<summary>Changes</summary>

If the destination tensor of the insertion of a slice has the same number of elements as the slice, but with a shape that only differs by a prefix of unit-sized dimensions, and if the insertion happens at zero offsets, unit strides and with a size matching the size of the destination, the insertion covers all elements of the destination. The result of such an insertion is equivalent to the slice, with its shape expanded to the type of the destination.

Example:
```mlir
  %0 = tensor.insert_slice %slice into
     %x[0, 0, 0, 0, 0][1, 1, 1, 16, 32][1, 1, 1, 1, 1] :
     tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
```

folds into:

```mlir
  %0 = tensor.expand_shape %slice[[0,1,2,3], [4]] :
          tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
```

This PR adds a canonicalization pattern for `InsertSliceOp` that implements this pattern.

---
Full diff: https://github.com/llvm/llvm-project/pull/92912.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+87-1) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+12) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8545c7b9af8f7..52d7005470232 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2835,6 +2835,91 @@ struct InsertSliceOpSourceCastInserter final
     return success();
   }
 };
+
+/// If the destination tensor of the insertion of a slice has the same
+/// number of elements as the slice, but with a shape that only
+/// differs by a prefix of unit-sized dimensions, and if the insertion
+/// happens at zero offsets, unit strides and with a size matching the
+/// size of the destination, the insertion covers all elements of the
+/// destination. The result of such an insertion is equivalent to the
+/// slice, with its shape expanded to the type of the destination.
+///
+/// Example:
+/// ```mlir
+///   %0 = tensor.insert_slice %slice into
+///           %x[0, 0, 0, 0, 0][1, 1, 1, 16, 32][1, 1, 1, 1, 1] :
+///           tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %0 = tensor.expand_shape %slice[[0,1,2,3], [4]] :
+///           tensor<16x32xf32> into tensor<1x1x1x16x32xf32>
+/// ```
+struct InsertSliceOpFullRewriteCanonicalizer final
+    : public OpRewritePattern<InsertSliceOp> {
+  using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+                                PatternRewriter &rewriter) const override {
+    RankedTensorType sourceType = insertSliceOp.getSourceType();
+    RankedTensorType resultType = insertSliceOp.getType();
+
+    if (sourceType != resultType && sourceType.hasStaticShape() &&
+        resultType.hasStaticShape() &&
+        isSameSizedSuffixShape(resultType.getShape(), sourceType.getShape()) &&
+        succeeded(foldIdentityOffsetSizeAndStrideOpInterface(insertSliceOp,
+                                                             resultType))) {
+      SmallVector<ReassociationIndices> reassocIndices;
+
+      // Number of leading dimensions with unit size that are not
+      // shared with the source type
+      size_t unitPrefixLength =
+          resultType.getShape().size() - sourceType.getShape().size();
+
+      // Compose mapping of leading dimensions with unit size and the
+      // fist common dimension to the first dimension of the source
+      // tensor
+      ReassociationIndices unitPrefixExpansion;
+
+      size_t dim;
+      for (dim = 0; dim < unitPrefixLength; dim++)
+        unitPrefixExpansion.push_back(dim);
+
+      unitPrefixExpansion.push_back(unitPrefixLength);
+      reassocIndices.push_back(unitPrefixExpansion);
+
+      // Map remaining common dimensions of the source to the target
+      for (dim = dim + 1; dim < resultType.getShape().size(); dim++) {
+        reassocIndices.push_back({static_cast<int64_t>(dim)});
+      }
+
+      rewriter.replaceOpWithNewOp<ExpandShapeOp>(
+          insertSliceOp, insertSliceOp.getType(), insertSliceOp.getSource(),
+          reassocIndices);
+
+      return mlir::success();
+    }
+
+    return mlir::failure();
+  }
+
+private:
+  /// Checks if `suffix` is a suffix of `shape` and all preceding
+  /// elements in `shape` are ones.
+  static bool isSameSizedSuffixShape(ArrayRef<int64_t> shape,
+                                     ArrayRef<int64_t> suffix) {
+    if (shape.size() >= suffix.size()) {
+      ArrayRef<int64_t> prefix = shape.take_front(shape.size() - suffix.size());
+      ArrayRef<int64_t> remainder = shape.take_back(suffix.size());
+
+      return llvm::all_of(prefix, [](int64_t d) { return d == 1; }) &&
+             remainder == suffix;
+    }
+
+    return false;
+  }
+};
 } // namespace
 
 llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
@@ -2845,7 +2930,8 @@ void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
               InsertSliceOpCastFolder<InsertSliceOp>,
-              InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
+              InsertSliceOpSourceCastInserter<InsertSliceOp>,
+              InsertSliceOpFullRewriteCanonicalizer>(context);
 }
 
 Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 914e5e8b8c4b8..8e66ef9f89c74 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6
 
 // -----
 
+// CHECK-LABEL: func @trivial_insert_slice_unit_prefix
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
+//   CHECK-NOT:   tensor.insert_slice
+//       CHECK:   %[[EXPANDED:.[a-z0-9A-Z_]+]] = tensor.expand_shape %[[ARG0]] {{\[\[0, 1, 2, 3\], \[4\], \[5\], \[6\]\] output}}_shape {{\[1, 1, 1, 4, 6, 16, 32\]}} : tensor<4x6x16x32xi8> into tensor<1x1x1x4x6x16x32xi8>
+//       CHECK:   return %[[EXPANDED]] : tensor<1x1x1x4x6x16x32xi8>
+func.func @trivial_insert_slice_unit_prefix(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<1x1x1x4x6x16x32xi8>) -> tensor<1x1x1x4x6x16x32xi8> {
+  %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0, 0] [1, 1, 1, 4, 6, 16, 32] [1, 1, 1, 1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<1x1x1x4x6x16x32xi8>
+  return %0 : tensor<1x1x1x4x6x16x32xi8>
+}
+
+// -----
+
 // CHECK-LABEL: func @empty_insert_slice
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
 //  CHECK-SAME:   %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>

``````````

</details>


https://github.com/llvm/llvm-project/pull/92912


More information about the Mlir-commits mailing list