[Mlir-commits] [mlir] 4a6ee23 - [mlir][linalg] Fix bug in the fusion on tensors index op handling.

Tobias Gysi llvmlistbot at llvm.org
Wed May 5 07:47:18 PDT 2021


Author: Tobias Gysi
Date: 2021-05-05T14:46:08Z
New Revision: 4a6ee23d832f823d71faf7d0dca1b6eec71df253

URL: https://github.com/llvm/llvm-project/commit/4a6ee23d832f823d71faf7d0dca1b6eec71df253
DIFF: https://github.com/llvm/llvm-project/commit/4a6ee23d832f823d71faf7d0dca1b6eec71df253.diff

LOG: [mlir][linalg] Fix bug in the fusion on tensors index op handling.

The old index op handling let the new index operations point back to the
producer block. As a result, after fusion some index operations in the
fused block had back references to the old producer block resulting in
illegal IR. The patch now relies on a block and value mapping to avoid
such back references.

Differential Revision: https://reviews.llvm.org/D101887

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/fusion-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 5af62dafe6d9..d1646e92b8d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -145,11 +145,12 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
         fusedBlock->getArguments().take_front(numFusedOpIndices));
     mapper.map(std::get<0>(it), newIndex);
   }
-  // 2b. Replace the producer index operations by index operations placed in the
-  // fused block using the `consumerToProducerLoopsMap` to map the index spaces.
-  unsigned numFusedOpLoops =
-      std::max(producer.getNumLoops(), consumer.getNumLoops());
+  // 2b. Add an index operation for every fused loop dimension and use the
+  // `consumerToProducerLoopsMap` to map the producer indices.
   if (producer.hasIndexSemantics()) {
+    // Add an index operation for every fused loop dimension.
+    unsigned numFusedOpLoops =
+        std::max(producer.getNumLoops(), consumer.getNumLoops());
     SmallVector<Value> fusedIndices;
     fusedIndices.reserve(numFusedOpLoops);
     llvm::transform(llvm::seq<int64_t>(0, numFusedOpLoops),
@@ -161,10 +162,7 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
       Value newIndex = rewriter.create<mlir::AffineApplyOp>(
           producer.getLoc(),
           consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
-      // Replace the producer index operation by the index value computed in the
-      // fused block. All remaining operations in the producer block are later
-      // on cloned to the fused block.
-      rewriter.replaceOp(indexOp, newIndex);
+      mapper.map(indexOp.getResult(), newIndex);
     }
   }
   // TODO: allow fusing the producer of an output operand.
@@ -210,10 +208,12 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
   // TODO: allow fusion of multi-result producers.
   assert(producer->getNumResults() == 1 && "expected single result producer");
 
-  // 8. Clone operations from producer (except the yield operation) to the fused
-  // op.
-  for (auto &op : producerBlock.without_terminator())
-    rewriter.clone(op, mapper);
+  // 8. Clone all producer operations except for the yield and index operations
+  // to the fused operation.
+  for (auto &op : producerBlock.without_terminator()) {
+    if (!isa<IndexOp>(op))
+      rewriter.clone(op, mapper);
+  }
   // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
   // forward the yield operand.
   auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 40c52657a853..1ba2d37fff3e 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -462,8 +462,7 @@ func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
 // -----
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
-func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>,
-                                       %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
+func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
   %0 = memref.dim %arg0, %c0 : tensor<?x?xi32>
@@ -486,7 +485,7 @@ func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>,
   %4 = linalg.generic {
     indexing_maps = [#map0, #map0, #map0],
     iterator_types = ["parallel", "parallel"] }
-    ins(%3, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
+    ins(%3, %arg0 : tensor<?x?xi32>, tensor<?x?xi32>)
     outs(%2 : tensor<?x?xi32>) {
     ^bb0(%arg2: i32, %arg3: i32, %arg4: i32):       // no predecessors
       %10 = addi %arg2, %arg3 : i32
@@ -497,7 +496,7 @@ func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>,
 //   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func @indexed_producer_consumer_fusion
 //       CHECK: linalg.generic
-// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
+// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
 //      CHECK: ^{{[a-zA-Z0-9_]*}}
 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
@@ -507,7 +506,7 @@ func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>,
 //      CHECK:   %[[SUB_OPERAND:.+]] = index_cast %[[IDX1]] : index to i32
 //      CHECK:   %[[VAL1:.+]] = addi %[[ARG0]], %[[ADD_OPERAND]] : i32
 //      CHECK:   %[[VAL2:.+]] = subi %[[VAL1]], %[[SUB_OPERAND]] : i32
-//      CHECK:   %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG1]] : i32
+//      CHECK:   %[[VAL3:.+]] = addi %[[VAL2]], %[[ARG0]] : i32
 //      CHECK:   linalg.yield %[[VAL3]] : i32
 //   CHECK-NOT: linalg.generic
 


        


More information about the Mlir-commits mailing list