[Mlir-commits] [mlir] 061201e - [mlir][linalg] Enhance padding LinalgOps to handle tensor.empty cases.

Hanhan Wang llvmlistbot at llvm.org
Wed Feb 8 18:39:48 PST 2023


Author: Hanhan Wang
Date: 2023-02-08T18:39:34-08:00
New Revision: 061201ec3d6d78ca5d5a583eb9141623ea3f66e7

URL: https://github.com/llvm/llvm-project/commit/061201ec3d6d78ca5d5a583eb9141623ea3f66e7
DIFF: https://github.com/llvm/llvm-project/commit/061201ec3d6d78ca5d5a583eb9141623ea3f66e7.diff

LOG: [mlir][linalg] Enhance padding LinalgOps to handle tensor.empty cases.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D143043

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-op-pad.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index f1f92d329ebb4..3b2cd0d29d95e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -115,19 +115,28 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
     currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber());
   }
 
-  // Fail if `currOpOperand` is not defined by an ExtractSliceOp.
+  // Fail if `currOpOperand` is not defined by an ExtractSliceOp or EmptyOp.
   auto sliceOp = currOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
-  if (!sliceOp)
+  auto emptyOp = currOpOperand->get().getDefiningOp<tensor::EmptyOp>();
+  if (!sliceOp && !emptyOp)
     return failure();
 
-  // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
-  llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
-  OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
+  llvm::SmallBitVector droppedDims;
+  SmallVector<OpFoldResult> mixedSizes;
+  if (sliceOp) {
+    // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
+    droppedDims = sliceOp.getDroppedDims();
+    mixedSizes = sliceOp.getMixedSizes();
+  }
+  if (emptyOp) {
+    mixedSizes = emptyOp.getMixedSizes();
+    droppedDims.resize(mixedSizes.size());
+  }
 
-  // Upper bound the `sliceOp` sizes to obtain a static bounding box.
+  // Upper bound the sizes to obtain a static bounding box.
   SmallVector<int64_t> paddedShape(shape.begin(), shape.end());
   int64_t shapeIdx = 0;
-  for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
+  for (const auto &en : enumerate(mixedSizes)) {
     // Skip dropped dimensions.
     if (droppedDims.test(en.index()))
       continue;

diff  --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index dcd800300df9d..e712c3d8417f0 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -39,6 +39,44 @@ transform.sequence failures(propagate) {
 
 // -----
 
+#map = affine_map<()[s0] -> (-s0 + 12, 7)>
+
+// CHECK-LABEL: @static_sizes_output_divisible_on_empty_op
+func.func @static_sizes_output_divisible_on_empty_op(%arg0: tensor<24x12xf32>,
+    %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, %iv0: index,
+    %iv1: index, %iv2: index) -> tensor<24x25xf32> {
+  %0 = affine.min #map()[%iv2]
+
+  //      CHECK: %[[T0:.*]] = tensor.empty
+  //      CHECK: %[[T1:.*]] = tensor.empty
+  //      CHECK: %[[T2:.*]] = tensor.empty
+  %1 = tensor.empty(%0) : tensor<4x?xf32>
+  %2 = tensor.empty(%0) : tensor<?x5xf32>
+  %3 = tensor.empty() : tensor<4x5xf32>
+
+  //  CHECK-DAG: %[[CST:.*]] = arith.constant 0.
+  //  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+
+  //      CHECK: %[[T3:.*]] = tensor.pad %[[T0]] nofold
+  //      CHECK: tensor.yield %[[CST]]
+  //      CHECK: %[[T4:.*]] = tensor.pad %[[T1]] nofold
+
+  //      CHECK: %[[T5:.*]] = linalg.matmul
+  // CHECK-SAME:              ins(%[[T3]], %[[T4]] : tensor<4x7xf32>, tensor<7x5xf32>)
+  // CHECK-SAME:              outs(%[[T2]] : tensor<4x5xf32>)
+  %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
+  %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
+  func.return %5 : tensor<24x25xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]}
+}
+
+// -----
+
 func.func @pad(%arg0: tensor<24x12xf32>,
                %arg1: tensor<12x25xf32>,
                %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {


        


More information about the Mlir-commits mailing list