[Mlir-commits] [mlir] c3e3d59 - [mlir][tensor] Fix tensor::PackOp fold() handling of padding value (#87296)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 2 13:49:31 PDT 2024
Author: Han-Chung Wang
Date: 2024-04-02T13:49:28-07:00
New Revision: c3e3d59fab8ae8161810c861d78c7b5fcabb1a2e
URL: https://github.com/llvm/llvm-project/commit/c3e3d59fab8ae8161810c861d78c7b5fcabb1a2e
DIFF: https://github.com/llvm/llvm-project/commit/c3e3d59fab8ae8161810c861d78c7b5fcabb1a2e.diff
LOG: [mlir][tensor] Fix tensor::PackOp fold() handling of padding value (#87296)
We can't just check if it is a splat constant or not. We should also
check if the value match.
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 38a9ad60bb7948..0ce40e81371209 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1068,10 +1068,13 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// Try to remove a tensor operation if it would only reshape a constant.
/// Removes the op and replaces the constant with a new constant of the result
-/// shape.
-static OpFoldResult reshapeConstantSource(DenseElementsAttr source,
- TensorType result) {
- if (source && source.isSplat() && result.hasStaticShape())
+/// shape. When an optional cst attribute is passed, it is reshaped only if the
+/// splat value matches the value in the attribute.
+static OpFoldResult
+reshapeConstantSource(DenseElementsAttr source, TensorType result,
+ std::optional<Attribute> cst = std::nullopt) {
+ if (source && source.isSplat() && result.hasStaticShape() &&
+ (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
return source.resizeSplat(result);
return {};
@@ -4143,9 +4146,12 @@ bool PackOp::isLikePad() {
}
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+ std::optional<Attribute> paddingValue;
+ if (auto pad = adaptor.getPaddingValue())
+ paddingValue = pad;
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
- getResult().getType()))
+ getDestType(), paddingValue))
return reshapedSource;
return {};
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 9ab54fe9c133db..ac365c9d297e88 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -830,6 +830,39 @@ func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x1
// -----
+// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
+// CHECK-NOT: tensor.pack
+// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
+func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+ %pad = arith.constant 1.000000e-01 : f32
+ %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+ %0 = tensor.pack %cst
+ padding_value(%pad : f32)
+ outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
+ return %0 : tensor<8x16x8x32xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
+// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+// CHECK: tensor.pack
+func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+ %pad = arith.constant 0.0 : f32
+ %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+ %0 = tensor.pack %cst
+ padding_value(%pad : f32)
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [0, 1]
+ inner_tiles = [8, 32]
+ into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
+ return %0 : tensor<8x16x8x32xf32>
+}
+
+// -----
+
func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
More information about the Mlir-commits
mailing list