[Mlir-commits] [mlir] 7e5b10b - [mlir][Linalg] Add support for tiling tensor.pad to scf.forall

Nicolas Vasilache llvmlistbot at llvm.org
Wed Apr 12 04:43:55 PDT 2023


Author: Nicolas Vasilache
Date: 2023-04-12T04:43:47-07:00
New Revision: 7e5b10b9f74d4e34a0eb2db806b98ecb1c31daf1

URL: https://github.com/llvm/llvm-project/commit/7e5b10b9f74d4e34a0eb2db806b98ecb1c31daf1
DIFF: https://github.com/llvm/llvm-project/commit/7e5b10b9f74d4e34a0eb2db806b98ecb1c31daf1.diff

LOG: [mlir][Linalg] Add support for tiling tensor.pad to scf.forall

Also, properly propagate the nofold attribute.

Differential Revision: https://reviews.llvm.org/D148114

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/transform-op-tile.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 453fd85f19bb6..18d485ae5463b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -354,8 +354,6 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
         return getValueOrCreateConstantIndexOp(b, loc, ofr);
       }));
 
-  Operation *tiledOp = nullptr;
-
   // 1. Create the ForallOp. We don't use the lambda body-builder
   // version because we require the use of RewriterBase in the body, so we
   // manually move the insertion point to the body below.
@@ -371,6 +369,8 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
   // 3. Clone the tileable op and update its destination operands to use the
   // output bbArgs of the ForallOp.
   ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+  Operation *tiledOp = nullptr;
+  SmallVector<Value> tiledValues;
   {
     // 3.a. RAII guard, inserting within forallOp, before terminator.
     OpBuilder::InsertionGuard g(b);
@@ -395,13 +395,12 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
     assert(tilingResult->tiledOps.size() == 1 &&
            "expected a single produced tiled op");
     tiledOp = tilingResult->tiledOps.front();
+    tiledValues = tilingResult->tiledValues;
   }
 
   // 5. Parallel insert back into the result tensor.
-  auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
-  assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
   for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
-                           tilingInterfaceOp->getResults(), destBbArgs)) {
+                           tiledValues, destBbArgs)) {
     // 5.a. Partial subset information is inserted just before the terminator.
     OpBuilder::InsertionGuard g(b);
     b.setInsertionPoint(forallOp.getTerminator());

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 1e8d4762f0f80..63f7a5af6f5e4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -617,7 +617,8 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
     // Create pad(extract_slice(x)).
     Value newSliceOp = b.create<tensor::ExtractSliceOp>(
         loc, padOp.getSource(), newOffsets, newLengths, newStrides);
-    auto newPadOp = b.create<PadOp>(loc, Type(), newSliceOp, newLows, newHighs);
+    auto newPadOp = b.create<PadOp>(loc, Type(), newSliceOp, newLows, newHighs,
+                                    /*nofold=*/padOp.getNofold());
 
     // Copy region to new PadOp.
     IRMapping bvm;

diff  --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index df48c2138fbd8..15e8f8c41eba7 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -123,3 +123,28 @@ func.func @tile_linalg_matmul(
     -> tensor<128x128xf32>
   return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32>
 }
+
+// -----
+
+// CHECK-LABEL: tile_tensor_pad
+func.func @tile_tensor_pad(
+  %arg0 : tensor<?x?xf32>, %cst : f32, %low: index, %high: index) 
+    -> tensor<20x40xf32>
+{
+  // CHECK: scf.forall
+  // CHECK:   scf.if
+  // CHECK:     tensor.generate
+  // CHECK:   else
+  // CHECK:     tensor.pad {{.*}} nofold 
+  %0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] {
+        ^bb0(%arg9: index, %arg10: index):
+          tensor.yield %cst : f32
+  } : tensor<?x?xf32> to tensor<20x40xf32>
+  return %0 : tensor<20x40xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]
+}


        


More information about the Mlir-commits mailing list