[Mlir-commits] [mlir] 06bb9cf - [mlir][linalg] Remove IndexedGenericOp support from LinalgInterchangePattern...
Tobias Gysi
llvmlistbot at llvm.org
Wed May 12 06:02:36 PDT 2021
Author: Tobias Gysi
Date: 2021-05-12T13:01:37Z
New Revision: 06bb9cf30d11247540d5b3f2a714f3aa640353e6
URL: https://github.com/llvm/llvm-project/commit/06bb9cf30d11247540d5b3f2a714f3aa640353e6
DIFF: https://github.com/llvm/llvm-project/commit/06bb9cf30d11247540d5b3f2a714f3aa640353e6.diff
LOG: [mlir][linalg] Remove IndexedGenericOp support from LinalgInterchangePattern...
after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612).
Differential Revision: https://reviews.llvm.org/D102245
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index de0f5888550f..93c99038e322 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -213,8 +213,8 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`op.rank` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
-void interchange(PatternRewriter &rewriter, LinalgOp op,
- ArrayRef<unsigned> interchangeVector);
+void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
+ ArrayRef<unsigned> interchangeVector);
/// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
@@ -363,11 +363,11 @@ LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
// Preconditions that ensure the corresponding transformation succeeds and can
// be applied as a rewrite pattern.
//===----------------------------------------------------------------------===//
-/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
-/// and `iterator_types` permutated according to `permutation`.
+/// Emits a `generic` operation with the `indexing_maps` and `iterator_types`
+/// permutated according to `permutation`.
LogicalResult
-interchangeGenericLinalgOpPrecondition(Operation *op,
- ArrayRef<unsigned> interchangeVector);
+interchangeGenericOpPrecondition(GenericOp genericOp,
+ ArrayRef<unsigned> interchangeVector);
/// Promote std.subviews feeding linalg operations.
LogicalResult promoteSubviewsPrecondition(Operation *op,
@@ -630,18 +630,18 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
};
///
-/// Linalg interchange patterns.
+/// Linalg generic interchage pattern.
///
/// Apply the `interchange` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `interchange` for more details.
-struct LinalgBaseInterchangePattern : public RewritePattern {
- LinalgBaseInterchangePattern(
- StringRef opName, MLIRContext *context,
- ArrayRef<unsigned> interchangeVector,
+struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
+ using OpRewritePattern<GenericOp>::OpRewritePattern;
+ GenericOpInterchangePattern(
+ MLIRContext *context, ArrayRef<unsigned> interchangeVector,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
- LogicalResult matchAndRewrite(Operation *op,
+ LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override;
private:
@@ -651,16 +651,6 @@ struct LinalgBaseInterchangePattern : public RewritePattern {
SmallVector<unsigned, 8> interchangeVector;
};
-template <typename OpTy>
-struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
- LinalgInterchangePattern(
- MLIRContext *context, ArrayRef<unsigned> interchangeVector,
- LinalgTransformationFilter filter = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
- interchangeVector, filter, benefit) {}
-};
-
///
/// Linalg promotion patterns.
///
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index 6d13765b7f54..e03d8cb01bc1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -32,68 +32,65 @@
using namespace mlir;
using namespace mlir::linalg;
-LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
- Operation *op, ArrayRef<unsigned> interchangeVector) {
- // Transformation applies to generic ops only.
- if (!isa<GenericOp, IndexedGenericOp>(op))
- return failure();
- LinalgOp linalgOp = cast<LinalgOp>(op);
+LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
+ GenericOp genericOp, ArrayRef<unsigned> interchangeVector) {
// Interchange vector must be non-empty and match the number of loops.
if (interchangeVector.empty() ||
- linalgOp.getNumLoops() != interchangeVector.size())
+ genericOp.getNumLoops() != interchangeVector.size())
return failure();
// Permutation map must be invertible.
- if (!inversePermutation(
- AffineMap::getPermutationMap(interchangeVector, op->getContext())))
+ if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector,
+ genericOp.getContext())))
return failure();
return success();
}
-void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op,
- ArrayRef<unsigned> interchangeVector) {
+void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
+ GenericOp genericOp,
+ ArrayRef<unsigned> interchangeVector) {
// 1. Compute the inverse permutation map.
- MLIRContext *context = op.getContext();
+ MLIRContext *context = genericOp.getContext();
AffineMap permutationMap = inversePermutation(
AffineMap::getPermutationMap(interchangeVector, context));
assert(permutationMap && "expected permutation to be invertible");
- assert(interchangeVector.size() == op.getNumLoops() &&
+ assert(interchangeVector.size() == genericOp.getNumLoops() &&
"expected interchange vector to have entry for every loop");
// 2. Compute the interchanged indexing maps.
SmallVector<Attribute, 4> newIndexingMaps;
- ArrayRef<Attribute> indexingMaps = op.indexing_maps().getValue();
- for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) {
+ ArrayRef<Attribute> indexingMaps = genericOp.indexing_maps().getValue();
+ for (unsigned i = 0, e = genericOp.getNumShapedOperands(); i != e; ++i) {
AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
if (!permutationMap.isEmpty())
m = m.compose(permutationMap);
newIndexingMaps.push_back(AffineMapAttr::get(m));
}
- op->setAttr(getIndexingMapsAttrName(),
- ArrayAttr::get(context, newIndexingMaps));
+ genericOp->setAttr(getIndexingMapsAttrName(),
+ ArrayAttr::get(context, newIndexingMaps));
// 3. Compute the interchanged iterator types.
- ArrayRef<Attribute> itTypes = op.iterator_types().getValue();
+ ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
SmallVector<Attribute, 4> itTypesVector;
llvm::append_range(itTypesVector, itTypes);
applyPermutationToVector(itTypesVector, interchangeVector);
- op->setAttr(getIteratorTypesAttrName(),
- ArrayAttr::get(context, itTypesVector));
+ genericOp->setAttr(getIteratorTypesAttrName(),
+ ArrayAttr::get(context, itTypesVector));
// 4. Transform the index operations by applying the permutation map.
- if (op.hasIndexSemantics()) {
+ if (genericOp.hasIndexSemantics()) {
// TODO: Remove the assertion and add a getBody() method to LinalgOp
// interface once every LinalgOp has a body.
- assert(op->getNumRegions() == 1 &&
- op->getRegion(0).getBlocks().size() == 1 &&
+ assert(genericOp->getNumRegions() == 1 &&
+ genericOp->getRegion(0).getBlocks().size() == 1 &&
"expected generic operation to have one block.");
- Block &block = op->getRegion(0).front();
+ Block &block = genericOp->getRegion(0).front();
OpBuilder::InsertionGuard guard(rewriter);
for (IndexOp indexOp :
llvm::make_early_inc_range(block.getOps<IndexOp>())) {
rewriter.setInsertionPoint(indexOp);
SmallVector<Value> allIndices;
- allIndices.reserve(op.getNumLoops());
- llvm::transform(llvm::seq<uint64_t>(0, op.getNumLoops()),
+ allIndices.reserve(genericOp.getNumLoops());
+ llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()),
std::back_inserter(allIndices), [&](uint64_t dim) {
return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index acf460982784..736d298a713a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -393,30 +393,26 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
return success();
}
-/// Linalg base interchange pattern.
-mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
- StringRef opName, MLIRContext *context,
- ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter filter,
- PatternBenefit benefit)
- : RewritePattern(opName, benefit, context, {}), filter(filter),
+/// Linalg generic interchange pattern.
+mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
+ MLIRContext *context, ArrayRef<unsigned> interchangeVector,
+ LinalgTransformationFilter filter, PatternBenefit benefit)
+ : OpRewritePattern(context, benefit), filter(filter),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
-LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
- Operation *op, PatternRewriter &rewriter) const {
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
- if (!linalgOp)
+LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
+ GenericOp genericOp, PatternRewriter &rewriter) const {
+ if (failed(filter.checkAndNotify(rewriter, genericOp)))
return failure();
- if (failed(filter.checkAndNotify(rewriter, linalgOp)))
- return failure();
- if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
+ if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
return failure();
// TODO: figure out how this interplays with named ops. In particular this
// should break the named op property.
- rewriter.updateRootInPlace(op, [&]() {
- interchange(rewriter, linalgOp, interchangeVector);
+ rewriter.updateRootInPlace(genericOp, [&]() {
+ interchangeGenericOp(rewriter, genericOp, interchangeVector);
// New filter if specified.
- filter.replaceLinalgTransformationFilter(rewriter, op);
+ filter.replaceLinalgTransformationFilter(rewriter, genericOp);
});
return success();
}
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index f3092efbc580..347acb673f1f 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -125,37 +125,6 @@ func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
-#indexed_matmul_trait = {
- args_in = 2,
- args_out = 1,
- indexing_maps = #matmul_accesses,
- library_call = "linalg_matmul_indexed",
- iterator_types = ["parallel", "parallel", "reduction"]
-}
-func @permute_generic_indexed(
- %A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
- %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
- %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
- linalg.indexed_generic #indexed_matmul_trait
- ins(%A, %B : memref<?x?xf32, offset: ?, strides: [?, 1]>,
- memref<?x?xf32, offset: ?, strides: [?, 1]>)
- outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
- ^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
- %d = mulf %a, %b: f32
- %e = addf %c, %d: f32
- linalg.yield %e: f32
- }
- return
-}
-// CHECK-LABEL: func @permute_generic_indexed
-// CHECK: linalg.indexed_generic {
-// CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
-// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"],
-// CHECK-SAME: library_call = "linalg_matmul_indexed"}
-// CHECK: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
-// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
-// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
-
func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%x: memref<?xf32, offset: ?, strides: [1]>,
%y: memref<?xf32, offset: ?, strides: [1]>) {
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 94ab9b951c37..90df19c2009e 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -194,14 +194,9 @@ static void applyPatterns(FuncOp funcOp) {
.addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
//===--------------------------------------------------------------------===//
- // Linalg generic permutation patterns.
+ // Linalg generic interchange pattern.
//===--------------------------------------------------------------------===//
- patterns.add<LinalgInterchangePattern<GenericOp>>(
- ctx,
- /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
- LinalgTransformationFilter(ArrayRef<Identifier>{},
- Identifier::get("PERMUTED", ctx)));
- patterns.add<LinalgInterchangePattern<IndexedGenericOp>>(
+ patterns.add<GenericOpInterchangePattern>(
ctx,
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
LinalgTransformationFilter(ArrayRef<Identifier>{},
@@ -551,7 +546,7 @@ static void applyInterchangePattern(FuncOp funcOp,
ArrayRef<unsigned> interchangeVector) {
MLIRContext *context = funcOp.getContext();
RewritePatternSet interchangePattern(context);
- interchangePattern.add<LinalgInterchangePattern<GenericOp>>(
+ interchangePattern.add<GenericOpInterchangePattern>(
context, interchangeVector,
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("interchange", context)));
More information about the Mlir-commits
mailing list