[Mlir-commits] [mlir] [mlir][TilingInterface] Make the tiling set tile sizes function use `OpFoldResult`. (PR #66566)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 18 10:18:53 PDT 2023


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/66566

>From 3a8279c5f7d3498a68e818ea3fa8a21b8c761731 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 14 Sep 2023 01:19:44 +0000
Subject: [PATCH 1/2] Revert "[mlir][vector] Improve lowering to LLVM for
 `minf`, `maxf` reductions"

This reverts commit dad9de0ae5360b18c890985d212bec266bf8c122.
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 62 +++++++++++--------
 .../VectorToLLVM/vector-to-llvm.mlir          | 16 +++--
 2 files changed, 47 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 92f7aa69760395a..8c8d53f0d6df68f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -566,31 +566,35 @@ static Value createIntegerReductionComparisonOpLowering(
   return result;
 }
 
-namespace {
-template <typename Source>
-struct VectorToScalarMapper;
-template <>
-struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
-  using Type = LLVM::MaximumOp;
-};
-template <>
-struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
-  using Type = LLVM::MinimumOp;
-};
-} // namespace
+/// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
+/// with vector types.
+static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
+                           Value rhs, bool isMin) {
+  auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
+  Type i1Type = builder.getI1Type();
+  if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
+    i1Type = VectorType::get(vecType.getShape(), i1Type);
+  Value cmp = builder.create<LLVM::FCmpOp>(
+      loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
+      lhs, rhs);
+  Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
+  Value isNan = builder.create<LLVM::FCmpOp>(
+      loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
+  Value nan = builder.create<LLVM::ConstantOp>(
+      loc, lhs.getType(),
+      builder.getFloatAttr(floatType,
+                           APFloat::getQNaN(floatType.getFloatSemantics())));
+  return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
+}
 
 template <class LLVMRedIntrinOp>
-static Value
-createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter,
-                                      Location loc, Type llvmType,
-                                      Value vectorOperand, Value accumulator) {
+static Value createFPReductionComparisonOpLowering(
+    ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+    Value vectorOperand, Value accumulator, bool isMin) {
   Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
 
-  if (accumulator) {
-    result =
-        rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
-            loc, result, accumulator);
-  }
+  if (accumulator)
+    result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin);
 
   return result;
 }
@@ -763,13 +767,17 @@ class VectorReductionOpConversion
                                             ReductionNeutralFPOne>(
           rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
     } else if (kind == vector::CombiningKind::MINF) {
-      result =
-          createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
-              rewriter, loc, llvmType, operand, acc);
+      // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
+      // NaNs/-0.0/+0.0 in the same way.
+      result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
+          rewriter, loc, llvmType, operand, acc,
+          /*isMin=*/true);
     } else if (kind == vector::CombiningKind::MAXF) {
-      result =
-          createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
-              rewriter, loc, llvmType, operand, acc);
+      // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
+      // NaNs/-0.0/+0.0 in the same way.
+      result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
+          rewriter, loc, llvmType, operand, acc,
+          /*isMin=*/false);
     } else
       return failure();
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 514594240d22a1b..9a0287d241345b8 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1323,8 +1323,12 @@ func.func @reduce_fmax_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
 }
 // CHECK-LABEL: @reduce_fmax_f32(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
-//      CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmaximum(%[[A]]) : (vector<16xf32>) -> f32
-//      CHECK: %[[R:.*]] = llvm.intr.maximum(%[[V]], %[[B]]) : (f32, f32) -> f32
+//      CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmax(%[[A]]) : (vector<16xf32>) -> f32
+//      CHECK: %[[C0:.*]] = llvm.fcmp "ogt" %[[V]], %[[B]] : f32
+//      CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
+//      CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32
+//      CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+//      CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32
 //      CHECK: return %[[R]] : f32
 
 // -----
@@ -1335,8 +1339,12 @@ func.func @reduce_fmin_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
 }
 // CHECK-LABEL: @reduce_fmin_f32(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
-//      CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fminimum(%[[A]]) : (vector<16xf32>) -> f32
-//      CHECK: %[[R:.*]] = llvm.intr.minimum(%[[V]], %[[B]]) : (f32, f32) -> f32
+//      CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmin(%[[A]]) : (vector<16xf32>) -> f32
+//      CHECK: %[[C0:.*]] = llvm.fcmp "olt" %[[V]], %[[B]] : f32
+//      CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
+//      CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32
+//      CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+//      CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32
 //      CHECK: return %[[R]] : f32
 
 // -----

>From b7f1db56902061c6ae961dff0a6af79fc14f2cc6 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Fri, 15 Sep 2023 18:24:23 -0700
Subject: [PATCH 2/2] [mlir][TilingInterface] Make the tiling set tile sizes
 function use `OpFoldResult`.

---
 .../SCF/Transforms/TileUsingInterface.h       | 11 +----
 .../TransformOps/LinalgTransformOps.cpp       | 29 ++++++-----
 .../SCF/Transforms/TileUsingInterface.cpp     | 48 ++++++++-----------
 .../Dialect/Linalg/transform-op-tile.mlir     |  4 +-
 .../TilingInterface/TestTilingInterface.cpp   | 14 ++++--
 5 files changed, 48 insertions(+), 58 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index e7bcd062d96525d..ca641c596c7b7bb 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -26,7 +26,7 @@ namespace mlir {
 namespace scf {
 
 using SCFTileSizeComputationFunction =
-    std::function<SmallVector<Value>(OpBuilder &, Operation *)>;
+    std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
 
 /// Options to use to control tiling.
 struct SCFTilingOptions {
@@ -40,17 +40,10 @@ struct SCFTilingOptions {
     tileSizeComputationFunction = std::move(fun);
     return *this;
   }
-  /// Set the `tileSizeComputationFunction` to return the values `ts`. The
-  /// values must not fold away when tiling. Otherwise, use a more robust
-  /// `tileSizeComputationFunction`.
-  SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
-    tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
-    return *this;
-  }
   /// Convenience function to set the `tileSizeComputationFunction` to a
   /// function that computes tile sizes at the point they are needed. Allows
   /// proper interaction with folding.
-  SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
+  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
 
   /// The interchange vector to reorder the tiled loops.
   SmallVector<int64_t> interchangeVector = {};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3421a3c169dbba1..bc6c4f851987841 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -472,7 +472,9 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
 
   scf::SCFTilingOptions tilingOptions;
   tilingOptions.interchangeVector = tileInterchange;
-  tilingOptions = tilingOptions.setTileSizes(tileSizes);
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
+  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
   scf::SCFTileAndFuseOptions tileAndFuseOptions;
   tileAndFuseOptions.tilingOptions = tilingOptions;
   LogicalResult result = applyTilingToAll(
@@ -922,7 +924,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
     auto nextProducer = getNextProducer();
     if (failed(nextProducer)) {
       auto diag = mlir::emitSilenceableFailure(getLoc())
-             << "could not find next producer to fuse into container";
+                  << "could not find next producer to fuse into container";
       diag.attachNote(containingOp->getLoc()) << "containing op";
       return diag;
     }
@@ -1994,7 +1996,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
                                    transform::TransformState &state) {
   scf::SCFTilingOptions tilingOptions;
   tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
-    SmallVector<Value, 4> tileSizes;
+    SmallVector<OpFoldResult> tileSizes;
     Location loc = target.getLoc();
     SmallVector<OpFoldResult> allShapeSizes =
         target.createFlatListOfOperandDims(b, loc);
@@ -2007,9 +2009,8 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
     // If the shape size is dynamic, tile by 1.
     // Otherwise, do not tile (i.e. tile size 0).
     for (OpFoldResult shapeSize : shapeSizes) {
-      tileSizes.push_back(getConstantIntValue(shapeSize)
-                              ? b.create<arith::ConstantIndexOp>(loc, 0)
-                              : b.create<arith::ConstantIndexOp>(loc, 1));
+      tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
+                                                         : b.getIndexAttr(1));
     }
     return tileSizes;
   });
@@ -2535,7 +2536,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
     if (!tileSizes.empty()) {
       tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
                                                                   Operation *) {
-        SmallVector<Value, 4> sizes;
+        SmallVector<OpFoldResult> sizes;
         sizes.reserve(tileSizes.size());
         unsigned dynamicIdx = 0;
 
@@ -2546,10 +2547,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
                   getLoc(), attr.cast<IntegerAttr>().getInt());
               Value vscale =
                   b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
-              sizes.push_back(b.create<arith::MulIOp>(getLoc(), val, vscale));
+              sizes.push_back(
+                  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
             } else {
-              sizes.push_back(b.create<arith::ConstantIndexOp>(
-                  getLoc(), cast<IntegerAttr>(attr).getInt()));
+              sizes.push_back(attr);
             }
             continue;
           }
@@ -2559,8 +2560,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
           assert((dynamicSizes.empty() ^ params.empty()) &&
                  "expected either dynamic sizes or parameters");
           if (!params.empty()) {
-            sizes.push_back(
-                b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
+            sizes.push_back(b.getIndexAttr(params[index]));
           } else {
             sizes.push_back(dynamicSizes[index]->getResult(0));
           }
@@ -2974,13 +2974,12 @@ transform::TileToScfForOp::apply(transform::TransformRewriter &rewriter,
     if (!tileSizes.empty()) {
       tilingOptions.setTileSizeComputationFunction(
           [&, index](OpBuilder &b, Operation *) {
-            SmallVector<Value, 4> sizes;
+            SmallVector<OpFoldResult> sizes;
             sizes.reserve(tileSizes.size());
             unsigned dynamicIdx = 0;
             for (OpFoldResult ofr : getMixedSizes()) {
               if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
-                sizes.push_back(b.create<arith::ConstantIndexOp>(
-                    getLoc(), cast<IntegerAttr>(attr).getInt()));
+                sizes.push_back(attr);
               } else {
                 sizes.push_back(
                     dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 597676a017bf482..910d5e4f4f1100f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -31,19 +31,11 @@
 using namespace mlir;
 
 scf::SCFTilingOptions &
-scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
+scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
   assert(!tileSizeComputationFunction && "tile sizes already set");
-  SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
+  auto tileSizes = llvm::to_vector(ts);
   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
-    OpBuilder::InsertionGuard guard(b);
-    b.setInsertionPointToStart(
-        &op->getParentWithTrait<OpTrait::IsIsolatedFromAbove>()
-             ->getRegion(0)
-             .front());
-    return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
-      Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
-      return v;
-    }));
+    return tileSizes;
   };
   return *this;
 }
@@ -108,15 +100,14 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
 
 /// Generate an empty loop nest that represents the tiled loop nest shell.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
-/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
+/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
 /// the
 ///   tile processed within the inner most loop.
-static SmallVector<scf::ForOp>
-generateTileLoopNest(OpBuilder &builder, Location loc,
-                     ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
-                     SmallVector<OpFoldResult> &offsets,
-                     SmallVector<OpFoldResult> &sizes) {
+static SmallVector<scf::ForOp> generateTileLoopNest(
+    OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
+    ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
+    SmallVector<OpFoldResult> &sizes) {
   assert(!loopRanges.empty() && "expected at least one loop range");
   assert(loopRanges.size() == tileSizeVals.size() &&
          "expected as many tile sizes as loop ranges");
@@ -130,7 +121,8 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
     Value size =
         getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
-    Value tileSize = tileSizeVals[loopRange.index()];
+    Value tileSize = getValueOrCreateConstantIndexOp(
+        builder, loc, tileSizes[loopRange.index()]);
     // No loops if tile size is zero. Set offset and size to the loop
     // offset and size.
     if (matchPattern(tileSize, m_Zero())) {
@@ -296,10 +288,10 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   // skips tiling a particular dimension. This convention is significantly
   // simpler to handle instead of adjusting affine maps to account for missing
   // dimensions.
-  SmallVector<Value> tileSizeVector =
+  SmallVector<OpFoldResult> tileSizeVector =
       options.tileSizeComputationFunction(rewriter, op);
   if (tileSizeVector.size() < iterationDomain.size()) {
-    auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+    auto zero = rewriter.getIndexAttr(0);
     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
   }
 
@@ -402,17 +394,17 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
 FailureOr<scf::SCFReductionTilingResult>
 mlir::scf::tileReductionUsingScf(RewriterBase &b,
                                  PartialReductionOpInterface op,
-                                 ArrayRef<OpFoldResult> tileSize) {
+                                 ArrayRef<OpFoldResult> tileSizes) {
   Location loc = op.getLoc();
   // Ops implementing PartialReductionOpInterface are expected to implement
   // TilingInterface.
   auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
   SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
-  SmallVector<Value> tileSizeVector =
-      getValueOrCreateConstantIndexOp(b, loc, tileSize);
-  if (tileSizeVector.size() < iterationDomain.size()) {
-    auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
-    tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
+  auto tileSizesVector = llvm::to_vector(tileSizes);
+  if (tileSizesVector.size() < iterationDomain.size()) {
+    auto zero = b.getIndexAttr(0);
+    tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
+                           zero);
   }
   if (op->getNumResults() != 1)
     return b.notifyMatchFailure(
@@ -429,7 +421,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 
   // 1. create the inital tensor value.
   FailureOr<Operation *> identityTensor =
-      op.generateInitialTensorForPartialReduction(b, loc, tileSize,
+      op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
                                                   reductionDims);
   if (failed(identityTensor))
     return b.notifyMatchFailure(op,
@@ -437,7 +429,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   // 2. Create the nested loops.
   SmallVector<OpFoldResult> offsets, sizes;
   SmallVector<scf::ForOp> loops = generateTileLoopNest(
-      b, loc, iterationDomain, tileSizeVector, offsets, sizes);
+      b, loc, iterationDomain, tileSizesVector, offsets, sizes);
 
   // 3. Generate the tiled implementation within the inner most loop.
   b.setInsertionPoint(loops.back().getBody()->getTerminator());
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index d4629dcb29c3efc..2608b703898611b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -190,16 +190,16 @@ transform.sequence failures(propagate) {
 // -----
 
 // CHECK-LABEL:   func.func @scalable_and_fixed_length_tile
-// CHECK:           %[[STEP_0:.*]] = arith.constant 4 : index
-// CHECK:           %[[STEP_1:.*]] = arith.constant 4 : index
 // CHECK:           %[[C4:.*]] = arith.constant 4 : index
 // CHECK:           %[[VS:.*]] = vector.vscale
 // CHECK:           %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[C128:.*]] = arith.constant 128 : index
+// CHECK:           %[[STEP_0:.*]] = arith.constant 4 : index
 // CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
 // CHECK:             %[[C0_1:.*]] = arith.constant 0 : index
 // CHECK:             %[[C128_1:.*]] = arith.constant 128 : index
+// CHECK:             %[[STEP_1:.*]] = arith.constant 4 : index
 // CHECK:             scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
 // CHECK:               %[[C0_2:.*]] = arith.constant 0 : index
 // CHECK:               %[[C128_2:.*]] = arith.constant 128 : index
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 752c885e0b87bdb..2fcc7bcadb60450 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -450,7 +450,9 @@ static void addPatternForTiling(MLIRContext *context,
                                 ArrayRef<int64_t> tileSizes,
                                 ArrayRef<int64_t> interchange = {}) {
   scf::SCFTilingOptions tilingOptions;
-  tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(context, tileSizes);
+  tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
   LinalgTransformationFilter filter(StringAttr::get(context, filterName),
                                     StringAttr::get(context, "tiled"));
   patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
@@ -462,7 +464,9 @@ static void addPatternForTileFuseAndYield(MLIRContext *context,
                                           ArrayRef<int64_t> tileSizes,
                                           ArrayRef<int64_t> interchange = {}) {
   scf::SCFTilingOptions tilingOptions;
-  tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(context, tileSizes);
+  tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
   LinalgTransformationFilter filter(StringAttr::get(context, filterName),
                                     StringAttr::get(context, "tiled"));
   patterns.add<TestTileConsumerFuseAndYieldProducerUsingSCFForOp>(
@@ -475,8 +479,10 @@ static void addPatternForTileAndFuse(MLIRContext *context,
                                      ArrayRef<int64_t> tileSizes,
                                      ArrayRef<int64_t> interchange = {}) {
   scf::SCFTileAndFuseOptions tileAndFuseOptions;
-  tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
-      interchange);
+  SmallVector<OpFoldResult> tileSizesOfr =
+      getAsIndexOpFoldResult(context, tileSizes);
+  tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr)
+      .setInterchange(interchange);
   LinalgTransformationFilter filter(StringAttr::get(context, filterName),
                                     StringAttr::get(context, "tiled"));
   patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(



More information about the Mlir-commits mailing list