[Mlir-commits] [mlir] ef30179 - [mlir][linalg] lower index operations during linalg to loop lowering.

Tobias Gysi llvmlistbot at llvm.org
Tue Apr 13 02:05:12 PDT 2021


Author: Tobias Gysi
Date: 2021-04-13T09:04:09Z
New Revision: ef30179efff24a02d5f7a3380a7f3cab247b1338

URL: https://github.com/llvm/llvm-project/commit/ef30179efff24a02d5f7a3380a7f3cab247b1338
DIFF: https://github.com/llvm/llvm-project/commit/ef30179efff24a02d5f7a3380a7f3cab247b1338.diff

LOG: [mlir][linalg] lower index operations during linalg to loop lowering.

The patch extends the linalg to loop lowering pass to replace all linalg index operations by the induction variables of the generated loop nests.

Differential Revision: https://reviews.llvm.org/D100364

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/test/Dialect/Linalg/loop-order.mlir
    mlir/test/Dialect/Linalg/loops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 8b2b0cde8f9a..bb1000a51c35 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -516,6 +516,47 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
   return loops;
 }
 
+/// Replace the index operations in the body of the loop nest by the matching
+/// induction variables. If available use the interchange vector to map the
+/// interchanged induction variables to the dimension of the index operation.
+static void replaceIndexOpsByInductionVariables(
+    LinalgOp linalgOp, PatternRewriter &rewriter, ArrayRef<Operation *> loopOps,
+    ArrayRef<unsigned> interchangeVector) {
+  // Extract the induction variables of the loop nest from outer to inner.
+  SmallVector<Value> allIvs;
+  for (Operation *loopOp : loopOps) {
+    llvm::TypeSwitch<Operation *>(loopOp)
+        .Case([&](scf::ParallelOp parallelOp) {
+          allIvs.append(parallelOp.getInductionVars().begin(),
+                        parallelOp.getInductionVars().end());
+        })
+        .Case([&](scf::ForOp forOp) {
+          allIvs.push_back(forOp.getInductionVar());
+        })
+        .Case([&](AffineForOp affineForOp) {
+          allIvs.push_back(affineForOp.getInductionVar());
+        })
+        .Default([&](Operation *op) { assert(false && "unexpected op"); });
+  }
+  assert(linalgOp.getNumLoops() == allIvs.size() &&
+         "expected the number of loops and induction variables to match");
+  // Replace the index operations in the body of the innermost loop op.
+  if (!loopOps.empty()) {
+    LoopLikeOpInterface loopOp = loopOps.back();
+    for (IndexOp indexOp :
+         llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) {
+      // Search the indexing dimension in the interchange vector if available.
+      assert(interchangeVector.empty() ||
+             interchangeVector.size() == linalgOp.getNumLoops());
+      const auto *it = llvm::find(interchangeVector, indexOp.dim());
+      uint64_t dim = it != interchangeVector.end()
+                         ? std::distance(interchangeVector.begin(), it)
+                         : indexOp.dim();
+      rewriter.replaceOp(indexOp, allIvs[dim]);
+    }
+  }
+}
+
 namespace {
 template <typename LoopType>
 class LinalgRewritePattern : public RewritePattern {
@@ -528,11 +569,14 @@ class LinalgRewritePattern : public RewritePattern {
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     auto linalgOp = dyn_cast<LinalgOp>(op);
-    // TODO: remove hasIndexSemantics check once index ops are supported.
-    if (!linalgOp || linalgOp.hasIndexSemantics())
+    if (!isa<LinalgOp>(op))
       return failure();
-    if (!linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector))
+    Optional<LinalgLoops> loopOps =
+        linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector);
+    if (!loopOps.hasValue())
       return failure();
+    replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue(),
+                                        interchangeVector);
     rewriter.eraseOp(op);
     return success();
   }

diff  --git a/mlir/test/Dialect/Linalg/loop-order.mlir b/mlir/test/Dialect/Linalg/loop-order.mlir
index 968ffdc5e747..c572967e6d10 100644
--- a/mlir/test/Dialect/Linalg/loop-order.mlir
+++ b/mlir/test/Dialect/Linalg/loop-order.mlir
@@ -24,22 +24,49 @@ func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) {
 
 // -----
 
-func @index_op(%arg0: memref<4x8xindex>) {
-  linalg.generic {
-    indexing_maps = [affine_map<(i, j) -> (i, j)>],
-    iterator_types = ["parallel", "parallel"]}
-  outs(%arg0 : memref<4x8xindex>) {
-  ^bb0(%arg1: index):   // no predecessors
-    %0 = linalg.index 1 : index
-    linalg.yield %0 : index
+#map = affine_map<(i, j, k, l, m) -> (i, j, k, l, m)>
+func @generic(%output: memref<1x2x3x4x5xindex>) {
+  linalg.generic {indexing_maps = [#map],
+                  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+    outs(%output : memref<1x2x3x4x5xindex>) {
+    ^bb0(%arg0 : index):
+    %i = linalg.index 0 : index
+    %j = linalg.index 1 : index
+    %k = linalg.index 2 : index
+    %l = linalg.index 3 : index
+    %m = linalg.index 4 : index
+    %0 = addi %i, %j : index
+    %1 = addi %0, %k : index
+    %2 = addi %1, %l : index
+    %3 = addi %2, %m : index
+    linalg.yield %3: index
   }
   return
 }
-// LOOP-LABEL: @index_op
-//      LOOP:   linalg.generic
 
-// PARALLEL-LABEL: @index_op
-//      PARALLEL:   linalg.generic
+// LOOP: scf.for %[[m:.*]] = %c0 to %c5 step %c1
+// LOOP:   scf.for %[[i:.*]] = %c0 to %c1 step %c1
+// LOOP:     scf.for %[[l:.*]] = %c0 to %c4 step %c1
+// LOOP:       scf.for %[[j:.*]] = %c0 to %c2 step %c1
+// LOOP:         scf.for %[[k:.*]] = %c0 to %c3 step %c1
+// LOOP:           %{{.*}} = addi %[[i]], %[[j]] : index
+// LOOP:           %{{.*}} = addi %{{.*}}, %[[k]] : index
+// LOOP:           %{{.*}} = addi %{{.*}}, %[[l]] : index
+// LOOP:           %{{.*}} = addi %{{.*}}, %[[m]] : index
 
-// AFFINE-LABEL: @index_op
-//      AFFINE:   linalg.generic
+// PARALLEL: 			scf.parallel (%[[m:.*]], %[[i:.*]], %[[l:.*]], %[[j:.*]], %[[k:.*]]) =
+// PARALLEL-SAME:   to (%c5, %c1, %c4, %c2, %c3)
+// PARALLEL:        %{{.*}} = addi %[[i]], %[[j]] : index
+// PARALLEL:        %{{.*}} = addi %{{.*}}, %[[k]] : index
+// PARALLEL:        %{{.*}} = addi %{{.*}}, %[[l]] : index
+// PARALLEL:        %{{.*}} = addi %{{.*}}, %[[m]] : index
+
+// AFFINE: affine.for %[[m:.*]] = 0 to 5
+// AFFINE:   affine.for %[[i:.*]] = 0 to 1
+// AFFINE:     affine.for %[[l:.*]] = 0 to 4
+// AFFINE:       affine.for %[[j:.*]] = 0 to 2
+// AFFINE:         affine.for %[[k:.*]] = 0 to 3
+// AFFINE:           %{{.*}} = addi %[[i]], %[[j]] : index
+// AFFINE:           %{{.*}} = addi %{{.*}}, %[[k]] : index
+// AFFINE:           %{{.*}} = addi %{{.*}}, %[[l]] : index
+// AFFINE:           %{{.*}} = addi %{{.*}}, %[[m]] : index

diff  --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
index 5be3525854f9..b0ffc7f0053e 100644
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -880,6 +880,61 @@ func @generic_region(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1:
   library_call = "some_external_function_name_2",
   doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"
 }
+func @generic_index_region(
+        %arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+        %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+        %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+  linalg.generic #trait4
+      ins(%arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>)
+     outs(%arg1, %arg2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+                         memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+    ^bb0(%a: f32, %b: f32, %c: f32):
+      %i = linalg.index 0 : index
+      %j = linalg.index 1 : index
+      %k = linalg.index 2 : index
+      %result_1 = mulf %a, %b : f32
+
+      %ij = addi %i, %j : index
+      %ijk = addi %ij, %k : index
+      %ijk_int = index_cast %ijk : index to i32
+      %ijk_float = sitofp %ijk_int : i32 to f32
+
+      %result_2 = addf %c, %ijk_float : f32
+      linalg.yield %result_1, %result_2 : f32, f32
+  }
+  return
+}
+
+// CHECKLOOP-LABEL: @generic_index_region
+//       CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
+//       CHECKLOOP:   scf.for %[[j:.*]] = {{.*}}
+//       CHECKLOOP:     scf.for %[[k:.*]] = {{.*}}
+//       CHECKLOOP:       %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]]
+//       CHECKLOOP:       %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]]
+//       CHECKLOOP:       %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]]
+//       CHECKLOOP:       %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32
+//       CHECKLOOP:       %[[ij:.*]] = addi %[[i]], %[[j]] : index
+//       CHECKLOOP:       %[[ijk:.*]] = addi %[[ij]], %[[k]] : index
+//       CHECKLOOP:       %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32
+//       CHECKLOOP:       %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32
+//       CHECKLOOP:       %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
+//       CHECKLOOP:       store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
+//       CHECKLOOP:       store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
+
+// CHECKPARALLEL-LABEL: @generic_index_region
+//       CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]])
+//       CHECKPARALLEL:   %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]]
+//       CHECKPARALLEL:   %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]]
+//       CHECKPARALLEL:   %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]]
+//       CHECKPARALLEL:   %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32
+//       CHECKPARALLEL:   %[[ij:.*]] = addi %[[i]], %[[j]] : index
+//       CHECKPARALLEL:   %[[ijk:.*]] = addi %[[ij]], %[[k]] : index
+//       CHECKPARALLEL:   %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32
+//       CHECKPARALLEL:   %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32
+//       CHECKPARALLEL:   %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
+//       CHECKPARALLEL:   store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
+//       CHECKPARALLEL:   store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]
+
 func @indexed_generic_region(
         %arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
         %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
@@ -973,6 +1028,43 @@ func @generic_op_zero_rank(%arg0: memref<f32>, %arg1: memref<3x4xf32>)
 //       CHECKPARALLEL:   %[[a:.*]] = memref.load %[[ARG0]][]
 //       CHECKPARALLEL:   store %[[a]], %[[ARG1]][%[[i]], %[[j]]]
 
+func @generic_index_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
+{
+  linalg.generic #trait_broadcast
+      ins(%arg0 : memref<i32>)
+     outs(%arg1 : memref<3x4xi32>) {
+    ^bb(%a: i32, %b: i32) :
+      %i = linalg.index 0 : index
+      %j = linalg.index 1 : index
+      %ij = addi %i, %j : index
+      %ij_int = index_cast %ij : index to i32
+      %result = addi %a, %ij_int : i32
+      linalg.yield %result : i32
+  }
+  return
+}
+
+// CHECKLOOP-LABEL: @generic_index_op_zero_rank
+//  CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
+//  CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
+//       CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
+//       CHECKLOOP:   scf.for %[[j:.*]] = {{.*}}
+//       CHECKLOOP:     %[[a:.*]] = memref.load %[[ARG0]][
+//       CHECKLOOP:     %[[ij:.*]] = addi %[[i]], %[[j]] : index
+//       CHECKLOOP:     %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
+//       CHECKLOOP:     %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
+//       CHECKLOOP:     store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
+
+// CHECKPARALLEL-LABEL: @generic_index_op_zero_rank
+//  CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<i32>
+//  CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<3x4xi32>
+//       CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]])
+//       CHECKPARALLEL:   %[[a:.*]] = memref.load %[[ARG0]][
+//       CHECKPARALLEL:   %[[ij:.*]] = addi %[[i]], %[[j]] : index
+//       CHECKPARALLEL:   %[[ij_int:.*]] = index_cast %[[ij]] : index to i32
+//       CHECKPARALLEL:   %[[result:.*]] = addi %[[a]], %[[ij_int]] : i32
+//       CHECKPARALLEL:   store %[[result]], %[[ARG1]][%[[i]], %[[j]]]
+
 func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
 {
   linalg.indexed_generic #trait_broadcast
@@ -1065,6 +1157,47 @@ func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)
   library_call = "some_reduce_external_fn"
 }
 
+func @generic_index_op_1D_reduce(%arg0: memref<?xf32>,
+                                %arg1: memref<f32>,
+                                %arg2: memref<f32>)
+{
+  linalg.generic #trait_reduce_init_1D
+      ins(%arg0, %arg1 : memref<?xf32>, memref<f32>)
+     outs(%arg2 : memref<f32>) {
+    ^bb(%a: f32, %b: f32, %c: f32) :
+      %i = linalg.index 0 : index
+      %0 = constant 0 : index
+      %1 = cmpi eq, %0, %i : index
+      %2 = select %1, %b, %c : f32
+      %3 = addf %a, %2 : f32
+      linalg.yield %3 : f32
+  }
+  return
+}
+// CHECKLOOP-LABEL: @generic_index_op_1D_reduce
+//  CHECKLOOP-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
+//  CHECKLOOP-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
+//  CHECKLOOP-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
+//       CHECKLOOP: scf.for %[[i:.*]] = {{.*}}
+//       CHECKLOOP:   %[[a:.*]] = memref.load %[[ARG0]][%[[i]]]
+//       CHECKLOOP:   %[[b:.*]] = memref.load %[[ARG1]][]
+//       CHECKLOOP:   %[[c:.*]] = memref.load %[[ARG2]][]
+//       CHECKLOOP:   %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
+//       CHECKLOOP:   %[[e:.*]] = addf %[[a]], %[[d]]
+//       CHECKLOOP:   store %[[e]], %[[ARG2]][]
+
+// CHECKPARALLEL-LABEL: @generic_index_op_1D_reduce
+//  CHECKPARALLEL-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?xf32>
+//  CHECKPARALLEL-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<f32>
+//  CHECKPARALLEL-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<f32>
+//       CHECKPARALLEL: scf.for %[[i:.*]] = {{.*}}
+//       CHECKPARALLEL:   %[[a:.*]] = memref.load %[[ARG0]][%[[i]]]
+//       CHECKPARALLEL:   %[[b:.*]] = memref.load %[[ARG1]][]
+//       CHECKPARALLEL:   %[[c:.*]] = memref.load %[[ARG2]][]
+//       CHECKPARALLEL:   %[[d:.*]] = select %{{.*}}, %[[b]], %[[c]]
+//       CHECKPARALLEL:   %[[e:.*]] = addf %[[a]], %[[d]]
+//       CHECKPARALLEL:   store %[[e]], %[[ARG2]][]
+
 func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
                                    %arg1: memref<f32>,
                                    %arg2: memref<f32>)


        


More information about the Mlir-commits mailing list