[Mlir-commits] [mlir] [mlir][TilingInterface] Make the tiling set tile sizes function use `OpFoldResult`. (PR #66566)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 18 10:18:53 PDT 2023
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/66566
>From 3a8279c5f7d3498a68e818ea3fa8a21b8c761731 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 14 Sep 2023 01:19:44 +0000
Subject: [PATCH 1/2] Revert "[mlir][vector] Improve lowering to LLVM for
`minf`, `maxf` reductions"
This reverts commit dad9de0ae5360b18c890985d212bec266bf8c122.
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 62 +++++++++++--------
.../VectorToLLVM/vector-to-llvm.mlir | 16 +++--
2 files changed, 47 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 92f7aa69760395a..8c8d53f0d6df68f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -566,31 +566,35 @@ static Value createIntegerReductionComparisonOpLowering(
return result;
}
-namespace {
-template <typename Source>
-struct VectorToScalarMapper;
-template <>
-struct VectorToScalarMapper<LLVM::vector_reduce_fmaximum> {
- using Type = LLVM::MaximumOp;
-};
-template <>
-struct VectorToScalarMapper<LLVM::vector_reduce_fminimum> {
- using Type = LLVM::MinimumOp;
-};
-} // namespace
+/// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum
+/// with vector types.
+static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
+ Value rhs, bool isMin) {
+ auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
+ Type i1Type = builder.getI1Type();
+ if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
+ i1Type = VectorType::get(vecType.getShape(), i1Type);
+ Value cmp = builder.create<LLVM::FCmpOp>(
+ loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
+ lhs, rhs);
+ Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
+ Value isNan = builder.create<LLVM::FCmpOp>(
+ loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
+ Value nan = builder.create<LLVM::ConstantOp>(
+ loc, lhs.getType(),
+ builder.getFloatAttr(floatType,
+ APFloat::getQNaN(floatType.getFloatSemantics())));
+ return builder.create<LLVM::SelectOp>(loc, isNan, nan, sel);
+}
template <class LLVMRedIntrinOp>
-static Value
-createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter,
- Location loc, Type llvmType,
- Value vectorOperand, Value accumulator) {
+static Value createFPReductionComparisonOpLowering(
+ ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+ Value vectorOperand, Value accumulator, bool isMin) {
Value result = rewriter.create<LLVMRedIntrinOp>(loc, llvmType, vectorOperand);
- if (accumulator) {
- result =
- rewriter.create<typename VectorToScalarMapper<LLVMRedIntrinOp>::Type>(
- loc, result, accumulator);
- }
+ if (accumulator)
+ result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin);
return result;
}
@@ -763,13 +767,17 @@ class VectorReductionOpConversion
ReductionNeutralFPOne>(
rewriter, loc, llvmType, operand, acc, reassociateFPReductions);
} else if (kind == vector::CombiningKind::MINF) {
- result =
- createFPReductionComparisonOpLowering<LLVM::vector_reduce_fminimum>(
- rewriter, loc, llvmType, operand, acc);
+ // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
+ // NaNs/-0.0/+0.0 in the same way.
+ result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmin>(
+ rewriter, loc, llvmType, operand, acc,
+ /*isMin=*/true);
} else if (kind == vector::CombiningKind::MAXF) {
- result =
- createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmaximum>(
- rewriter, loc, llvmType, operand, acc);
+ // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
+ // NaNs/-0.0/+0.0 in the same way.
+ result = createFPReductionComparisonOpLowering<LLVM::vector_reduce_fmax>(
+ rewriter, loc, llvmType, operand, acc,
+ /*isMin=*/false);
} else
return failure();
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 514594240d22a1b..9a0287d241345b8 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1323,8 +1323,12 @@ func.func @reduce_fmax_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
}
// CHECK-LABEL: @reduce_fmax_f32(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
-// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmaximum(%[[A]]) : (vector<16xf32>) -> f32
-// CHECK: %[[R:.*]] = llvm.intr.maximum(%[[V]], %[[B]]) : (f32, f32) -> f32
+// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmax(%[[A]]) : (vector<16xf32>) -> f32
+// CHECK: %[[C0:.*]] = llvm.fcmp "ogt" %[[V]], %[[B]] : f32
+// CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
+// CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32
+// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+// CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32
// CHECK: return %[[R]] : f32
// -----
@@ -1335,8 +1339,12 @@ func.func @reduce_fmin_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
}
// CHECK-LABEL: @reduce_fmin_f32(
// CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
-// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fminimum(%[[A]]) : (vector<16xf32>) -> f32
-// CHECK: %[[R:.*]] = llvm.intr.minimum(%[[V]], %[[B]]) : (f32, f32) -> f32
+// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmin(%[[A]]) : (vector<16xf32>) -> f32
+// CHECK: %[[C0:.*]] = llvm.fcmp "olt" %[[V]], %[[B]] : f32
+// CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
+// CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32
+// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32
+// CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32
// CHECK: return %[[R]] : f32
// -----
>From b7f1db56902061c6ae961dff0a6af79fc14f2cc6 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Fri, 15 Sep 2023 18:24:23 -0700
Subject: [PATCH 2/2] [mlir][TilingInterface] Make the tiling set tile sizes
function use `OpFoldResult`.
---
.../SCF/Transforms/TileUsingInterface.h | 11 +----
.../TransformOps/LinalgTransformOps.cpp | 29 ++++++-----
.../SCF/Transforms/TileUsingInterface.cpp | 48 ++++++++-----------
.../Dialect/Linalg/transform-op-tile.mlir | 4 +-
.../TilingInterface/TestTilingInterface.cpp | 14 ++++--
5 files changed, 48 insertions(+), 58 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index e7bcd062d96525d..ca641c596c7b7bb 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -26,7 +26,7 @@ namespace mlir {
namespace scf {
using SCFTileSizeComputationFunction =
- std::function<SmallVector<Value>(OpBuilder &, Operation *)>;
+ std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
/// Options to use to control tiling.
struct SCFTilingOptions {
@@ -40,17 +40,10 @@ struct SCFTilingOptions {
tileSizeComputationFunction = std::move(fun);
return *this;
}
- /// Set the `tileSizeComputationFunction` to return the values `ts`. The
- /// values must not fold away when tiling. Otherwise, use a more robust
- /// `tileSizeComputationFunction`.
- SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
- tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
- return *this;
- }
/// Convenience function to set the `tileSizeComputationFunction` to a
/// function that computes tile sizes at the point they are needed. Allows
/// proper interaction with folding.
- SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
+ SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
/// The interchange vector to reorder the tiled loops.
SmallVector<int64_t> interchangeVector = {};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3421a3c169dbba1..bc6c4f851987841 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -472,7 +472,9 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
- tilingOptions = tilingOptions.setTileSizes(tileSizes);
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
+ tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions = tilingOptions;
LogicalResult result = applyTilingToAll(
@@ -922,7 +924,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
auto nextProducer = getNextProducer();
if (failed(nextProducer)) {
auto diag = mlir::emitSilenceableFailure(getLoc())
- << "could not find next producer to fuse into container";
+ << "could not find next producer to fuse into container";
diag.attachNote(containingOp->getLoc()) << "containing op";
return diag;
}
@@ -1994,7 +1996,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
transform::TransformState &state) {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
- SmallVector<Value, 4> tileSizes;
+ SmallVector<OpFoldResult> tileSizes;
Location loc = target.getLoc();
SmallVector<OpFoldResult> allShapeSizes =
target.createFlatListOfOperandDims(b, loc);
@@ -2007,9 +2009,8 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
// If the shape size is dynamic, tile by 1.
// Otherwise, do not tile (i.e. tile size 0).
for (OpFoldResult shapeSize : shapeSizes) {
- tileSizes.push_back(getConstantIntValue(shapeSize)
- ? b.create<arith::ConstantIndexOp>(loc, 0)
- : b.create<arith::ConstantIndexOp>(loc, 1));
+ tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
+ : b.getIndexAttr(1));
}
return tileSizes;
});
@@ -2535,7 +2536,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
if (!tileSizes.empty()) {
tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
Operation *) {
- SmallVector<Value, 4> sizes;
+ SmallVector<OpFoldResult> sizes;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
@@ -2546,10 +2547,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
getLoc(), attr.cast<IntegerAttr>().getInt());
Value vscale =
b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
- sizes.push_back(b.create<arith::MulIOp>(getLoc(), val, vscale));
+ sizes.push_back(
+ b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
} else {
- sizes.push_back(b.create<arith::ConstantIndexOp>(
- getLoc(), cast<IntegerAttr>(attr).getInt()));
+ sizes.push_back(attr);
}
continue;
}
@@ -2559,8 +2560,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
assert((dynamicSizes.empty() ^ params.empty()) &&
"expected either dynamic sizes or parameters");
if (!params.empty()) {
- sizes.push_back(
- b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
+ sizes.push_back(b.getIndexAttr(params[index]));
} else {
sizes.push_back(dynamicSizes[index]->getResult(0));
}
@@ -2974,13 +2974,12 @@ transform::TileToScfForOp::apply(transform::TransformRewriter &rewriter,
if (!tileSizes.empty()) {
tilingOptions.setTileSizeComputationFunction(
[&, index](OpBuilder &b, Operation *) {
- SmallVector<Value, 4> sizes;
+ SmallVector<OpFoldResult> sizes;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
for (OpFoldResult ofr : getMixedSizes()) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
- sizes.push_back(b.create<arith::ConstantIndexOp>(
- getLoc(), cast<IntegerAttr>(attr).getInt()));
+ sizes.push_back(attr);
} else {
sizes.push_back(
dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 597676a017bf482..910d5e4f4f1100f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -31,19 +31,11 @@
using namespace mlir;
scf::SCFTilingOptions &
-scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
+scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
assert(!tileSizeComputationFunction && "tile sizes already set");
- SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
+ auto tileSizes = llvm::to_vector(ts);
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(
- &op->getParentWithTrait<OpTrait::IsIsolatedFromAbove>()
- ->getRegion(0)
- .front());
- return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
- Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
- return v;
- }));
+ return tileSizes;
};
return *this;
}
@@ -108,15 +100,14 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
-/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
+/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
/// the
/// tile processed within the inner most loop.
-static SmallVector<scf::ForOp>
-generateTileLoopNest(OpBuilder &builder, Location loc,
- ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
- SmallVector<OpFoldResult> &offsets,
- SmallVector<OpFoldResult> &sizes) {
+static SmallVector<scf::ForOp> generateTileLoopNest(
+ OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
+ ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
+ SmallVector<OpFoldResult> &sizes) {
assert(!loopRanges.empty() && "expected at least one loop range");
assert(loopRanges.size() == tileSizeVals.size() &&
"expected as many tile sizes as loop ranges");
@@ -130,7 +121,8 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
Value size =
getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
- Value tileSize = tileSizeVals[loopRange.index()];
+ Value tileSize = getValueOrCreateConstantIndexOp(
+ builder, loc, tileSizes[loopRange.index()]);
// No loops if tile size is zero. Set offset and size to the loop
// offset and size.
if (matchPattern(tileSize, m_Zero())) {
@@ -296,10 +288,10 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
// dimensions.
- SmallVector<Value> tileSizeVector =
+ SmallVector<OpFoldResult> tileSizeVector =
options.tileSizeComputationFunction(rewriter, op);
if (tileSizeVector.size() < iterationDomain.size()) {
- auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+ auto zero = rewriter.getIndexAttr(0);
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
}
@@ -402,17 +394,17 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
FailureOr<scf::SCFReductionTilingResult>
mlir::scf::tileReductionUsingScf(RewriterBase &b,
PartialReductionOpInterface op,
- ArrayRef<OpFoldResult> tileSize) {
+ ArrayRef<OpFoldResult> tileSizes) {
Location loc = op.getLoc();
// Ops implementing PartialReductionOpInterface are expected to implement
// TilingInterface.
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
- SmallVector<Value> tileSizeVector =
- getValueOrCreateConstantIndexOp(b, loc, tileSize);
- if (tileSizeVector.size() < iterationDomain.size()) {
- auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
- tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
+ auto tileSizesVector = llvm::to_vector(tileSizes);
+ if (tileSizesVector.size() < iterationDomain.size()) {
+ auto zero = b.getIndexAttr(0);
+ tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
+ zero);
}
if (op->getNumResults() != 1)
return b.notifyMatchFailure(
@@ -429,7 +421,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
// 1. create the inital tensor value.
FailureOr<Operation *> identityTensor =
- op.generateInitialTensorForPartialReduction(b, loc, tileSize,
+ op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
reductionDims);
if (failed(identityTensor))
return b.notifyMatchFailure(op,
@@ -437,7 +429,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
// 2. Create the nested loops.
SmallVector<OpFoldResult> offsets, sizes;
SmallVector<scf::ForOp> loops = generateTileLoopNest(
- b, loc, iterationDomain, tileSizeVector, offsets, sizes);
+ b, loc, iterationDomain, tileSizesVector, offsets, sizes);
// 3. Generate the tiled implementation within the inner most loop.
b.setInsertionPoint(loops.back().getBody()->getTerminator());
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index d4629dcb29c3efc..2608b703898611b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -190,16 +190,16 @@ transform.sequence failures(propagate) {
// -----
// CHECK-LABEL: func.func @scalable_and_fixed_length_tile
-// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
-// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[VS:.*]] = vector.vscale
// CHECK: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
// CHECK: %[[C128_1:.*]] = arith.constant 128 : index
+// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
// CHECK: scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
// CHECK: %[[C128_2:.*]] = arith.constant 128 : index
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 752c885e0b87bdb..2fcc7bcadb60450 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -450,7 +450,9 @@ static void addPatternForTiling(MLIRContext *context,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
- tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(context, tileSizes);
+ tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
@@ -462,7 +464,9 @@ static void addPatternForTileFuseAndYield(MLIRContext *context,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
- tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(context, tileSizes);
+ tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileConsumerFuseAndYieldProducerUsingSCFForOp>(
@@ -475,8 +479,10 @@ static void addPatternForTileAndFuse(MLIRContext *context,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> interchange = {}) {
scf::SCFTileAndFuseOptions tileAndFuseOptions;
- tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
- interchange);
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(context, tileSizes);
+ tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr)
+ .setInterchange(interchange);
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(
More information about the Mlir-commits
mailing list