[Mlir-commits] [mlir] 061201e - [mlir][linalg] Enhance padding LinalgOps to handle tensor.empty cases.
Hanhan Wang
llvmlistbot at llvm.org
Wed Feb 8 18:39:48 PST 2023
Author: Hanhan Wang
Date: 2023-02-08T18:39:34-08:00
New Revision: 061201ec3d6d78ca5d5a583eb9141623ea3f66e7
URL: https://github.com/llvm/llvm-project/commit/061201ec3d6d78ca5d5a583eb9141623ea3f66e7
DIFF: https://github.com/llvm/llvm-project/commit/061201ec3d6d78ca5d5a583eb9141623ea3f66e7.diff
LOG: [mlir][linalg] Enhance padding LinalgOps to handle tensor.empty cases.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D143043
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-op-pad.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index f1f92d329ebb4..3b2cd0d29d95e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -115,19 +115,28 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber());
}
- // Fail if `currOpOperand` is not defined by an ExtractSliceOp.
+ // Fail if `currOpOperand` is not defined by an ExtractSliceOp or EmptyOp.
auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
- if (!sliceOp)
+ auto emptyOp = currOpOperand->get().getDefiningOp<tensor::EmptyOp>();
+ if (!sliceOp && !emptyOp)
return failure();
- // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
- llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
- OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
+ llvm::SmallBitVector droppedDims;
+ SmallVector<OpFoldResult> mixedSizes;
+ if (sliceOp) {
+ // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
+ droppedDims = sliceOp.getDroppedDims();
+ mixedSizes = sliceOp.getMixedSizes();
+ }
+ if (emptyOp) {
+ mixedSizes = emptyOp.getMixedSizes();
+ droppedDims.resize(mixedSizes.size());
+ }
- // Upper bound the `sliceOp` sizes to obtain a static bounding box.
+ // Upper bound the sizes to obtain a static bounding box.
SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
int64_t shapeIdx = 0;
- for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
+ for (const auto &en : enumerate(mixedSizes)) {
// Skip dropped dimensions.
if (droppedDims.test(en.index()))
continue;
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index dcd800300df9d..e712c3d8417f0 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -39,6 +39,44 @@ transform.sequence failures(propagate) {
// -----
+#map = affine_map<()[s0] -> (-s0 + 12, 7)>
+
+// CHECK-LABEL: @static_sizes_output_divisible_on_empty_op
+func.func @static_sizes_output_divisible_on_empty_op(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, %iv0: index,
+ %iv1: index, %iv2: index) -> tensor<24x25xf32> {
+ %0 = affine.min #map()[%iv2]
+
+ // CHECK: %[[T0:.*]] = tensor.empty
+ // CHECK: %[[T1:.*]] = tensor.empty
+ // CHECK: %[[T2:.*]] = tensor.empty
+ %1 = tensor.empty(%0) : tensor<4x?xf32>
+ %2 = tensor.empty(%0) : tensor<?x5xf32>
+ %3 = tensor.empty() : tensor<4x5xf32>
+
+ // CHECK-DAG: %[[CST:.*]] = arith.constant 0.
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+
+ // CHECK: %[[T3:.*]] = tensor.pad %[[T0]] nofold
+ // CHECK: tensor.yield %[[CST]]
+ // CHECK: %[[T4:.*]] = tensor.pad %[[T1]] nofold
+
+ // CHECK: %[[T5:.*]] = linalg.matmul
+ // CHECK-SAME: ins(%[[T3]], %[[T4]] : tensor<4x7xf32>, tensor<7x5xf32>)
+ // CHECK-SAME: outs(%[[T2]] : tensor<4x5xf32>)
+ %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
+ %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
+ func.return %5 : tensor<24x25xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
+}
+
+// -----
+
func.func @pad(%arg0: tensor<24x12xf32>,
%arg1: tensor<12x25xf32>,
%arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
More information about the Mlir-commits
mailing list