[Mlir-commits] [mlir] [mlir][TilingInterface] Make the tiling set tile sizes function use `OpFoldResult`. (PR #66566)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 15 18:35:18 PDT 2023


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/66566

>From 9caa02e285022160132c2ae31772070dfa8fd10c Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Fri, 15 Sep 2023 18:24:23 -0700
Subject: [PATCH] [mlir][TilingInterface] Make the tiling set tile sizes
 function use `OpFoldResult`.

---
 .../SCF/Transforms/TileUsingInterface.h       | 11 +----
 .../TransformOps/LinalgTransformOps.cpp       | 24 +++++-----
 .../SCF/Transforms/TileUsingInterface.cpp     | 46 ++++++++-----------
 .../Dialect/Linalg/transform-op-tile.mlir     |  4 +-
 .../TilingInterface/TestTilingInterface.cpp   | 14 ++++--
 5 files changed, 45 insertions(+), 54 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index e7bcd062d96525d..ca641c596c7b7bb 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -26,7 +26,7 @@ namespace mlir {
 namespace scf {
 
 using SCFTileSizeComputationFunction =
-    std::function<SmallVector<Value>(OpBuilder &, Operation *)>;
+    std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
 
 /// Options to use to control tiling.
 struct SCFTilingOptions {
@@ -40,17 +40,10 @@ struct SCFTilingOptions {
     tileSizeComputationFunction = std::move(fun);
     return *this;
   }
-  /// Set the `tileSizeComputationFunction` to return the values `ts`. The
-  /// values must not fold away when tiling. Otherwise, use a more robust
-  /// `tileSizeComputationFunction`.
-  SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
-    tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
-    return *this;
-  }
   /// Convenience function to set the `tileSizeComputationFunction` to a
   /// function that computes tile sizes at the point they are needed. Allows
   /// proper interaction with folding.
-  SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
+  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
 
   /// The interchange vector to reorder the tiled loops.
   SmallVector<int64_t> interchangeVector = {};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dc65ac509d280dc..2e34da2a9191ae8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -473,7 +473,9 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
 
   scf::SCFTilingOptions tilingOptions;
   tilingOptions.interchangeVector = tileInterchange;
-  tilingOptions = tilingOptions.setTileSizes(tileSizes);
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
+  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
   scf::SCFTileAndFuseOptions tileAndFuseOptions;
   tileAndFuseOptions.tilingOptions = tilingOptions;
   LogicalResult result = applyTilingToAll(
@@ -923,7 +925,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
     auto nextProducer = getNextProducer();
     if (failed(nextProducer)) {
       auto diag = mlir::emitSilenceableFailure(getLoc())
-             << "could not find next producer to fuse into container";
+                  << "could not find next producer to fuse into container";
       diag.attachNote(containingOp->getLoc()) << "containing op";
       return diag;
     }
@@ -1999,7 +2001,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
                                    transform::TransformState &state) {
   scf::SCFTilingOptions tilingOptions;
   tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
-    SmallVector<Value, 4> tileSizes;
+    SmallVector<OpFoldResult> tileSizes;
     Location loc = target.getLoc();
     SmallVector<OpFoldResult> allShapeSizes =
         target.createFlatListOfOperandDims(b, loc);
@@ -2012,9 +2014,8 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
     // If the shape size is dynamic, tile by 1.
     // Otherwise, do not tile (i.e. tile size 0).
     for (OpFoldResult shapeSize : shapeSizes) {
-      tileSizes.push_back(getConstantIntValue(shapeSize)
-                              ? b.create<arith::ConstantIndexOp>(loc, 0)
-                              : b.create<arith::ConstantIndexOp>(loc, 1));
+      tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
+                                                         : b.getIndexAttr(1));
     }
     return tileSizes;
   });
@@ -2549,7 +2550,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
     if (!tileSizes.empty()) {
       tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
                                                                   Operation *) {
-        SmallVector<Value, 4> sizes;
+        SmallVector<OpFoldResult> sizes;
         sizes.reserve(tileSizes.size());
         unsigned dynamicIdx = 0;
 
@@ -2560,10 +2561,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
                   getLoc(), attr.cast<IntegerAttr>().getInt());
               Value vscale =
                   b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
-              sizes.push_back(b.create<arith::MulIOp>(getLoc(), val, vscale));
+              sizes.push_back(
+                  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
             } else {
-              sizes.push_back(b.create<arith::ConstantIndexOp>(
-                  getLoc(), cast<IntegerAttr>(attr).getInt()));
+              sizes.push_back(attr);
             }
             continue;
           }
@@ -2573,8 +2574,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
           assert((dynamicSizes.empty() ^ params.empty()) &&
                  "expected either dynamic sizes or parameters");
           if (!params.empty()) {
-            sizes.push_back(
-                b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
+            sizes.push_back(b.getIndexAttr(params[index]));
           } else {
             sizes.push_back(dynamicSizes[index]->getResult(0));
           }
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 597676a017bf482..2d1bb3388ed55f2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -31,19 +31,11 @@
 using namespace mlir;
 
 scf::SCFTilingOptions &
-scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
+scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
   assert(!tileSizeComputationFunction && "tile sizes already set");
-  SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
+  auto tileSizes = llvm::to_vector(ts);
   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
-    OpBuilder::InsertionGuard guard(b);
-    b.setInsertionPointToStart(
-        &op->getParentWithTrait<OpTrait::IsIsolatedFromAbove>()
-             ->getRegion(0)
-             .front());
-    return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
-      Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
-      return v;
-    }));
+    return tileSizes;
   };
   return *this;
 }
@@ -112,11 +104,10 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
 /// the
 ///   tile processed within the inner most loop.
-static SmallVector<scf::ForOp>
-generateTileLoopNest(OpBuilder &builder, Location loc,
-                     ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
-                     SmallVector<OpFoldResult> &offsets,
-                     SmallVector<OpFoldResult> &sizes) {
+static SmallVector<scf::ForOp> generateTileLoopNest(
+    OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
+    ArrayRef<OpFoldResult> tileSizeVals, SmallVector<OpFoldResult> &offsets,
+    SmallVector<OpFoldResult> &sizes) {
   assert(!loopRanges.empty() && "expected at least one loop range");
   assert(loopRanges.size() == tileSizeVals.size() &&
          "expected as many tile sizes as loop ranges");
@@ -130,7 +121,8 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
     Value size =
         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
-    Value tileSize = tileSizeVals[loopRange.index()];
+    Value tileSize = getValueOrCreateConstantIndexOp(
+        builder, loc, tileSizeVals[loopRange.index()]);
     // No loops if tile size is zero. Set offset and size to the loop
     // offset and size.
     if (matchPattern(tileSize, m_Zero())) {
@@ -296,10 +288,10 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   // skips tiling a particular dimension. This convention is significantly
   // simpler to handle instead of adjusting affine maps to account for missing
   // dimensions.
-  SmallVector<Value> tileSizeVector =
+  SmallVector<OpFoldResult> tileSizeVector =
       options.tileSizeComputationFunction(rewriter, op);
   if (tileSizeVector.size() < iterationDomain.size()) {
-    auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+    auto zero = rewriter.getIndexAttr(0);
     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
   }
 
@@ -402,17 +394,17 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
 FailureOr<scf::SCFReductionTilingResult>
 mlir::scf::tileReductionUsingScf(RewriterBase &b,
                                  PartialReductionOpInterface op,
-                                 ArrayRef<OpFoldResult> tileSize) {
+                                 ArrayRef<OpFoldResult> tileSizes) {
   Location loc = op.getLoc();
   // Ops implementing PartialReductionOpInterface are expected to implement
   // TilingInterface.
   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
-  SmallVector<Value> tileSizeVector =
-      getValueOrCreateConstantIndexOp(b, loc, tileSize);
-  if (tileSizeVector.size() < iterationDomain.size()) {
-    auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
-    tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
+  auto tileSizesVector = llvm::to_vector(tileSizes);
+  if (tileSizesVector.size() < iterationDomain.size()) {
+    auto zero = b.getIndexAttr(0);
+    tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
+                           zero);
   }
   if (op->getNumResults() != 1)
     return b.notifyMatchFailure(
@@ -429,7 +421,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 
   // 1. create the inital tensor value.
   FailureOr<Operation *> identityTensor =
-      op.generateInitialTensorForPartialReduction(b, loc, tileSize,
+      op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
                                                   reductionDims);
   if (failed(identityTensor))
     return b.notifyMatchFailure(op,
@@ -437,7 +429,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   // 2. Create the nested loops.
   SmallVector<OpFoldResult> offsets, sizes;
   SmallVector<scf::ForOp> loops = generateTileLoopNest(
-      b, loc, iterationDomain, tileSizeVector, offsets, sizes);
+      b, loc, iterationDomain, tileSizesVector, offsets, sizes);
 
   // 3. Generate the tiled implementation within the inner most loop.
   b.setInsertionPoint(loops.back().getBody()->getTerminator());
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index ce2a3d6ca9c58da..9df19632506a73c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -190,16 +190,16 @@ transform.sequence failures(propagate) {
 // -----
 
 // CHECK-LABEL:   func.func @scalable_and_fixed_length_tile
-// CHECK:           %[[STEP_0:.*]] = arith.constant 4 : index
-// CHECK:           %[[STEP_1:.*]] = arith.constant 4 : index
 // CHECK:           %[[C4:.*]] = arith.constant 4 : index
 // CHECK:           %[[VS:.*]] = vector.vscale
 // CHECK:           %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[C128:.*]] = arith.constant 128 : index
+// CHECK:           %[[STEP_0:.*]] = arith.constant 4 : index
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
 // CHECK:             %[[C0_1:.*]] = arith.constant 0 : index
 // CHECK:             %[[C128_1:.*]] = arith.constant 128 : index
+// CHECK:             %[[STEP_1:.*]] = arith.constant 4 : index
 // CHECK:             scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
 // CHECK:               %[[C0_2:.*]] = arith.constant 0 : index
 // CHECK:               %[[C128_2:.*]] = arith.constant 128 : index
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 752c885e0b87bdb..2fcc7bcadb60450 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -450,7 +450,9 @@ static void addPatternForTiling(MLIRContext *context,
                                 ArrayRef<int64_t> tileSizes,
                                 ArrayRef<int64_t> interchange = {}) {
   scf::SCFTilingOptions tilingOptions;
-  tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(context, tileSizes);
+  tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
   LinalgTransformationFilter filter(StringAttr::get(context, filterName),
                                     StringAttr::get(context, "tiled"));
   patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
@@ -462,7 +464,9 @@ static void addPatternForTileFuseAndYield(MLIRContext *context,
                                           ArrayRef<int64_t> tileSizes,
                                           ArrayRef<int64_t> interchange = {}) {
   scf::SCFTilingOptions tilingOptions;
-  tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(context, tileSizes);
+  tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
   LinalgTransformationFilter filter(StringAttr::get(context, filterName),
                                     StringAttr::get(context, "tiled"));
   patterns.add<TestTileConsumerFuseAndYieldProducerUsingSCFForOp>(
@@ -475,8 +479,10 @@ static void addPatternForTileAndFuse(MLIRContext *context,
                                      ArrayRef<int64_t> tileSizes,
                                      ArrayRef<int64_t> interchange = {}) {
   scf::SCFTileAndFuseOptions tileAndFuseOptions;
-  tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
-      interchange);
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(context, tileSizes);
+  tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr)
+      .setInterchange(interchange);
   LinalgTransformationFilter filter(StringAttr::get(context, filterName),
                                     StringAttr::get(context, "tiled"));
   patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(



More information about the Mlir-commits mailing list