[Mlir-commits] [mlir] c584771 - Revert "[mlir][TilingInterface] Enable tile and fuse using TilingInterface."
Mahesh Ravishankar
llvmlistbot at llvm.org
Tue Jun 21 09:57:23 PDT 2022
Author: Mahesh Ravishankar
Date: 2022-06-21T16:56:59Z
New Revision: c584771f54cf94bb396c22f5cca895dd3f23c245
URL: https://github.com/llvm/llvm-project/commit/c584771f54cf94bb396c22f5cca895dd3f23c245
DIFF: https://github.com/llvm/llvm-project/commit/c584771f54cf94bb396c22f5cca895dd3f23c245.diff
LOG: Revert "[mlir][TilingInterface] Enable tile and fuse using TilingInterface."
This reverts commit ea75511319d9dff8c38c8794c3949c40b63a38d7 due to build failures.
Added:
Modified:
mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/include/mlir/Interfaces/TilingInterface.td
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp
mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1f3ee8a5b27f6..6e8af767ff8a3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -10,12 +10,9 @@
#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
-#include <deque>
-
namespace mlir {
class Operation;
class PatternRewriter;
@@ -58,7 +55,7 @@ struct SCFTilingResult {
SmallVector<scf::ForOp> loops;
};
-/// Pattern to tile an op that implements the `TilingInterface` using
+/// Pattern to tile an op that implementas the `TilingInterface` using
/// `scf.for` for iterating over the tiles.
struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
/// Construct a generic pattern applied to all TilingInterface ops.
@@ -84,56 +81,6 @@ struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
SCFTilingOptions options;
};
-/// Pattern to tile and fuse a sequence of operations, by tiling the consumer
-/// and fusing its producers. Note that this assumes that it is valid to
-/// tile+fuse the producer into the innermost tiled loop. Its up to the caller
-/// to ensure that the tile sizes provided make this fusion valid.
-///
-/// For example, for the following sequence
-///
-/// ```mlir
-/// %0 = linalg.fill ...
-/// %1 = linalg.matmul ... outs(%0 : ...) ...
-/// ```
-///
-/// it is legal to fuse the fill with the matmul only if the matmul is tiled
-/// along the parallel dimensions and not the reduction dimension, i.e. the tile
-/// size for the reduction dimension should be 0.
-struct SCFTileAndFuseResult {
- SmallVector<Operation *> tiledAndFusedOps;
- SmallVector<scf::ForOp> loops;
-};
-struct TileConsumerAndFuseProducersUsingSCFForOp
- : public OpInterfaceRewritePattern<TilingInterface> {
-
- /// Construct a generic pattern applied to all TilingInterface ops.
- TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
- SCFTilingOptions options,
- PatternBenefit benefit = 1);
-
- /// Construct a generic pattern applied to `opName`.
- TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
- MLIRContext *context,
- SCFTilingOptions options,
- PatternBenefit benefit = 1);
-
- /// `matchAndRewrite` implementation that returns the significant transformed
- /// pieces of IR.
- FailureOr<SCFTileAndFuseResult>
- returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(TilingInterface op,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(op, rewriter);
- }
-
-private:
- /// This pattern uses the tiling pattern. Instead of using inheritance, use
- /// the patterns as private object that is instantiated at the same time as
- /// this pattern.
- TileUsingSCFForOp tilingPattern;
-};
-
} // namespace scf
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 28c22aecdf318..e6267e9cf02e5 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -9,7 +9,6 @@
#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
@@ -21,14 +20,6 @@ namespace tensor {
void populateSplitPaddingPatterns(RewritePatternSet &patterns,
PatternBenefit baseBenefit = 1);
-/// Pattern to swap an `tensor.extract_slice` with its producer when the
-/// producer implements the `TilingInterface`. The pattern itself does not
-/// provide a mechanism to control where the application happens. With use of
-/// transform dialect that control is done within the transform dialect. Other
-/// use cases can inherit from this pattern and add necessary controls.
-FailureOr<Value> replaceExtractSliceWithTiledProducer(
- OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
-
} // namespace tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index f3fdc30168b28..606901375ede8 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -120,48 +120,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
/*defaultImplementation=*/[{
return failure();
}]
- >,
- InterfaceMethod<
- /*desc=*/[{
- Method to generate the code that produces a tile of the result.
-
- Generates the IR that computes the tile of a result of the
- operation. The `offsets` and `sizes` describe the tile of
- the output required. This is
diff erent from
- `getTiledImplementation` which generates the tiled
- implementation of the operation given a tile of the
- iteration space. This method generates a tiled
- implementation of the operation based on the tile of the
- result required. This method enables fusion by using tile
- and fuse. The method returns failure if the operation can't be
- tiled to generate the result tile. In practical terms this
- implies it cannot be tiled and fused with its consumers.
-
- - `dest` are the Value into which the result of the tiled
- operation is to be inserted into. The type of the `dest`
- Values is same as the types returned by
- `getDestinationOperands` method.
- - `offsets` provides the offset of the tile within the
- iteration space
- - `sizes` provides the size of the tile.
- - `tileDestOperands` specifies whether to also tile `dest` operands
- or not. Avoiding tiling `dest` operands can be useful for
- composition with various looping container ops.
- }],
- /*retType=*/"FailureOr<Value>",
- /*methodName=*/"generateResultTileValue",
- /*args=*/(ins
- "OpBuilder &":$b,
- "unsigned":$resultNumber,
- "ValueRange":$dest,
- "ArrayRef<OpFoldResult>":$offsets,
- "ArrayRef<OpFoldResult>":$sizes,
- "bool":$tileDestOperands),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- return failure();
- }]
>
- ];
+ ];
}
#endif // MLIR_TILINGINTERFACE
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 88b21f15081f3..c67097ab3d695 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -30,6 +30,7 @@ template <typename LinalgOpTy>
struct LinalgOpTilingInterface
: public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
LinalgOpTy> {
+
/// Return the destination operands.
SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
return llvm::cast<LinalgOp>(op).getOutputOperands();
@@ -46,8 +47,6 @@ struct LinalgOpTilingInterface
/// Return the iteration domain range.
SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(op);
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
@@ -130,65 +129,16 @@ struct LinalgOpTilingInterface
resultSizes = sliceOp.getMixedSizes();
return success();
}
-
- FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
- unsigned resultNumber,
- ValueRange dest,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes,
- bool tileDestOperands) const {
- auto linalgOp = cast<LinalgOp>(op);
-
- // Check that the indexing map used for the output is a projected
- // permutation. This could be relaxed with a more general approach that can
- // map the offsets and sizes from the result to iteration space tiles
- // (filling in full extent for dimensions not used to access the result).
- AffineMap indexingMap =
- linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber));
- if (!indexingMap.isProjectedPermutation()) {
- return op->emitOpError(
- "unhandled tiled implementation generation when result is not "
- "accessed using a permuted projection");
- }
-
- auto numLoops = linalgOp.getNumLoops();
- auto tilingInterfaceOp = cast<TilingInterface>(op);
- SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
- iterationTileSizes(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (auto range : llvm::enumerate(iterationDomain)) {
- iterationTileOffsets[range.index()] = range.value().offset;
- iterationTileSizes[range.index()] = range.value().size;
- }
- }
- for (auto resultExpr : llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition =
- resultExpr.value().cast<AffineDimExpr>().getPosition();
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
- iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
- }
-
- SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
- b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands);
- if (tiledOp.size() != 1)
- return op->emitOpError("failed to generate tiled implementation");
-
- return tiledOp[0]->getResult(resultNumber);
- }
};
} // namespace
-template <typename OpType>
-static void registerOne(MLIRContext *ctx) {
+template <typename OpType> static void registerOne(MLIRContext *ctx) {
OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
}
/// Variadic helper function.
-template <typename... OpTypes>
-static void registerAll(MLIRContext *ctx) {
+template <typename... OpTypes> static void registerAll(MLIRContext *ctx) {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1bad67f3d7f4d..4646abcf3e8d1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -42,10 +42,6 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
return *this;
}
-//===----------------------------------------------------------------------===//
-// TileUsingSCFForOp pattern implementation.
-//===----------------------------------------------------------------------===//
-
/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
@@ -251,155 +247,3 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
rewriter.replaceOp(op, tilingResult.loops.front().getResults());
return tilingResult;
}
-
-//===----------------------------------------------------------------------===//
-// TileConsumerAndFuseProducersUsingSCFForOp pattern implementation.
-//===----------------------------------------------------------------------===//
-
-scf::TileConsumerAndFuseProducersUsingSCFForOp::
- TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
- scf::SCFTilingOptions options,
- PatternBenefit benefit)
- : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
- tilingPattern(context, std::move(options)) {}
-
-scf::TileConsumerAndFuseProducersUsingSCFForOp::
- TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
- MLIRContext *context,
- scf::SCFTilingOptions options,
- PatternBenefit benefit)
- : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
- tilingPattern(context, std::move(options)) {}
-
-/// Return the `Value` that is defined by an operation that implements
-/// the `TilingInterface`. Looks through `iter_args` of scf.for nest
-/// if required.
-static Optional<OpResult> getFusableProducer(Value v) {
- while (auto blockArg = v.dyn_cast<BlockArgument>()) {
- auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
- if (!loopOp)
- return llvm::None;
- v = loopOp.getOpOperandForRegionIterArg(blockArg).get();
- }
- if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp()))
- return llvm::None;
- return v.cast<OpResult>();
-}
-
-FailureOr<scf::SCFTileAndFuseResult>
-scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
- TilingInterface op, PatternRewriter &rewriter) const {
- // This transformation is only valid for ops that return values (i.e. not
- // valid to use with operations that have memref operands).
- if (!op->getNumResults()) {
- return rewriter.notifyMatchFailure(
- op, "invalid pattern for op with no results");
- }
-
- // 1. First tile the consumer.
- SCFTileAndFuseResult tileAndFuseResult;
- {
- FailureOr<SCFTilingResult> tilingResult =
- tilingPattern.returningMatchAndRewrite(op, rewriter);
- if (failed(tilingResult)) {
- return failure();
- }
- tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
- tileAndFuseResult.loops = std::move(tilingResult->loops);
- }
-
- // 2. Typically, the operands of the tiled operation are slices of the
- // operands of the untiled operation. These are expressed in IR using
- // `tensor.extract_slice` operations with source being the operands of the
- // untiled operation. Create a worklist of these `tensor.extract_slice`
- // operations. If the producers of the source of the `tensor.extract_slice`
- // can be tiled such that the tiled value is generated in-place, that
- // effectively tiles + fuses the operations.
- auto addCandidateSlices = [](Operation *fusedOp,
- std::deque<tensor::ExtractSliceOp> &candidates) {
- for (Value operand : fusedOp->getOperands())
- if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
- candidates.push_back(sliceOp);
- };
-
- std::deque<tensor::ExtractSliceOp> candidates;
- addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
- OpBuilder::InsertionGuard g(rewriter);
- while (!candidates.empty()) {
- // 2a. Traverse the slices in BFS fashion.
- tensor::ExtractSliceOp candidateSliceOp = candidates.front();
- candidates.pop_front();
-
- // 2b. Get the producer of the source (potentially walking through
- // `iter_args` of nested `scf.for`)
- Optional<OpResult> fusableProducer =
- getFusableProducer(candidateSliceOp.source());
- if (!fusableProducer)
- continue;
-
- // 2c. Generate the tiled implementation of the producer of the source
- rewriter.setInsertionPoint(candidateSliceOp);
- FailureOr<Value> fusedProducerValue =
- tensor::replaceExtractSliceWithTiledProducer(
- rewriter, candidateSliceOp, fusableProducer.getValue());
- if (failed(fusedProducerValue))
- continue;
- rewriter.replaceOp(candidateSliceOp, fusedProducerValue.getValue());
-
- // 2d. The operands of the fused producer might themselved be slices of
- // values produced by operations that implement the `TilingInterface`.
- // Add these operations to the worklist.
- Operation *fusedProducer = fusedProducerValue->getDefiningOp();
- tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
- addCandidateSlices(fusedProducer, candidates);
-
- // 2e. If the operation being fused creates a value that is used as `outs`
- // in the tiled operation, the result of the unfused operation will be
- // used in the `iter_args` of the tiled loop generated. When the
- // operation is fused, this use in `iter_args` needs to be modified to
- // use the destination of the fused operation. For example, starting
- // with
- //
- // ```mlir
- // %0 = linalg.init_tensor ...
- // %1 = linalg.fill ... outs(%0:...)...
- // %2 = linalg.matmul ... outs(%1:...)....
- // ```
- //
- // First the `linalg.matmul` gets tiled
- //
- // ```mlir
- // %0 = linalg.init_tensor
- // %1 = linalg.fill
- // %2 = scf.for .... iter_args(%arg0 = %1)...
- // ...
- // ... = linalg.matmul ...
- //
- // ```
- //
- // When the `linalg.fill` gets fused, the `iter_args` needs to be
- // modified
- //
- // ```mlir
- // %0 = linalg.init_tensor
- // %1 = scf.for ... iter_args(%arg0 = %0)...
- // ...
- // %2 = linalg.fill ...
- // %3 = linalg.matmul ... outs(%2: ...)...
- // ```
- TilingInterface unfusedProducerOp =
- cast<TilingInterface>(fusableProducer->getOwner());
- scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front();
- SmallVector<Value> unfusedProducerOpDestValues =
- unfusedProducerOp.getDestinationOperands(rewriter);
- for (OpOperand &uses : unfusedProducerOp->getUses()) {
- if (uses.getOwner() == outerMostTiledLoop.getOperation()) {
- unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber();
- unsigned operandNumber = uses.getOperandNumber();
- outerMostTiledLoop->setOperand(
- operandNumber, unfusedProducerOpDestValues[resultNumber]);
- }
- }
- }
- return tileAndFuseResult;
-}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 8479c43211e83..f4983c4d5c886 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -2,7 +2,6 @@ add_mlir_dialect_library(MLIRTensorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
SplitPadding.cpp
- SwapExtractSliceWithProducer.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
@@ -19,6 +18,5 @@ add_mlir_dialect_library(MLIRTensorTransforms
MLIRPass
MLIRSCFDialect
MLIRTensorDialect
- MLIRTilingInterface
MLIRTransforms
)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp
deleted file mode 100644
index 8d570cfdf7592..0000000000000
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducer.cpp
+++ /dev/null
@@ -1,43 +0,0 @@
-//===- SwapExtractSliceWithProducer.cpp - Swapping `tensor.extract_slice` ---=//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Swap a `tensor.extract_slice` with the producer of the source if the producer
-// implements the `TilingInterface`. When used in conjunction with tiling this
-// effectively tiles + fuses the producer with its consumer.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/Interfaces/TilingInterface.h"
-
-using namespace mlir;
-
-FailureOr<Value> tensor::replaceExtractSliceWithTiledProducer(
- OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
- auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
- if (!producerOp)
- return failure();
-
- // `TilingInterface` currently only supports strides being 1.
- if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 1);
- }))
- return failure();
-
- FailureOr<Value> tiledResult = producerOp.generateResultTileValue(
- builder, producer.getResultNumber(),
- producerOp.getDestinationOperands(builder), sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes(), true);
- if (failed(tiledResult))
- return failure();
-
- return tiledResult.getValue();
-}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
deleted file mode 100644
index dd77211d8ccc6..0000000000000
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ /dev/null
@@ -1,185 +0,0 @@
-// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCheck %s
-
-func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %cst = arith.constant 0.0 : f32
- %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
- %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
- %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
- %gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
- return %gemm : tensor<?x?xf32>
-}
-// CHECK: func.func @gemm_fill_fusion(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
-// CHECK: %[[INIT:.+]] = linalg.init_tensor
-// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
-// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
-// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
-// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
-// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
-// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
-// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
-// CHECK: %[[FILL_TILE:.+]] = linalg.fill
-// CHECK-SAME: outs(%[[INIT_TILE]] :
-// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
-// CHECK-SAME: outs(%[[FILL_TILE]] :
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
-// CHECK scf.yield %[[INSERT]]
-
-// -----
-
-func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
- %arg2 : tensor<?xf32>) -> tensor<?x?xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %cst = arith.constant 0.0 : f32
- %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
- %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
- %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
- %gemm = linalg.matmul
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
- %generic = linalg.generic {
- __internal_linalg_transform__ = "fusion",
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
- ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
- %add = arith.addf %b0, %b1 : f32
- linalg.yield %add : f32
- } -> tensor<?x?xf32>
- return %generic : tensor<?x?xf32>
-}
-// CHECK: func.func @gemm_generic_fusion(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf32>)
-// CHECK: %[[INIT:.+]] = linalg.init_tensor
-// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
-// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT]])
-// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
-// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
-// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
-// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
-// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
-// CHECK: %[[FILL_TILE:.+]] = linalg.fill
-// CHECK-SAME: outs(%[[INIT_TILE]] :
-// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
-// CHECK-SAME: outs(%[[FILL_TILE]] :
-// CHECK-DAG: %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]]
-// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
-// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[GEMM_TILE]], %[[BIAS_TILE]] :
-// CHECK-SAME: outs(%[[OUTS_TILE]] :
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]]
-// CHECK scf.yield %[[INSERT]]
-
-// -----
-
-func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %cst = arith.constant 0.0 : f32
- %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
- %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
- %init0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
- %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %gemm0 = linalg.matmul
- ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
- %init1 = linalg.init_tensor [%d0, %d2] : tensor<?x?xf32>
- %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"}
- ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
- return %gemm1 : tensor<?x?xf32>
-}
-// CHECK: func.func @gemm_gemm_fusion(
-// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]]
-// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
-// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]]
-// CHECK: %[[INIT1:.+]] = linalg.init_tensor [%[[D0]], %[[D2]]]
-// CHECK: scf.for %[[IV:[a-zA-Z0-9]+]] =
-// CHECK-SAME: iter_args(%[[ITERARG:.+]] = %[[INIT1]])
-// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
-// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
-// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]], 0]
-// CHECK: %[[FILL0_TILE:.+]] = linalg.fill
-// CHECK-SAME: outs(%[[INIT0_TILE]] :
-// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
-// CHECK-SAME: outs(%[[FILL0_TILE]] :
-// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
-// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
-// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
-// CHECK-SAME: outs(%[[INIT1_TILE]] :
-// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
-// CHECK-SAME: outs(%[[FILL1_TILE]] :
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0]
-// CHECK scf.yield %[[INSERT]]
-
-// -----
-
-func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %cst = arith.constant 0.0 : f32
- %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
- %init0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
- %fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
- %gemm = linalg.matmul
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
- %init1 = linalg.init_tensor [%d1, %d0] : tensor<?x?xf32>
- %transpose = linalg.generic {
- __internal_linalg_transform__ = "fusion",
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) {
- ^bb0(%b0 : f32, %b1 : f32):
- linalg.yield %b0 : f32
- } -> tensor<?x?xf32>
- return %transpose : tensor<?x?xf32>
-}
-// CHECK: func.func @gemm_transpose_fusion(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
-// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
-// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]]
-// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] =
-// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT1]])
-// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] =
-// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
-// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
-// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
-// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV0]], %[[IV1]]]
-// CHECK: %[[FILL_TILE:.+]] = linalg.fill
-// CHECK-SAME: outs(%[[INIT0_TILE]] :
-// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
-// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
-// CHECK-SAME: outs(%[[FILL_TILE]] :
-// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
-// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[GEMM_TILE]] :
-// CHECK-SAME: outs(%[[OUTS_TILE]] :
-// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
-// CHECK scf.yield %[[INSERT]]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index a7367a713ff4f..1e094329db66f 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-tiling-interface=tile-using-scf-for -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s
func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 6241603d6a679..f3ba7a1c5f52d 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -29,9 +29,8 @@ using namespace mlir;
namespace {
-/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
-/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
-/// using a `filter` to avoid recursive application.
+/// Construct a generic pattern applied to all TilingInterface ops that verify
+/// `filter`.
struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
TestTileUsingSCFForOpWithFilter(MLIRContext *context,
scf::SCFTilingOptions options,
@@ -53,7 +52,8 @@ struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
- auto tilingResult = returningMatchAndRewrite(op, rewriter);
+ FailureOr<scf::SCFTilingResult> tilingResult =
+ returningMatchAndRewrite(op, rewriter);
if (failed(tilingResult)) {
return failure();
}
@@ -65,50 +65,6 @@ struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
linalg::LinalgTransformationFilter filter;
};
-/// Pattern for testing `TileConsumerAndFUseProducersUsingSCFForOp` pattern
-/// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
-/// ops for iterating over the tiles) while using a `filter` to avoid recursive
-/// application.
-struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter
- : public scf::TileConsumerAndFuseProducersUsingSCFForOp {
- TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
- MLIRContext *context, scf::SCFTilingOptions options,
- linalg::LinalgTransformationFilter filter =
- linalg::LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options,
- benefit),
- filter(filter) {}
-
- /// Construct a generic pattern applied to `opName`.
- TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
- StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
- linalg::LinalgTransformationFilter filter =
- linalg::LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options,
- benefit),
- filter(filter) {}
-
- LogicalResult matchAndRewrite(TilingInterface op,
- PatternRewriter &rewriter) const override {
- if (failed(filter.checkAndNotify(rewriter, op)))
- return failure();
-
- auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter);
- if (failed(tileAndFuseResult)) {
- return failure();
- }
- filter.replaceLinalgTransformationFilter(
- rewriter, tileAndFuseResult->tiledAndFusedOps.front());
- return success();
- }
-
-private:
- linalg::LinalgTransformationFilter filter;
-};
-
-/// Test pass for testing the use of `TilingInterface`.
struct TestTilingInterfacePass
: public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass)
@@ -126,63 +82,29 @@ struct TestTilingInterfacePass
return "Test tiling using TilingInterface";
}
- Option<bool> testTiling{
- *this, "tile-using-scf-for",
- llvm::cl::desc(
- "Test tiling using TilingInterface with scf.for operations"),
- llvm::cl::init(false)};
-
- Option<bool> testTileConsumerAndFuseProducer{
- *this, "tile-consumer-and-fuse-producer-using-scf-for",
- llvm::cl::desc("Test tile and fuse transformation using TilingInterface "
- "with scf.for operations"),
- llvm::cl::init(false)};
-
void runOnOperation() override;
-
-private:
- void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns);
};
} // namespace
-template <class Pattern>
-static void
-addPatternForTiling(MLIRContext *context, ArrayRef<int64_t> tileSizes,
- StringRef filterName, RewritePatternSet &patterns) {
- scf::SCFTilingOptions tilingOptions;
- tilingOptions.setTileSizes(tileSizes);
- linalg::LinalgTransformationFilter filter(
- StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
- patterns.add<Pattern>(context, tilingOptions, filter);
-}
-
-void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
- RewritePatternSet &patterns) {
- if (testTiling) {
- // 1. Tiling M and N dims of `linalg.matmul` on tensors.
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, {10, 20}, "simple_gemm", patterns);
- // 2. Tiling M, N and K of `linalg.matmul` on buffers.
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, {10, 20, 30}, "simple_gemm_memref", patterns);
- // 3. Tiling 3D parallel generic op which implements a transpose
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, {10, 0, 20}, "parallel_generic_transpose", patterns);
- // 4. Tiling 2D conv op.
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns);
- return;
- }
- if (testTileConsumerAndFuseProducer) {
- // 1. Tile and fuse of gemm with bias-add operation.
- addPatternForTiling<
- TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
- context, {10, 20}, "fusion", patterns);
- addPatternForTiling<
- TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
- context, {10}, "gemm_fusion", patterns);
- return;
- }
+static void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) {
+ auto addPatternForTiling = [&](ArrayRef<int64_t> tileSizes,
+ StringRef filterName) {
+ scf::SCFTilingOptions tilingOptions;
+ tilingOptions.setTileSizes(tileSizes);
+ linalg::LinalgTransformationFilter filter(
+ StringAttr::get(context, filterName),
+ StringAttr::get(context, "tiled"));
+ patterns.add<TestTileUsingSCFForOpWithFilter>(context, tilingOptions,
+ filter);
+ };
+ // 1. Tiling M and N dims of `linalg.matmul` on tensors.
+ addPatternForTiling({10, 20}, "simple_gemm");
+ // 2. Tiling M, N and K of `linalg.matmul` on buffers.
+ addPatternForTiling({10, 20, 30}, "simple_gemm_memref");
+ // 3. Tiling 3D parallel generic op which implements a transpose
+ addPatternForTiling({10, 0, 20}, "parallel_generic_transpose");
+ // 4. Tiling 2D conv op.
+ addPatternForTiling({0, 0, 0, 0, 10, 20, 30}, "simple_conv");
}
void TestTilingInterfacePass::runOnOperation() {
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index f0813db443a5c..8ef7787306916 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1881,7 +1881,6 @@ cc_library(
":SCFUtils",
":Support",
":TensorDialect",
- ":TensorTransforms",
":TilingInterface",
":Transforms",
"//llvm:Support",
@@ -5029,7 +5028,6 @@ cc_library(
":SCFDialect",
":TensorDialect",
":TensorPassIncGen",
- ":TilingInterface",
":Transforms",
"//llvm:Support",
],
More information about the Mlir-commits
mailing list