[Mlir-commits] [mlir] [MLIR][Linalg] Add specialization for linalg.broadcast (PR #104684)

Javed Absar llvmlistbot at llvm.org
Fri Sep 27 04:21:36 PDT 2024


https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/104684

>From 9a1f106e7b477b102e979d9f572e18c8693fe8ba Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 17 Aug 2024 16:13:25 -0400
Subject: [PATCH 1/2] [MLIR][Linalg] Add specialization for linalg.broadcast

Specialize `linalg.genereic` that are are `linalg.broadcast`
---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  5 ++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 73 +++++++++++++++++--
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 15 ++++
 .../Dialect/Linalg/roundtrip-broadcast.mlir   | 36 +++++++++
 .../Linalg/transform-op-specialize.mlir       | 12 ---
 5 files changed, 121 insertions(+), 20 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 08afdf373f014a..221155a31c34da 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -131,6 +131,11 @@ bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
 /// Returns the scalar fill value if true.
 std::optional<Value> isaFillOpInterface(GenericOp genericOp);
 
+/// Checks whether `genericOp` is semantically equivalent to a
+///  `linalg.broadcast`. Returns broadcast dimension if true.
+std::optional<SmallVector<int64_t>>
+isaBroadcastOpInterface(GenericOp genericOp);
+
 namespace detail {
 
 /// Returns true if the block contains a contraction of the following form:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6ee1810c2ff2b9..95c2fbbb00bc2c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include <algorithm>
+#include <numeric>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -29,6 +30,24 @@ using namespace mlir::linalg;
 /// Include the definitions of the copy operation interface.
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
 
+namespace {
+/// Check linalg generic with single input output has
+/// body that is just a yield op yielding input value.
+static bool bodyIsJustYieldOp(GenericOp genericOp) {
+  assert(genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
+         "expected single input output to linalg.generic");
+  Block *body = genericOp.getBody();
+  if (body->getOperations().size() != 1)
+    return false;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0) != body->getArgument(0))
+    return false;
+  return true;
+}
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Interface utility functions
 //===----------------------------------------------------------------------===//
@@ -52,7 +71,6 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
 //===----------------------------------------------------------------------===//
 // CopyOpInterface implementation
 //===----------------------------------------------------------------------===//
-
 bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
   // Structural.
   if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
@@ -85,18 +103,57 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
     return std::nullopt;
 
   OpOperand *value = genericOp.getDpsInputOperand(0);
-  if (!genericOp.isScalar(value))
+  if (!genericOp.isScalar(value) || !bodyIsJustYieldOp(genericOp))
+    return std::nullopt;
+  return value->get();
+}
+
+//===----------------------------------------------------------------------===//
+// BroadcastOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaBroadcastOpInterface(GenericOp genericOp) {
+
+  // Structural.
+  if ((genericOp.getNumParallelLoops() != genericOp.getNumLoops()) ||
+      genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1 ||
+      !bodyIsJustYieldOp(genericOp))
     return std::nullopt;
 
-  Block *body = genericOp.getBody();
-  if (body->getOperations().size() != 1)
+  auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
+  auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
+  if (!isa<MemRefType, RankedTensorType>(t0) ||
+      !isa<MemRefType, RankedTensorType>(t1))
     return std::nullopt;
 
-  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
-  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
-      yieldOp->getOperand(0) != body->getArgument(0))
+  // Check output is identity map. Injective function could also be
+  // a permutation of indices and expressible in linalg.generic but
+  // is not expressible for named broadcast op.
+  auto dstMap = genericOp.getIndexingMapsArray()[1];
+  if (!dstMap.isIdentity())
     return std::nullopt;
-  return value->get();
+
+  SmallVector<int64_t> position;
+  auto srcMap = genericOp.getIndexingMapsArray()[0];
+
+  // Check input map is monotonically increasing DimIds.
+  for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
+    auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
+    if (!expr)
+      return std::nullopt;
+    int64_t pos = expr.getPosition();
+    if (i > 0 && pos <= position[i - 1])
+      return std::nullopt;
+    position.push_back(expr.getPosition());
+  }
+
+  SmallVector<int64_t> broadcastedDims;
+  auto numDims = srcMap.getNumDims();
+  for (auto dim : llvm::seq<int64_t>(0, numDims)) {
+    if (!llvm::is_contained(position, dim))
+      broadcastedDims.push_back(dim);
+  }
+  return broadcastedDims;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4d7b748d7200e2..801622752c8404 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -259,18 +259,31 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
 //===----------------------------------------------------------------------===//
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
                                                       GenericOp genericOp) {
+  // Copy
   if (isaCopyOpInterface(genericOp)) {
     LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
 
+  // Fill
   if (isaFillOpInterface(genericOp)) {
     LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
 
+  // Broadcast
+  std::optional<SmallVector<int64_t>> equivalentToBroadcast
+          = isaBroadcastOpInterface(genericOp);
+  if (equivalentToBroadcast) {
+    auto dims = *equivalentToBroadcast;
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0], dims);
+    return namedOp;
+  }
+
+  // Elementwise Unary
   if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
     Operation *op = &genericOp.getBody()->front();
     if (isa<math::ExpOp>(op)) {
@@ -279,6 +292,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     }
   }
 
+  // Elementwise Binary
   if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
     bool swap = areBinOpsSwapped(genericOp);
     Operation *op = &genericOp.getBody()->front();
@@ -300,6 +314,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     }
   }
 
+  // Contraction
   if (isaContractionOpInterface(genericOp)) {
     return specializeLinalgContractions(rewriter, genericOp);
   }
diff --git a/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
new file mode 100644
index 00000000000000..10d7ba826f79f9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: broadcast_first_dimension
+// CHECK-SAME:   %[[A:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) dimensions = [0]
+//
+func.func @broadcast_first_dimension(%A: tensor<?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+   %res = linalg.broadcast ins(%A: tensor<?x?xf32>) outs(%Out: tensor<?x?x?xf32>) dimensions = [0]
+  return %res : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: broadcast_mid_dimension
+// CHECK-SAME:   %[[A:.+]]: tensor<3x5xf32>, %[[Out:.+]]: tensor<3x4x5xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<3x5xf32>) outs(%[[Out]] : tensor<3x4x5xf32>) dimensions = [1]
+//
+func.func @broadcast_mid_dimension(%A: tensor<3x5xf32>, %Out: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
+   %res = linalg.broadcast ins(%A: tensor<3x5xf32>) outs(%Out: tensor<3x4x5xf32>) dimensions = [1]
+  return %res : tensor<3x4x5xf32>
+}
+
+
+// CHECK-LABEL: broadcast_multiple_dimensions
+// CHECK-SAME:   %[[A:.+]]: tensor<4x5x7xf32>, %[[Out:.+]]: tensor<3x4x5x6x7x8x9xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<4x5x7xf32>) outs(%[[Out]] : tensor<3x4x5x6x7x8x9xf32>) dimensions = [0, 3, 5, 6]
+//
+func.func @broadcast_multiple_dimensions(%A: tensor<4x5x7xf32>, %Out: tensor<3x4x5x6x7x8x9xf32>) -> tensor<3x4x5x6x7x8x9xf32> {
+   %res = linalg.broadcast ins(%A: tensor<4x5x7xf32>) outs(%Out: tensor<3x4x5x6x7x8x9xf32>) dimensions = [0,3,5,6]
+  return %res : tensor<3x4x5x6x7x8x9xf32>
+}
+
+
+
+
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 35679db7412f30..31f2f6b1ab513f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -4,18 +4,6 @@
 #map1 = affine_map<(d0, d1) -> (d0)>
 #map2 = affine_map<(d0, d1) -> (d1, d0)>
 
-func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
-  // expected-note @below {{when applied to this op}}
-  linalg.generic {
-    indexing_maps = [#map1, #map], 
-    iterator_types = ["parallel", "parallel"]}
-    ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
-    ^bb0(%in: f32, %out: f32):
-      linalg.yield %in : f32
-  }
-  return
-}
-
 func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
   // expected-note @below {{when applied to this op}}
   linalg.generic {

>From 83acd956069c1f9c12eee7117c9aa0d63ba5c3cd Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Sat, 17 Aug 2024 16:48:57 -0400
Subject: [PATCH 2/2] [MLIR][Linalg] Fix clang-format complain

---
 mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 801622752c8404..c9afbcd2460512 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -274,12 +274,13 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
   }
 
   // Broadcast
-  std::optional<SmallVector<int64_t>> equivalentToBroadcast
-          = isaBroadcastOpInterface(genericOp);
+  std::optional<SmallVector<int64_t>> equivalentToBroadcast =
+      isaBroadcastOpInterface(genericOp);
   if (equivalentToBroadcast) {
     auto dims = *equivalentToBroadcast;
     LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
-        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0], dims);
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+        dims);
     return namedOp;
   }
 



More information about the Mlir-commits mailing list