[Mlir-commits] [mlir] [mlir][linalg] raise generic to named ops. (PR #110421)

Javed Absar llvmlistbot at llvm.org
Thu Oct 10 04:35:50 PDT 2024


https://github.com/javedabsar1 updated https://github.com/llvm/llvm-project/pull/110421

>From a38ba01c84c78c09d462a1d432bfa6486b71ac12 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Fri, 27 Sep 2024 09:35:04 -0400
Subject: [PATCH 1/2] [mlir][linalg] raise generic to named ops.

Add support for specializing linalg.broadcast and linalg.transform
from generic. Also, refactoring to reuse specialization checks.
---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h |  10 ++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 115 +++++++++++++++---
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  27 ++++
 .../Dialect/Linalg/roundtrip-broadcast.mlir   |  32 +++++
 .../Dialect/Linalg/roundtrip-transpose.mlir   |  11 ++
 .../Linalg/transform-op-specialize.mlir       |  12 --
 6 files changed, 180 insertions(+), 27 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
 create mode 100644 mlir/test/Dialect/Linalg/roundtrip-transpose.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 0fcaa96ade4031..6f1c243cc4396d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -120,6 +120,16 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp,
 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
 bool isaCopyOpInterface(LinalgOp linalgOp);
 
+/// Checks whether `genericOp` is semantically equivalent to a
+///  `linalg.broadcast`. Returns broadcast dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaBroadcastOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a
+///  `linalg.transpose`. Returns permuted dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaTransposeOpInterface(GenericOp genericOp);
+
 /// Checks whether a given `genericOp` is semantically equivalent to a single
 /// linalgelementwise unary op. e.g. linalg.exp.
 /// A linalg.generic body could be a series of unary elementwise ops e.g.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 0b5191664a9e2f..5842128091972a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include <algorithm>
+#include <numeric>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -49,18 +50,41 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
   return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
 }
 
+// Returns true if all loops of the linalgOp are parallel
+static bool isAllParallel(LinalgOp op) {
+  return op.getNumParallelLoops() == op.getNumLoops();
+}
+
+// Returns true if and only if linalgOp takes one input and one init.
+static bool isSingleInputOutput(LinalgOp op) {
+  return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1;
+}
+// Returns true if genericOp body is just a yieldOp that yields
+// input operand as result.
+static bool isSingleYieldOp(GenericOp op) {
+  if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1)
+    return false;
+
+  Block *body = op.getBody();
+  if (body->getOperations().size() != 1)
+    return false;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0) != body->getArgument(0))
+    return false;
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // CopyOpInterface implementation
 //===----------------------------------------------------------------------===//
 
 bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
-  // Structural.
-  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+  // Structural and operands
+  if (!isAllParallel(linalgOp) || !isSingleInputOutput(linalgOp))
     return false;
 
-  // Operands and maps.
-  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
-    return false;
   auto mapRange = linalgOp.getIndexingMapsArray();
   if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
       !mapRange.back().isIdentity()) {
@@ -75,8 +99,8 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
 //===----------------------------------------------------------------------===//
 std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
   // Structural.
-  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
-      genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+      !isSingleYieldOp(genericOp))
     return std::nullopt;
 
   // Input should be referenced and init should not.
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
   OpOperand *value = genericOp.getDpsInputOperand(0);
   if (!genericOp.isScalar(value))
     return std::nullopt;
+  return value->get();
+}
 
-  Block *body = genericOp.getBody();
-  if (body->getOperations().size() != 1)
+//===----------------------------------------------------------------------===//
+// BroadcastOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaBroadcastOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+      !isSingleYieldOp(genericOp))
     return std::nullopt;
 
-  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
-  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
-      yieldOp->getOperand(0) != body->getArgument(0))
+  auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
+  auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
+  if (!isa<MemRefType, RankedTensorType>(t0) ||
+      !isa<MemRefType, RankedTensorType>(t1))
     return std::nullopt;
-  return value->get();
+
+  // Check output is identity map. Injective function could also be
+  // a permutation of indices and expressible in linalg.generic but
+  // is not expressible for named broadcast op.
+  auto dstMap = genericOp.getIndexingMapsArray()[1];
+  if (!dstMap.isIdentity())
+    return std::nullopt;
+
+  SmallVector<int64_t> position;
+  auto srcMap = genericOp.getIndexingMapsArray()[0];
+
+  // Check input map is monotonically increasing DimIds.
+  for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
+    auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
+    if (!expr)
+      return std::nullopt;
+    int64_t pos = expr.getPosition();
+    if (i > 0 && pos <= position[i - 1])
+      return std::nullopt;
+    position.push_back(expr.getPosition());
+  }
+
+  SmallVector<int64_t> broadcastedDims;
+  auto numDims = srcMap.getNumDims();
+  for (auto dim : llvm::seq<int64_t>(0, numDims)) {
+    if (!llvm::is_contained(position, dim))
+      broadcastedDims.push_back(dim);
+  }
+  return broadcastedDims;
+}
+
+//===----------------------------------------------------------------------===//
+// TranposeOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaTransposeOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+      !isSingleYieldOp(genericOp))
+    return std::nullopt;
+
+  // mapping checks.
+  auto mapRange = genericOp.getIndexingMapsArray();
+  if (mapRange.size() != 2 || !mapRange.back().isIdentity() ||
+      !mapRange.front().isPermutation())
+    return std::nullopt;
+
+  SmallVector<int64_t> permutation;
+  auto map = mapRange.front();
+  for (unsigned i = 0; i < map.getNumResults(); ++i) {
+    auto expr = llvm::cast<AffineDimExpr>(map.getResults()[i]);
+    permutation.push_back(expr.getPosition());
+  }
+  return permutation;
 }
 
 //===----------------------------------------------------------------------===//
@@ -106,8 +192,7 @@ static bool
 isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
                                           unsigned arity) {
   // Check all loops are parallel.
-  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
-      genericOp.getNumLoops() < 1)
+  if (!isAllParallel(genericOp) || genericOp.getNumLoops() < 1)
     return false;
 
   // Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4d7b748d7200e2..dfafffce9d9b60 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -259,18 +259,43 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
 //===----------------------------------------------------------------------===//
 FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
                                                       GenericOp genericOp) {
+  // Copy
   if (isaCopyOpInterface(genericOp)) {
     LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
 
+  // Fill
   if (isaFillOpInterface(genericOp)) {
     LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
 
+  // Broadcast
+  std::optional<SmallVector<int64_t>> equivalentToBroadcast =
+      isaBroadcastOpInterface(genericOp);
+  if (equivalentToBroadcast) {
+    auto dims = *equivalentToBroadcast;
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+        dims);
+    return namedOp;
+  }
+
+  // Transpose
+  std::optional<SmallVector<int64_t>> equivalentToTranspose =
+      isaTransposeOpInterface(genericOp);
+  if (equivalentToTranspose) {
+    auto permutation = *equivalentToTranspose;
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+        permutation);
+    return namedOp;
+  }
+
+  // Elementwise Unary
   if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
     Operation *op = &genericOp.getBody()->front();
     if (isa<math::ExpOp>(op)) {
@@ -279,6 +304,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     }
   }
 
+  // Elementwise Binary
   if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
     bool swap = areBinOpsSwapped(genericOp);
     Operation *op = &genericOp.getBody()->front();
@@ -300,6 +326,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
     }
   }
 
+  // Contraction - e.g. matmul
   if (isaContractionOpInterface(genericOp)) {
     return specializeLinalgContractions(rewriter, genericOp);
   }
diff --git a/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
new file mode 100644
index 00000000000000..d6915ec8fbbf6f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: broadcast_first_dimension
+// CHECK-SAME:   %[[A:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) dimensions = [0]
+//
+func.func @broadcast_first_dimension(%A: tensor<?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+   %res = linalg.broadcast ins(%A: tensor<?x?xf32>) outs(%Out: tensor<?x?x?xf32>) dimensions = [0]
+  return %res : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: broadcast_mid_dimension
+// CHECK-SAME:   %[[A:.+]]: tensor<3x5xf32>, %[[Out:.+]]: tensor<3x4x5xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<3x5xf32>) outs(%[[Out]] : tensor<3x4x5xf32>) dimensions = [1]
+//
+func.func @broadcast_mid_dimension(%A: tensor<3x5xf32>, %Out: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
+   %res = linalg.broadcast ins(%A: tensor<3x5xf32>) outs(%Out: tensor<3x4x5xf32>) dimensions = [1]
+  return %res : tensor<3x4x5xf32>
+}
+
+
+// CHECK-LABEL: broadcast_multiple_dimensions
+// CHECK-SAME:   %[[A:.+]]: tensor<4x5x7xf32>, %[[Out:.+]]: tensor<3x4x5x6x7x8x9xf32>)
+// CHECK-NOT:     linalg.generic
+// CHECK:         %broadcasted = linalg.broadcast ins(%[[A]] : tensor<4x5x7xf32>) outs(%[[Out]] : tensor<3x4x5x6x7x8x9xf32>) dimensions = [0, 3, 5, 6]
+//
+func.func @broadcast_multiple_dimensions(%A: tensor<4x5x7xf32>, %Out: tensor<3x4x5x6x7x8x9xf32>) -> tensor<3x4x5x6x7x8x9xf32> {
+   %res = linalg.broadcast ins(%A: tensor<4x5x7xf32>) outs(%Out: tensor<3x4x5x6x7x8x9xf32>) dimensions = [0,3,5,6]
+  return %res : tensor<3x4x5x6x7x8x9xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
new file mode 100644
index 00000000000000..ebc42c903e6e3e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: linalg_transpose
+// CHECK-SAME:  %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32>
+// CHECK-NOT:   linalg.generic
+// CHECK:  %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0]
+//
+func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
+  %res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0]
+  return %res : tensor<64x16xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 35679db7412f30..31f2f6b1ab513f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -4,18 +4,6 @@
 #map1 = affine_map<(d0, d1) -> (d0)>
 #map2 = affine_map<(d0, d1) -> (d1, d0)>
 
-func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
-  // expected-note @below {{when applied to this op}}
-  linalg.generic {
-    indexing_maps = [#map1, #map], 
-    iterator_types = ["parallel", "parallel"]}
-    ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
-    ^bb0(%in: f32, %out: f32):
-      linalg.yield %in : f32
-  }
-  return
-}
-
 func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
   // expected-note @below {{when applied to this op}}
   linalg.generic {

>From b65f967d57cb69f5802707863a1f520f42879fd4 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Tue, 8 Oct 2024 18:38:18 -0400
Subject: [PATCH 2/2] [mlir][linalg] revise based on review comments

---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |  24 +++
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  18 +++
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 150 ++++++++----------
 .../Dialect/Linalg/roundtrip-transpose.mlir   |  15 +-
 .../Linalg/specialize-generic-ops-fail.mlir   |  16 ++
 5 files changed, 138 insertions(+), 85 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9b..0a404194569c22 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -243,6 +243,18 @@ def LinalgStructuredInterface
                            utils::IteratorType::parallel);
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return true if all loops are parallel.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isAllParallelLoops",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return getNumParallelLoops() ==  getNumParallelLoops();
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the dims that are parallel loops.
@@ -327,6 +339,18 @@ def LinalgStructuredInterface
         return !getBlock()->getArgument(bbArgNumber).use_empty();
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true only if linalgOp takes one input and produces one result.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isSingleInputOutput",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op.getNumDpsInputs() == 1 && $_op.getNumDpsInits() == 1;
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return true if `opOperand` is an init tensor. This is true when it is
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 31f29139247267..a27c666a2aba46 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -210,6 +210,24 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
     }
 
     MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
+
+    // Return true only if GenericOp has a single input and single
+    // output, and the body is a single yieldOp that yields the input.
+    // This check is useful when trying to determine if the op is
+    // essentially a transpose, broadcast, copy or something like that.
+    bool isSingleYieldOp() {
+      if (!isSingleInputOutput())
+        return false;
+     Block *body = getBody();
+     if (body->getOperations().size() != 1)
+       return false;
+
+     auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+       if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+           yieldOp->getOperand(0) != body->getArgument(0))
+         return false;
+     return true;
+   }
   }];
 
   let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 5842128091972a..73d617fcffeba6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -50,66 +50,40 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
   return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
 }
 
-// Returns true if all loops of the linalgOp are parallel
-static bool isAllParallel(LinalgOp op) {
-  return op.getNumParallelLoops() == op.getNumLoops();
-}
-
-// Returns true if and only if linalgOp takes one input and one init.
-static bool isSingleInputOutput(LinalgOp op) {
-  return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1;
-}
-// Returns true if genericOp body is just a yieldOp that yields
-// input operand as result.
-static bool isSingleYieldOp(GenericOp op) {
-  if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1)
-    return false;
-
-  Block *body = op.getBody();
-  if (body->getOperations().size() != 1)
-    return false;
-
-  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
-  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
-      yieldOp->getOperand(0) != body->getArgument(0))
-    return false;
-  return true;
-}
-
 //===----------------------------------------------------------------------===//
 // CopyOpInterface implementation
 //===----------------------------------------------------------------------===//
 
-bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
-  // Structural and operands
-  if (!isAllParallel(linalgOp) || !isSingleInputOutput(linalgOp))
+bool linalg::isaCopyOpInterface(LinalgOp op) {
+  // Check all loops are parallel and linalgOp is single input and output.
+  if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
     return false;
 
-  auto mapRange = linalgOp.getIndexingMapsArray();
+  auto mapRange = op.getIndexingMapsArray();
   if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
       !mapRange.back().isIdentity()) {
     return false;
   }
   // Region.
-  return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
+  return llvm::hasSingleElement(op.getBlock()->getOperations());
 }
 
 //===----------------------------------------------------------------------===//
 // FillOpInterface implementation
 //===----------------------------------------------------------------------===//
-std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
+std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
   // Structural.
-  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
-      !isSingleYieldOp(genericOp))
+  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
+      !op.isSingleYieldOp())
     return std::nullopt;
 
   // Input should be referenced and init should not.
-  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
-      genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+  if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
+      op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
     return std::nullopt;
 
-  OpOperand *value = genericOp.getDpsInputOperand(0);
-  if (!genericOp.isScalar(value))
+  OpOperand *value = op.getDpsInputOperand(0);
+  if (!op.isScalar(value))
     return std::nullopt;
   return value->get();
 }
@@ -118,27 +92,30 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
 // BroadcastOpInterface implementation
 //===----------------------------------------------------------------------===//
 std::optional<SmallVector<int64_t>>
-linalg::isaBroadcastOpInterface(GenericOp genericOp) {
+linalg::isaBroadcastOpInterface(GenericOp op) {
   // Structural.
-  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
-      !isSingleYieldOp(genericOp))
+  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
+      !op.isSingleYieldOp())
     return std::nullopt;
 
-  auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
-  auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
-  if (!isa<MemRefType, RankedTensorType>(t0) ||
-      !isa<MemRefType, RankedTensorType>(t1))
+  auto srcTy = op.getDpsInputOperand(0)->get().getType();
+  auto dstTy = op.getDpsInitOperand(0)->get().getType();
+  if (!isa<MemRefType, RankedTensorType>(srcTy) ||
+      !isa<MemRefType, RankedTensorType>(dstTy))
     return std::nullopt;
 
-  // Check output is identity map. Injective function could also be
-  // a permutation of indices and expressible in linalg.generic but
-  // is not expressible for named broadcast op.
-  auto dstMap = genericOp.getIndexingMapsArray()[1];
+  // Check output is identity map. Broadcast could additionally be
+  // employing permutation of indices and that would be expressible
+  // in linalg.generic but is not expressible for named broadcast op.
+  auto dstMap = op.getIndexingMapsArray()[1];
   if (!dstMap.isIdentity())
     return std::nullopt;
 
   SmallVector<int64_t> position;
-  auto srcMap = genericOp.getIndexingMapsArray()[0];
+  auto srcMap = op.getIndexingMapsArray()[0];
+
+  if (srcMap.getResults().size() >= dstMap.getResults().size())
+    return std::nullopt;
 
   // Check input map is monotonically increasing DimIds.
   for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
@@ -153,6 +130,7 @@ linalg::isaBroadcastOpInterface(GenericOp genericOp) {
 
   SmallVector<int64_t> broadcastedDims;
   auto numDims = srcMap.getNumDims();
+  // This is quadratic but number of items is generally small.
   for (auto dim : llvm::seq<int64_t>(0, numDims)) {
     if (!llvm::is_contained(position, dim))
       broadcastedDims.push_back(dim);
@@ -164,23 +142,30 @@ linalg::isaBroadcastOpInterface(GenericOp genericOp) {
 // TranposeOpInterface implementation
 //===----------------------------------------------------------------------===//
 std::optional<SmallVector<int64_t>>
-linalg::isaTransposeOpInterface(GenericOp genericOp) {
-  // Structural.
-  if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
-      !isSingleYieldOp(genericOp))
+linalg::isaTransposeOpInterface(GenericOp op) {
+  // To specialize as a transpose op, the genericOp must be
+  // all parallel loops, single input, single output, and its body
+  // should be just a yield op, yielding input as output as is (no compute).
+  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
+      !op.isSingleYieldOp())
     return std::nullopt;
 
-  // mapping checks.
-  auto mapRange = genericOp.getIndexingMapsArray();
-  if (mapRange.size() != 2 || !mapRange.back().isIdentity() ||
-      !mapRange.front().isPermutation())
+  auto mapRange = op.getIndexingMapsArray();
+  if (mapRange.size() != 2)
     return std::nullopt;
 
-  SmallVector<int64_t> permutation;
-  auto map = mapRange.front();
-  for (unsigned i = 0; i < map.getNumResults(); ++i) {
-    auto expr = llvm::cast<AffineDimExpr>(map.getResults()[i]);
-    permutation.push_back(expr.getPosition());
+  auto mapOfInput = mapRange.front();
+  auto mapOfResult = mapRange.back();
+
+  // linalg.transpose permutes the dimensions of input using this
+  // rule: dim(result, i) = dim(input, permutation[i])
+  if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
+    return std::nullopt;
+
+  SmallVector<int64_t> permutation(mapOfInput.getNumDims());
+  for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
+    auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
+    permutation[expr.getPosition()] = i;
   }
   return permutation;
 }
@@ -188,62 +173,61 @@ linalg::isaTransposeOpInterface(GenericOp genericOp) {
 //===----------------------------------------------------------------------===//
 // Elementwise Single Unary/Binary-OpInterface implementation
 //===----------------------------------------------------------------------===//
-static bool
-isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
-                                          unsigned arity) {
+static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
+                                                      unsigned arity) {
   // Check all loops are parallel.
-  if (!isAllParallel(genericOp) || genericOp.getNumLoops() < 1)
+  if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
     return false;
 
   // Check there are arity-inputs, 1-output and all are identity-maps.
-  if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
-      !llvm::all_of(genericOp.getIndexingMapsArray(),
+  if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
+      !llvm::all_of(op.getIndexingMapsArray(),
                     [](AffineMap map) { return map.isIdentity(); }))
     return false;
 
   // Init should not be referenced for elementwise operations.
-  if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+  if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
     return false;
 
   // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
   // as resulting from producer-consumer fusion. Here, we restrict to two ops in
   // the body, where the first is the elementwise single op and the second a
   // yield.
-  Block *body = genericOp.getBody();
+  Block *body = op.getBody();
   if (body->getOperations().size() != 2)
     return false;
 
-  Operation *op = &body->front();
-  if (op->getNumOperands() != arity || op->getNumResults() != 1)
+  Operation *oper = &body->front();
+  if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
     return false;
 
   auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
   if (!yieldOp || yieldOp.getNumOperands() != 1 ||
-      yieldOp->getOperand(0).getDefiningOp() != op)
+      yieldOp->getOperand(0).getDefiningOp() != oper)
     return false;
   return true;
 }
 
-bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
+bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
   // All basic elemwise checks.
-  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1))
+  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1))
     return false;
 
   // Check input is actully used.
-  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
+  if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
     return false;
   return true;
 }
 
-bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
-  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2))
+bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
+  if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2))
     return false;
 
   // Check both inputs are used (elementwise).
-  OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
-  OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
-  if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
-      !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
+  OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
+  OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
+  if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
+      !op.payloadUsesValueFromOperand(inputOpOperand1))
     return false;
   return true;
 }
diff --git a/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
index ebc42c903e6e3e..21b7b348f1c7f8 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
@@ -1,11 +1,22 @@
 // RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
 
-// CHECK-LABEL: linalg_transpose
+// CHECK-LABEL: transpose2D
 // CHECK-SAME:  %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32>
 // CHECK-NOT:   linalg.generic
 // CHECK:  %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0]
 //
-func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
+func.func @transpose2D(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
   %res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0]
   return %res : tensor<64x16xf32>
 }
+
+
+// CHECK-LABEL: transpose3D
+// CHECK-SAME:  %[[A:.+]]: tensor<7x8x9xf32>, %[[Out:.+]]: tensor<9x7x8xf32>
+// CHECK-NOT:   linalg.generic
+// CHECK:  %transposed = linalg.transpose ins(%[[A]] : tensor<7x8x9xf32>) outs(%[[Out]] : tensor<9x7x8xf32>) permutation = [2, 0, 1]
+//
+func.func @transpose3D(%arg0: tensor<7x8x9xf32>, %arg1: tensor<9x7x8xf32>) -> tensor<9x7x8xf32> {
+  %transposed = linalg.transpose ins(%arg0 : tensor<7x8x9xf32>) outs(%arg1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
+  return %transposed : tensor<9x7x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
new file mode 100644
index 00000000000000..542a7ed4a198b8
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// This test checks that linalg.generic does not get incorrectly specialized to transform or broadcast.
+// CHECK-LABEL: @transpose_and_broadcast
+// CHECK: linalg.generic
+func.func @transpose_and_broadcast(%arg0: tensor<7x8xf32>, %arg1: tensor<8x7x9xf32>) -> tensor<8x7x9xf32> {
+  %0 = linalg.generic
+        {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]}
+        ins(%arg0 : tensor<7x8xf32>) outs(%arg1 : tensor<8x7x9xf32>) {
+        ^bb0(%in: f32, %out: f32):
+           linalg.yield %in : f32
+  } -> tensor<8x7x9xf32>
+  return %0 : tensor<8x7x9xf32>
+}



More information about the Mlir-commits mailing list