[Mlir-commits] [mlir] 9fe27bc - [mlir][Linalg] Allow decompose to handle ops when value of `outs` operand is used in payload.

Mahesh Ravishankar llvmlistbot at llvm.org
Thu Jul 28 09:47:32 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-07-28T16:42:54Z
New Revision: 9fe27bca7191308f9d1c4045b81554fb53da3271

URL: https://github.com/llvm/llvm-project/commit/9fe27bca7191308f9d1c4045b81554fb53da3271
DIFF: https://github.com/llvm/llvm-project/commit/9fe27bca7191308f9d1c4045b81554fb53da3271.diff

LOG: [mlir][Linalg] Allow decompose to handle ops when value of `outs` operand is used in payload.

Current implementation of decomposition of Linalg operations wouldnt
work if the `outs` operand values were used within the body of the
operation. Relax this restriction. This potentially sets the stage for
decomposing ops with reduction iterator types (but is not done here
since it requires more study).

Differential Revision: https://reviews.llvm.org/D130527

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
    mlir/test/Dialect/Linalg/decompose-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 889c6f034f897..662ab86d0dc5a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -156,42 +156,41 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
   SmallVector<Value> newInitValues;
   SmallVector<Type> newResultTypes;
 
-  /// The indexing map to use for the new results is obtained by
-  /// - Check if the result is yielded. If so use the same indexing map as the
-  /// corresponding output
-  /// - Identity indexing map if the result is not yielded.
-  Operation *yieldOp = body->getTerminator();
-  auto getResultIndexingMap = [&](OpResult scalarOpResult) -> AffineMap {
-    OpOperand *firstUseInYield = nullptr, *identityUseInYield = nullptr;
-    for (OpOperand &use : scalarOpResult.getUses()) {
-      if (use.getOwner() != yieldOp)
-        continue;
-      if (!firstUseInYield)
-        firstUseInYield = &use;
-      OpResult genericOpResult =
-          genericOp.getResult(use.getOperandNumber()).cast<OpResult>();
-      AffineMap indexingMap =
-          genericOp.getTiedIndexingMapForResult(genericOpResult);
-      if (indexingMap.isIdentity())
-        identityUseInYield = &use;
+  // Add as many new results as the number of results of the peeled scalar op.
+  for (auto scalarOpResult : peeledScalarOperation->getResults()) {
+    // If the result is yielded by the original op, use the operand, indexing
+    // map and result type that correspond to the yielded value.
+
+    Optional<unsigned> resultNumber;
+    for (auto user : scalarOpResult.getUsers()) {
+      if (auto yieldOp = dyn_cast<YieldOp>(user)) {
+        // Find the first use of the `scalarOpResult` in the yield op.
+        for (OpOperand &yieldOperand : yieldOp->getOpOperands()) {
+          if (yieldOperand.get() == scalarOpResult) {
+            resultNumber = yieldOperand.getOperandNumber();
+            break;
+          }
+        }
+        assert(resultNumber && "unable to find use of a value in its user");
+        break;
+      }
+    }
+    if (resultNumber) {
+      newInitValues.push_back(genericOp.getOutputOperand(*resultNumber)->get());
+      OpResult result = genericOp.getResult(*resultNumber).cast<OpResult>();
+      newResultTypes.push_back(result.getType());
+      peeledGenericOpIndexingMaps.push_back(
+          genericOp.getTiedIndexingMapForResult(result));
+      continue;
     }
-    if (identityUseInYield || !firstUseInYield)
-      return rewriter.getMultiDimIdentityMap(domain.size());
-    OpResult genericOpResult =
-        genericOp.getResult(firstUseInYield->getOperandNumber())
-            .cast<OpResult>();
-    return genericOp.getTiedIndexingMapForResult(genericOpResult);
-  };
-
-  for (auto scalarResult : peeledScalarOperation->getResults()) {
-    AffineMap resultIndexingMap = getResultIndexingMap(scalarResult);
-    SmallVector<OpFoldResult> initSize =
-        permuteValues(domain, resultIndexingMap);
+
+    // Fall back path, use an `init_tensor` and identity indexing map.
+    AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size());
     Value initTensor = rewriter.create<linalg::InitTensorOp>(
-        loc, initSize, scalarResult.getType());
+        loc, domain, scalarOpResult.getType());
     newInitValues.push_back(initTensor);
     newResultTypes.push_back(initTensor.getType());
-    peeledGenericOpIndexingMaps.push_back(resultIndexingMap);
+    peeledGenericOpIndexingMaps.push_back(indexingMap);
   }
 
   /// Create the peeled generic op with an empty body.
@@ -263,17 +262,6 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
         genericOp, "only operations with tensor semantics are handled");
   }
 
-  // TODO: For now only decompose operations where the `outs` operands values
-  // are not accessed within the payload. This might be relaxed in future, but
-  // needs a bit more reasoning to ensure that it is safe.
-  if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
-        return genericOp.payloadUsesValueFromOperand(outOperand);
-      })) {
-    return rewriter.notifyMatchFailure(
-        genericOp, "unhandled decomposition of generic op with use of out "
-                   "operand value in payload");
-  }
-
   if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) {
         return !genericOp.getTiedIndexingMap(outOperand).isPermutation();
       })) {

diff  --git a/mlir/test/Dialect/Linalg/decompose-ops.mlir b/mlir/test/Dialect/Linalg/decompose-ops.mlir
index 648a58eb87b30..cd8bd9852e955 100644
--- a/mlir/test/Dialect/Linalg/decompose-ops.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-ops.mlir
@@ -147,10 +147,10 @@ func.func @simple_op_permuted_outputs(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x
 //  CHECK-DAG:   %[[INIT1:.+]] = linalg.init_tensor [%[[D1]], %[[D0]]]
 //  CHECK-DAG:   %[[INIT2:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]]
 //  CHECK-DAG:   %[[GENERIC1:.+]]:4 = linalg.generic
-// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP0]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP0]], #[[MAP0]], #[[MAP3]]]
 // CHECK-SAME:       ["parallel", "parallel"]
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] :
-// CHECK-SAME:       outs(%[[INIT1]], %[[INIT2]], %[[INIT2]], %[[INIT2]] :
+// CHECK-SAME:       outs(%[[INIT1]], %[[INIT2]], %[[INIT2]], %[[INIT1]] :
 // CHECK-NEXT:   ^bb0(
 // CHECK-SAME:       %[[B0:[a-zA-Z0-9]+]]: f32
 // CHECK-SAME:       %[[B1:[a-zA-Z0-9]+]]: f32
@@ -162,7 +162,7 @@ func.func @simple_op_permuted_outputs(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x
 // CHECK-NEXT:     %[[S0:.+]] = arith.addf %[[B0]], %[[B1]]
 // CHECK-NEXT:     linalg.yield %[[S0]], %{{[a-zA-Z0-9]+}}, %[[S0]]
 //      CHECK:   %[[GENERIC2:.+]]:3 = linalg.generic
-// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]], #[[MAP0]], #[[MAP0]]]
+// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP3]], #[[MAP0]], #[[MAP0]]]
 // CHECK-SAME:       ["parallel", "parallel"]
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[GENERIC1]]#3 :
 // CHECK-SAME:       outs(%[[INIT1]], %[[INIT2]], %[[INIT2]] :
@@ -203,9 +203,9 @@ func.func @simple_op_permuted_outputs(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x
 // CANONICALIZECHECK-NEXT:     %[[S0:.+]] = arith.addf %[[B0]], %[[B1]]
 // CANONICALIZECHECK-NEXT:     linalg.yield %[[S0]], %[[S0]]
 //      CANONICALIZECHECK:   %[[GENERIC2:.+]] = linalg.generic
-// CANONICALIZECHECK-SAME:       [#[[MAP3]], #[[MAP0]], #[[MAP0]]]
+// CANONICALIZECHECK-SAME:       [#[[MAP3]], #[[MAP2]], #[[MAP0]]]
 // CANONICALIZECHECK-SAME:       ["parallel", "parallel"]
-// CANONICALIZECHECK-SAME:       ins(%[[ARG2]], %[[GENERIC1]]#1 :
+// CANONICALIZECHECK-SAME:       ins(%[[ARG2]], %[[GENERIC1]]#0 :
 // CANONICALIZECHECK-SAME:       outs(%[[INIT2]] :
 // CANONICALIZECHECK-NEXT:   ^bb0(
 // CANONICALIZECHECK-SAME:       %[[B4:[a-zA-Z0-9]+]]: f32
@@ -324,3 +324,95 @@ func.func @multi_statement(%arg0 : tensor<10x20xf32>, %arg1 : tensor<10xi32>) ->
 // CANONICALIZECHECK-NEXT:       %[[S2:.+]] = arith.addf %[[B4]], %[[B5]] : f64
 // CANONICALIZECHECK-NEXT:       linalg.yield %[[S2]]
 //      CANONICALIZECHECK:   return %[[GENERIC2]]
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d0)>
+#map1 = affine_map<(d0, d1) -> (d1)>
+#map2 = affine_map<(d0, d1) -> (d0, d1)>
+#map3 = affine_map<(d0, d1) -> (d1, d0)>
+func.func @destination_passing_style(
+    %arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>,
+    %arg2 : tensor<?x?xf32>, %arg3 : tensor<?x?xf32>)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0:2 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map2, #map3],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+      outs(%arg2, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32) :
+        %1 = arith.addf %b0, %b2 : f32
+        %2 = arith.mulf %b1, %b3 : f32
+        linalg.yield %1, %2 : f32, f32
+    } -> (tensor<?x?xf32>, tensor<?x?xf32>)
+  return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//      CHECK: func.func @destination_passing_style(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//      CHECK:   %[[GENERIC1:.+]]:3 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
+// CHECK-SAME:       outs(%[[ARG2]], %[[ARG3]], %[[ARG2]] :
+// CHECK-NEXT:   ^bb0(
+// CHECK-SAME:       %[[ARG4:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[ARG5:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[ARG6:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[ARG7:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[ARG8:[a-zA-Z0-9]+]]: f32
+// CHECK-NEXT:     %[[S1:.+]] = arith.addf %[[ARG4]], %[[ARG6]]
+// CHECK-NEXT:     linalg.yield %[[S1]], %{{.+}}, %[[S1]]
+//      CHECK:   %[[GENERIC2:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]], %[[GENERIC1]]#2 :
+// CHECK-SAME:       outs(%[[ARG2]], %[[ARG3]] :
+// CHECK-NEXT:   ^bb0(
+// CHECK-SAME:       %[[ARG9:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[ARG10:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[ARG11:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[ARG12:[a-zA-Z0-9]+]]: f32
+// CHECK-SAME:       %[[ARG13:[a-zA-Z0-9]+]]: f32
+// CHECK-NEXT:     %[[S2:.+]] = arith.mulf %[[ARG10]], %[[ARG12]]
+// CHECK-NEXT:     linalg.yield %[[ARG6]], %[[S2]]
+//      CHECK:   return %[[GENERIC1]]#0, %[[GENERIC2]]#1
+
+//  CANONICALIZECHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)>
+//  CANONICALIZECHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CANONICALIZECHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d1)>
+//  CANONICALIZECHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//      CANONICALIZECHECK: func.func @destination_passing_style(
+// CANONICALIZECHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?xf32>
+// CANONICALIZECHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?xf32>
+// CANONICALIZECHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CANONICALIZECHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//      CANONICALIZECHECK:   %[[GENERIC1:.+]] = linalg.generic
+// CANONICALIZECHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CANONICALIZECHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CANONICALIZECHECK-SAME:       ins(%[[ARG0]] :
+// CANONICALIZECHECK-SAME:       outs(%[[ARG2]] :
+// CANONICALIZECHECK-NEXT:   ^bb0(
+// CANONICALIZECHECK-SAME:       %[[ARG4:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[ARG5:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-NEXT:     %[[S1:.+]] = arith.addf %[[ARG4]], %[[ARG5]]
+// CANONICALIZECHECK-NEXT:     linalg.yield %[[S1]]
+//      CANONICALIZECHECK:   %[[GENERIC2:.+]]:2 = linalg.generic
+// CANONICALIZECHECK-SAME:       indexing_maps = [#[[MAP2]], #[[MAP1]], #[[MAP1]], #[[MAP3]]]
+// CANONICALIZECHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CANONICALIZECHECK-SAME:       ins(%[[ARG1]], %[[GENERIC1]] :
+// CANONICALIZECHECK-SAME:       outs(%[[ARG2]], %[[ARG3]] :
+// CANONICALIZECHECK-NEXT:   ^bb0(
+// CANONICALIZECHECK-SAME:       %[[ARG4:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[ARG5:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[ARG6:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-SAME:       %[[ARG7:[a-zA-Z0-9]+]]: f32
+// CANONICALIZECHECK-NEXT:     %[[S2:.+]] = arith.mulf %[[ARG4]], %[[ARG6]]
+// CANONICALIZECHECK-NEXT:     linalg.yield %[[ARG5]], %[[S2]]
+//      CANONICALIZECHECK:   return %[[GENERIC1]], %[[GENERIC2]]#1


        


More information about the Mlir-commits mailing list