[Mlir-commits] [mlir] 06bb9cf - [mlir][linalg] Remove IndexedGenericOp support from LinalgInterchangePattern...

Tobias Gysi llvmlistbot at llvm.org
Wed May 12 06:02:36 PDT 2021


Author: Tobias Gysi
Date: 2021-05-12T13:01:37Z
New Revision: 06bb9cf30d11247540d5b3f2a714f3aa640353e6

URL: https://github.com/llvm/llvm-project/commit/06bb9cf30d11247540d5b3f2a714f3aa640353e6
DIFF: https://github.com/llvm/llvm-project/commit/06bb9cf30d11247540d5b3f2a714f3aa640353e6.diff

LOG: [mlir][linalg] Remove IndexedGenericOp support from LinalgInterchangePattern...

after introducing the IndexedGenericOp to GenericOp canonicalization (https://reviews.llvm.org/D101612).

Differential Revision: https://reviews.llvm.org/D102245

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index de0f5888550f..93c99038e322 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -213,8 +213,8 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
 /// integers, in the range 0..`op.rank` without duplications
 /// (i.e. `[1,1,2]` is an invalid permutation).
-void interchange(PatternRewriter &rewriter, LinalgOp op,
-                 ArrayRef<unsigned> interchangeVector);
+void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
+                          ArrayRef<unsigned> interchangeVector);
 
 /// Callback function type used to perform the allocation for the promoted
 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
@@ -363,11 +363,11 @@ LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
 // Preconditions that ensure the corresponding transformation succeeds and can
 // be applied as a rewrite pattern.
 //===----------------------------------------------------------------------===//
-/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
-/// and `iterator_types` permutated according to `permutation`.
+/// Emits a `generic` operation with the `indexing_maps` and `iterator_types`
+/// permutated according to `permutation`.
 LogicalResult
-interchangeGenericLinalgOpPrecondition(Operation *op,
-                                       ArrayRef<unsigned> interchangeVector);
+interchangeGenericOpPrecondition(GenericOp genericOp,
+                                 ArrayRef<unsigned> interchangeVector);
 
 /// Promote std.subviews feeding linalg operations.
 LogicalResult promoteSubviewsPrecondition(Operation *op,
@@ -630,18 +630,18 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
 };
 
 ///
-/// Linalg interchange patterns.
+/// Linalg generic interchage pattern.
 ///
 /// Apply the `interchange` transformation as a pattern.
 /// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `interchange` for more details.
-struct LinalgBaseInterchangePattern : public RewritePattern {
-  LinalgBaseInterchangePattern(
-      StringRef opName, MLIRContext *context,
-      ArrayRef<unsigned> interchangeVector,
+struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+  GenericOpInterchangePattern(
+      MLIRContext *context, ArrayRef<unsigned> interchangeVector,
       LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
-  LogicalResult matchAndRewrite(Operation *op,
+  LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override;
 
 private:
@@ -651,16 +651,6 @@ struct LinalgBaseInterchangePattern : public RewritePattern {
   SmallVector<unsigned, 8> interchangeVector;
 };
 
-template <typename OpTy>
-struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
-  LinalgInterchangePattern(
-      MLIRContext *context, ArrayRef<unsigned> interchangeVector,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
-      : LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
-                                     interchangeVector, filter, benefit) {}
-};
-
 ///
 /// Linalg promotion patterns.
 ///

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index 6d13765b7f54..e03d8cb01bc1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -32,68 +32,65 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
-    Operation *op, ArrayRef<unsigned> interchangeVector) {
-  // Transformation applies to generic ops only.
-  if (!isa<GenericOp, IndexedGenericOp>(op))
-    return failure();
-  LinalgOp linalgOp = cast<LinalgOp>(op);
+LogicalResult mlir::linalg::interchangeGenericOpPrecondition(
+    GenericOp genericOp, ArrayRef<unsigned> interchangeVector) {
   // Interchange vector must be non-empty and match the number of loops.
   if (interchangeVector.empty() ||
-      linalgOp.getNumLoops() != interchangeVector.size())
+      genericOp.getNumLoops() != interchangeVector.size())
     return failure();
   // Permutation map must be invertible.
-  if (!inversePermutation(
-          AffineMap::getPermutationMap(interchangeVector, op->getContext())))
+  if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector,
+                                                       genericOp.getContext())))
     return failure();
   return success();
 }
 
-void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op,
-                               ArrayRef<unsigned> interchangeVector) {
+void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter,
+                                        GenericOp genericOp,
+                                        ArrayRef<unsigned> interchangeVector) {
   // 1. Compute the inverse permutation map.
-  MLIRContext *context = op.getContext();
+  MLIRContext *context = genericOp.getContext();
   AffineMap permutationMap = inversePermutation(
       AffineMap::getPermutationMap(interchangeVector, context));
   assert(permutationMap && "expected permutation to be invertible");
-  assert(interchangeVector.size() == op.getNumLoops() &&
+  assert(interchangeVector.size() == genericOp.getNumLoops() &&
          "expected interchange vector to have entry for every loop");
 
   // 2. Compute the interchanged indexing maps.
   SmallVector<Attribute, 4> newIndexingMaps;
-  ArrayRef<Attribute> indexingMaps = op.indexing_maps().getValue();
-  for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) {
+  ArrayRef<Attribute> indexingMaps = genericOp.indexing_maps().getValue();
+  for (unsigned i = 0, e = genericOp.getNumShapedOperands(); i != e; ++i) {
     AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
     if (!permutationMap.isEmpty())
       m = m.compose(permutationMap);
     newIndexingMaps.push_back(AffineMapAttr::get(m));
   }
-  op->setAttr(getIndexingMapsAttrName(),
-              ArrayAttr::get(context, newIndexingMaps));
+  genericOp->setAttr(getIndexingMapsAttrName(),
+                     ArrayAttr::get(context, newIndexingMaps));
 
   // 3. Compute the interchanged iterator types.
-  ArrayRef<Attribute> itTypes = op.iterator_types().getValue();
+  ArrayRef<Attribute> itTypes = genericOp.iterator_types().getValue();
   SmallVector<Attribute, 4> itTypesVector;
   llvm::append_range(itTypesVector, itTypes);
   applyPermutationToVector(itTypesVector, interchangeVector);
-  op->setAttr(getIteratorTypesAttrName(),
-              ArrayAttr::get(context, itTypesVector));
+  genericOp->setAttr(getIteratorTypesAttrName(),
+                     ArrayAttr::get(context, itTypesVector));
 
   // 4. Transform the index operations by applying the permutation map.
-  if (op.hasIndexSemantics()) {
+  if (genericOp.hasIndexSemantics()) {
     // TODO: Remove the assertion and add a getBody() method to LinalgOp
     // interface once every LinalgOp has a body.
-    assert(op->getNumRegions() == 1 &&
-           op->getRegion(0).getBlocks().size() == 1 &&
+    assert(genericOp->getNumRegions() == 1 &&
+           genericOp->getRegion(0).getBlocks().size() == 1 &&
            "expected generic operation to have one block.");
-    Block &block = op->getRegion(0).front();
+    Block &block = genericOp->getRegion(0).front();
     OpBuilder::InsertionGuard guard(rewriter);
     for (IndexOp indexOp :
          llvm::make_early_inc_range(block.getOps<IndexOp>())) {
       rewriter.setInsertionPoint(indexOp);
       SmallVector<Value> allIndices;
-      allIndices.reserve(op.getNumLoops());
-      llvm::transform(llvm::seq<uint64_t>(0, op.getNumLoops()),
+      allIndices.reserve(genericOp.getNumLoops());
+      llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()),
                       std::back_inserter(allIndices), [&](uint64_t dim) {
                         return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
                       });

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index acf460982784..736d298a713a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -393,30 +393,26 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
   return success();
 }
 
-/// Linalg base interchange pattern.
-mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
-    StringRef opName, MLIRContext *context,
-    ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter filter,
-    PatternBenefit benefit)
-    : RewritePattern(opName, benefit, context, {}), filter(filter),
+/// Linalg generic interchange pattern.
+mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
+    MLIRContext *context, ArrayRef<unsigned> interchangeVector,
+    LinalgTransformationFilter filter, PatternBenefit benefit)
+    : OpRewritePattern(context, benefit), filter(filter),
       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
 
-LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
-  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  if (!linalgOp)
+LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
+    GenericOp genericOp, PatternRewriter &rewriter) const {
+  if (failed(filter.checkAndNotify(rewriter, genericOp)))
     return failure();
-  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
-    return failure();
-  if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
+  if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
     return failure();
 
   // TODO: figure out how this interplays with named ops. In particular this
   // should break the named op property.
-  rewriter.updateRootInPlace(op, [&]() {
-    interchange(rewriter, linalgOp, interchangeVector);
+  rewriter.updateRootInPlace(genericOp, [&]() {
+    interchangeGenericOp(rewriter, genericOp, interchangeVector);
     // New filter if specified.
-    filter.replaceLinalgTransformationFilter(rewriter, op);
+    filter.replaceLinalgTransformationFilter(rewriter, genericOp);
   });
   return success();
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index f3092efbc580..347acb673f1f 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -125,37 +125,6 @@ func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK-SAME:     memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
 // CHECK-SAME:     memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
 
-#indexed_matmul_trait = {
-  args_in = 2,
-  args_out = 1,
-  indexing_maps = #matmul_accesses,
-  library_call = "linalg_matmul_indexed",
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-func @permute_generic_indexed(
-    %A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
-    %B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
-    %C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-  linalg.indexed_generic #indexed_matmul_trait
-    ins(%A, %B : memref<?x?xf32, offset: ?, strides: [?, 1]>,
-                 memref<?x?xf32, offset: ?, strides: [?, 1]>)
-   outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
-    ^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
-      %d = mulf %a, %b: f32
-      %e = addf %c, %d: f32
-      linalg.yield %e: f32
-  }
-  return
-}
-// CHECK-LABEL:  func @permute_generic_indexed
-// CHECK:        linalg.indexed_generic {
-// CHECK-SAME:     indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
-// CHECK-SAME:     iterator_types = ["parallel", "reduction", "parallel"],
-// CHECK-SAME:     library_call = "linalg_matmul_indexed"}
-// CHECK:            memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
-// CHECK-SAME:       memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
-// CHECK-SAME:       memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
-
 func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
              %x: memref<?xf32, offset: ?, strides: [1]>,
              %y: memref<?xf32, offset: ?, strides: [1]>) {

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 94ab9b951c37..90df19c2009e 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -194,14 +194,9 @@ static void applyPatterns(FuncOp funcOp) {
                .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
 
   //===--------------------------------------------------------------------===//
-  // Linalg generic permutation patterns.
+  // Linalg generic interchange pattern.
   //===--------------------------------------------------------------------===//
-  patterns.add<LinalgInterchangePattern<GenericOp>>(
-      ctx,
-      /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
-      LinalgTransformationFilter(ArrayRef<Identifier>{},
-                                 Identifier::get("PERMUTED", ctx)));
-  patterns.add<LinalgInterchangePattern<IndexedGenericOp>>(
+  patterns.add<GenericOpInterchangePattern>(
       ctx,
       /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
       LinalgTransformationFilter(ArrayRef<Identifier>{},
@@ -551,7 +546,7 @@ static void applyInterchangePattern(FuncOp funcOp,
                                     ArrayRef<unsigned> interchangeVector) {
   MLIRContext *context = funcOp.getContext();
   RewritePatternSet interchangePattern(context);
-  interchangePattern.add<LinalgInterchangePattern<GenericOp>>(
+  interchangePattern.add<GenericOpInterchangePattern>(
       context, interchangeVector,
       LinalgTransformationFilter(ArrayRef<Identifier>{},
                                  Identifier::get("interchange", context)));


        


More information about the Mlir-commits mailing list