[Mlir-commits] [mlir] 495e1d7 - [mlir][linalg] adding pass to run the interchange pattern.

Tobias Gysi llvmlistbot at llvm.org
Mon Apr 19 05:20:20 PDT 2021


Author: Tobias Gysi
Date: 2021-04-19T12:19:15Z
New Revision: 495e1d7e8a68e4343756b58b0dd7b4bd047bd847

URL: https://github.com/llvm/llvm-project/commit/495e1d7e8a68e4343756b58b0dd7b4bd047bd847
DIFF: https://github.com/llvm/llvm-project/commit/495e1d7e8a68e4343756b58b0dd7b4bd047bd847.diff

LOG: [mlir][linalg] adding pass to run the interchange pattern.

Instead of interchanging loops during the loop lowering this pass performs the interchange by permuting the indexing maps. It also updates the iterator types and the index accesses in the body of the operation.

Differential Revision: https://reviews.llvm.org/D100627

Added: 
    mlir/test/Dialect/Linalg/interchange.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4d45642b1d983..48b1eb8cdf502 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -194,16 +194,17 @@ tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
                      const LinalgDependenceGraph &dependenceGraph,
                      const LinalgTilingOptions &tilingOptions);
 
-/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
-/// This is an in-place transformation controlled by `interchangeVector`.
-/// An empty vector is interpreted as the identity permutation and the
-/// transformation returns early.
+/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts
+/// the index accesses of `op`. This is an in-place transformation controlled by
+/// `interchangeVector`. An empty vector is interpreted as the identity
+/// permutation and the transformation returns early.
 ///
 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with
 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
 /// integers, in the range 0..`op.rank` without duplications
 /// (i.e. `[1,1,2]` is an invalid permutation).
-LinalgOp interchange(LinalgOp op, ArrayRef<unsigned> interchangeVector);
+void interchange(PatternRewriter &rewriter, LinalgOp op,
+                 ArrayRef<unsigned> interchangeVector);
 
 /// Callback function type used to perform the allocation for the promoted
 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
index b893f2ba67211..29acd628d101f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
@@ -34,17 +34,13 @@ using namespace mlir::linalg;
 
 LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
     Operation *op, ArrayRef<unsigned> interchangeVector) {
-  if (interchangeVector.empty())
-    return failure();
   // Transformation applies to generic ops only.
   if (!isa<GenericOp, IndexedGenericOp>(op))
     return failure();
-  LinalgOp linOp = cast<LinalgOp>(op);
-  // Transformation applies to buffers only.
-  if (!linOp.hasBufferSemantics())
-    return failure();
-  // Permutation must be applicable.
-  if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size())
+  LinalgOp linalgOp = cast<LinalgOp>(op);
+  // Interchange vector must be non-empty and match the number of loops.
+  if (interchangeVector.empty() ||
+      linalgOp.getNumLoops() != interchangeVector.size())
     return failure();
   // Permutation map must be invertible.
   if (!inversePermutation(
@@ -53,33 +49,56 @@ LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
   return success();
 }
 
-LinalgOp mlir::linalg::interchange(LinalgOp op,
-                                   ArrayRef<unsigned> interchangeVector) {
-  if (interchangeVector.empty())
-    return op;
-
+void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op,
+                               ArrayRef<unsigned> interchangeVector) {
+  // 1. Compute the inverse permutation map.
   MLIRContext *context = op.getContext();
-  auto permutationMap = inversePermutation(
+  AffineMap permutationMap = inversePermutation(
       AffineMap::getPermutationMap(interchangeVector, context));
   assert(permutationMap && "expected permutation to be invertible");
+  assert(interchangeVector.size() == op.getNumLoops() &&
+         "expected interchange vector to have entry for every loop");
+
+  // 2. Compute the interchanged indexing maps.
   SmallVector<Attribute, 4> newIndexingMaps;
-  auto indexingMaps = op.indexing_maps().getValue();
+  ArrayRef<Attribute> indexingMaps = op.indexing_maps().getValue();
   for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) {
     AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
     if (!permutationMap.isEmpty())
       m = m.compose(permutationMap);
     newIndexingMaps.push_back(AffineMapAttr::get(m));
   }
-  auto itTypes = op.iterator_types().getValue();
-  SmallVector<Attribute, 4> itTypesVector;
-  for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
-    itTypesVector.push_back(itTypes[i]);
-  applyPermutationToVector(itTypesVector, interchangeVector);
-
   op->setAttr(getIndexingMapsAttrName(),
               ArrayAttr::get(context, newIndexingMaps));
+
+  // 3. Compute the interchanged iterator types.
+  ArrayRef<Attribute> itTypes = op.iterator_types().getValue();
+  SmallVector<Attribute, 4> itTypesVector;
+  llvm::append_range(itTypesVector, itTypes);
+  applyPermutationToVector(itTypesVector, interchangeVector);
   op->setAttr(getIteratorTypesAttrName(),
               ArrayAttr::get(context, itTypesVector));
 
-  return op;
+  // 4. Transform the index operations by applying the permutation map.
+  if (op.hasIndexSemantics()) {
+    // TODO: Remove the assertion and add a getBody() method to LinalgOp
+    // interface once every LinalgOp has a body.
+    assert(op->getNumRegions() == 1 &&
+           op->getRegion(0).getBlocks().size() == 1 &&
+           "expected generic operation to have one block.");
+    Block &block = op->getRegion(0).front();
+    OpBuilder::InsertionGuard guard(rewriter);
+    for (IndexOp indexOp :
+         llvm::make_early_inc_range(block.getOps<IndexOp>())) {
+      rewriter.setInsertionPoint(indexOp);
+      SmallVector<Value> allIndices;
+      allIndices.reserve(op.getNumLoops());
+      llvm::transform(llvm::seq<int64_t>(0, op.getNumLoops()),
+                      std::back_inserter(allIndices), [&](int64_t dim) {
+                        return rewriter.create<IndexOp>(indexOp->getLoc(), dim);
+                      });
+      rewriter.replaceOpWithNewOp<AffineApplyOp>(
+          indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices);
+    }
+  }
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index c51c92930ab41..55402a737cbb1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -404,8 +404,7 @@ mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  // TODO: remove hasIndexSemantics check once index ops are supported.
-  if (!linalgOp || linalgOp.hasIndexSemantics())
+  if (!linalgOp)
     return failure();
   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
@@ -415,7 +414,7 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
   // TODO: figure out how this interplays with named ops. In particular this
   // should break the named op property.
   rewriter.updateRootInPlace(op, [&]() {
-    interchange(linalgOp, interchangeVector);
+    interchange(rewriter, linalgOp, interchangeVector);
     // New filter if specified.
     filter.replaceLinalgTransformationFilter(rewriter, op);
   });

diff  --git a/mlir/test/Dialect/Linalg/interchange.mlir b/mlir/test/Dialect/Linalg/interchange.mlir
new file mode 100644
index 0000000000000..bc1d10b12f449
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/interchange.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-interchange-pattern=4,0,3,1,2 | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-interchange-pattern=4,0,3,1,2 -test-linalg-transform-patterns=test-interchange-pattern=1,3,4,2,0 | FileCheck --check-prefix=CANCEL-OUT %s
+
+#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+
+func @interchange_generic_op(%arg0 : memref<1x2x3x4x5xindex>, %arg1 : memref<1x2x4xindex>) {
+  linalg.generic {
+    indexing_maps = [#map0, #map1],
+    iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]}
+  ins(%arg0 : memref<1x2x3x4x5xindex>)
+  outs(%arg1 : memref<1x2x4xindex>) {
+      ^bb0(%arg2 : index, %arg3 : index) :
+        %0 = linalg.index 0 : index
+        %1 = linalg.index 1 : index
+        %2 = linalg.index 4 : index
+        %3 = subi %0, %1 : index
+        %4 = addi %3, %2 : index
+        %5 = addi %4, %arg2 : index
+        linalg.yield %5 : index
+      }
+  return
+}
+
+//    CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4, d2, d0)>
+//    CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d2)>
+//        CHECK: func @interchange_generic_op
+//        CHECK:   linalg.generic
+//   CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//   CHECK-SAME:     iterator_types = ["reduction", "parallel", "parallel", "parallel", "reduction"]
+//    CHECK-DAG:     %[[IDX0:.+]] = linalg.index 1 : index
+//    CHECK-DAG:     %[[IDX1:.+]] = linalg.index 3 : index
+//    CHECK-DAG:     %[[IDX4:.+]] = linalg.index 0 : index
+//        CHECK:     %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index
+//        CHECK:     %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index
+//        CHECK:     %[[T2:.+]] = addi %[[T1]], %{{.*}} : index
+
+//  CANCEL-OUT-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//  CANCEL-OUT-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>
+//      CANCEL-OUT: func @interchange_generic_op
+//      CANCEL-OUT:   linalg.generic
+// CANCEL-OUT-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CANCEL-OUT-SAME:     iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]
+//  CANCEL-OUT-DAG:     %[[IDX0:.+]] = linalg.index 0 : index
+//  CANCEL-OUT-DAG:     %[[IDX1:.+]] = linalg.index 1 : index
+//  CANCEL-OUT-DAG:     %[[IDX4:.+]] = linalg.index 4 : index
+//      CANCEL-OUT:     %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index
+//      CANCEL-OUT:     %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index
+//      CANCEL-OUT:     %[[T2:.+]] = addi %[[T1]], %{{.*}} : index
+
+

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index a6fe895035d20..178de38039aa1 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -91,6 +91,9 @@ struct TestLinalgTransforms
       *this, "tile-sizes-for-padding",
       llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
       llvm::cl::MiscFlags::CommaSeparated};
+  ListOption<unsigned> testInterchangePattern{
+      *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated,
+      llvm::cl::desc("Test the interchange pattern.")};
 };
 } // end anonymous namespace
 
@@ -540,6 +543,17 @@ static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
 }
 
+static void applyInterchangePattern(FuncOp funcOp,
+                                    ArrayRef<unsigned> interchangeVector) {
+  MLIRContext *context = funcOp.getContext();
+  RewritePatternSet interchangePattern(context);
+  interchangePattern.add<LinalgInterchangePattern<GenericOp>>(
+      context, interchangeVector,
+      LinalgTransformationFilter(ArrayRef<Identifier>{},
+                                 Identifier::get("interchange", context)));
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnFunction() {
   auto lambda = [&](void *) {
@@ -580,6 +594,8 @@ void TestLinalgTransforms::runOnFunction() {
       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);
     });
   }
+  if (testInterchangePattern.hasValue())
+    return applyInterchangePattern(getFunction(), testInterchangePattern);
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list