[Mlir-commits] [mlir] d0774f7 - [mlir][linalg] update drop unit dims to support linalg index operations.

Tobias Gysi llvmlistbot at llvm.org
Mon Apr 19 22:24:07 PDT 2021


Author: Tobias Gysi
Date: 2021-04-20T04:54:00Z
New Revision: d0774f7f0a147411bf2b0e3a692b22853f3fd3de

URL: https://github.com/llvm/llvm-project/commit/d0774f7f0a147411bf2b0e3a692b22853f3fd3de
DIFF: https://github.com/llvm/llvm-project/commit/d0774f7f0a147411bf2b0e3a692b22853f3fd3de.diff

LOG: [mlir][linalg] update drop unit dims to support linalg index operations.

Update the dimensions of the index operations to account for dropped dimensions and replace the index operations of dropped dimensions by zero.

Differential Revision: https://reviews.llvm.org/D100395

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 0059178d30e5..15540596f75c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -145,6 +145,31 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
                             })));
 }
 
+/// Update the index accesses of linalg operations having index semantics.
+template <typename GenericOpTy>
+static void replaceUnitDimIndexOps(GenericOpTy op,
+                                   const DenseSet<unsigned> &unitDims,
+                                   PatternRewriter &rewriter) {
+  assert(op->getNumRegions() == 1 && op->getRegion(0).getBlocks().size() == 1 &&
+         "expected generic operation to have one block.");
+  Block &block = op->getRegion(0).front();
+
+  for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps<IndexOp>())) {
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(indexOp);
+    if (unitDims.count(indexOp.dim()) != 0) {
+      rewriter.replaceOpWithNewOp<ConstantIndexOp>(indexOp, 0);
+    } else {
+      // Update the dimension of the index operation if needed.
+      unsigned droppedDims = llvm::count_if(
+          unitDims, [&](unsigned dim) { return dim < indexOp.dim(); });
+      if (droppedDims != 0)
+        rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
+                                             indexOp.dim() - droppedDims);
+    }
+  }
+}
+
 /// Modify the region of indexed generic op to drop arguments corresponding to
 /// loops that are unit trip count.
 template <typename OpTy>
@@ -177,10 +202,6 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(GenericOpTy op,
                                 PatternRewriter &rewriter) const override {
-    // TODO: remove once index ops are supported.
-    if (op.hasIndexSemantics())
-      return failure();
-
     SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
     if (indexingMaps.empty())
       return failure();
@@ -253,6 +274,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
     op.indexing_mapsAttr(newIndexingMapAttr);
     op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
     (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
+    replaceUnitDimIndexOps(op, unitDims, rewriter);
     rewriter.finalizeRootUpdate(op);
     return success();
   }
@@ -325,10 +347,6 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(GenericOpTy op,
                                 PatternRewriter &rewriter) const override {
-    // TODO: remove once index ops are supported.
-    if (op.hasIndexSemantics())
-      return failure();
-
     if (!op.hasTensorSemantics())
       return failure();
 

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index e65a408a0a0e..8a36f6dce78b 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -78,6 +78,56 @@ func @drop_one_trip_loops_indexed_generic
 
 // -----
 
+#accesses = [
+  affine_map<(i, j, k, l, m) -> (i, k, m)>,
+  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
+]
+
+#trait = {
+  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
+  indexing_maps = #accesses,
+  library_call = "some_external_func"
+}
+
+func @drop_one_trip_loops_indexed
+  (%arg0 : tensor<?x1x?xi32>, %shape: tensor<?x1x?x1x?xi32>) -> tensor<?x1x?x1x?xi32>
+{
+  %0 = linalg.generic #trait
+     ins(%arg0 : tensor<?x1x?xi32>)
+    outs(%shape: tensor<?x1x?x1x?xi32>) {
+       ^bb0(%arg6 : i32, %arg7 : i32) :
+         %idx0 = linalg.index 0 : index
+         %idx1 = linalg.index 1 : index
+         %idx2 = linalg.index 2 : index
+         %idx3 = linalg.index 3 : index
+         %idx4 = linalg.index 4 : index
+         %1 = addi %idx0, %idx1 : index
+         %2 = subi %1, %idx2 : index
+         %3 = subi %2, %idx3 : index
+         %4 = addi %3, %idx4 : index
+         %5 = index_cast %4 : index to i32
+         %6 = addi %5, %arg6 : i32
+         linalg.yield %6 : i32
+       } -> tensor<?x1x?x1x?xi32>
+  return %0 : tensor<?x1x?x1x?xi32>
+}
+// The subtractions disappear the access map of the output tensor maps its unit
+// dimensions 1 and 3 to the index dimensions 2 and 3.
+// CHECK-LABEL: func @drop_one_trip_loops_indexed
+//       CHECK:   linalg.generic
+//       CHECK:   ^{{.+}}(
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i32, %{{.*}}: i32)
+//       CHECK:     %[[IDX0:.+]] = linalg.index 0 : index
+//       CHECK:     %[[IDX1:.+]] = linalg.index 1 : index
+//       CHECK:     %[[IDX2:.+]] = linalg.index 2 : index
+//       CHECK:     %[[T3:.+]] = addi %[[IDX0]], %[[IDX1]]
+//       CHECK:     %[[T4:.+]] = addi %[[T3]], %[[IDX2]]
+//       CHECK:     %[[T5:.+]] = index_cast %[[T4]] : index to i32
+//       CHECK:     %[[T6:.+]] = addi %[[T5]], %[[ARG4]] : i32
+//       CHECK:     linalg.yield %[[T6]] : i32
+
+// -----
+
 #map0 = affine_map<(i, j) -> (i, j)>
 #access = [#map0, #map0]
 #trait = {
@@ -134,6 +184,37 @@ func @drop_all_loops_indexed_generic
 
 // -----
 
+#map0 = affine_map<(i, j) -> (i, j)>
+#access = [#map0, #map0]
+#trait = {
+  iterator_types = ["parallel", "parallel"],
+  indexing_maps = #access,
+  library_call = "some_external_func"
+}
+
+func @drop_all_loops_indexed
+  (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>{
+  %0 = linalg.generic #trait
+     ins(%arg0 : tensor<1x1xi32>)
+    outs(%arg0 : tensor<1x1xi32>) {
+       ^bb0(%arg3: i32, %arg4: i32) :
+         %idx0 = linalg.index 0 : index
+         %idx1 = linalg.index 1 : index
+         %1 = addi %idx0, %idx1 : index
+         %2 = index_cast %1 : index to i32
+         %3 = addi %2, %arg3 : i32
+         linalg.yield %3 : i32
+       } -> tensor<1x1xi32>
+  return %0 : tensor<1x1xi32>
+}
+
+// CHECK-LABEL: func @drop_all_loops_indexed
+//       CHECK:   linalg.generic
+//       CHECK:   ^{{.+}}(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
+//       CHECK:     linalg.yield %[[ARG1]] : i32
+
+// -----
+
 #accesses = [
   affine_map<(d0) -> (0, d0)>,
   affine_map<(d0) -> (d0)>
@@ -566,19 +647,3 @@ func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32
 // CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
 //      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]]
 //      CHECK:   return %[[RESULT_RESHAPE]]
-
-// -----
-
-//  CHECK: #{{.+}} = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-LABEL: @index_op
-func @index_op(%arg0: memref<1x8xindex>) {
-  linalg.generic {
-    indexing_maps = [affine_map<(i, j) -> (i, j)>],
-    iterator_types = ["parallel", "parallel"]}
-  outs(%arg0 : memref<1x8xindex>) {
-  ^bb0(%arg1: index):   // no predecessors
-    %0 = linalg.index 1 : index
-    linalg.yield %0 : index
-  }
-  return
-}


        


More information about the Mlir-commits mailing list