[Mlir-commits] [mlir] [mlir][linalg] raise generic to named ops. (PR #110421)
Javed Absar
llvmlistbot at llvm.org
Sun Sep 29 03:57:08 PDT 2024
https://github.com/javedabsar1 created https://github.com/llvm/llvm-project/pull/110421
Add support for specializing linalg.broadcast and linalg.transform from generic.
Also, does some refactoring to reuse specialization checks.
>From a38ba01c84c78c09d462a1d432bfa6486b71ac12 Mon Sep 17 00:00:00 2001
From: Javed Absar <javed.absar at gmail.com>
Date: Fri, 27 Sep 2024 09:35:04 -0400
Subject: [PATCH] [mlir][linalg] raise generic to named ops.
Add support for specializing linalg.broadcast and linalg.transform
from generic. Also, refactoring to reuse specialization checks.
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 10 ++
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 115 +++++++++++++++---
.../Dialect/Linalg/Transforms/Specialize.cpp | 27 ++++
.../Dialect/Linalg/roundtrip-broadcast.mlir | 32 +++++
.../Dialect/Linalg/roundtrip-transpose.mlir | 11 ++
.../Linalg/transform-op-specialize.mlir | 12 --
6 files changed, 180 insertions(+), 27 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
create mode 100644 mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 0fcaa96ade4031..6f1c243cc4396d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -120,6 +120,16 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp,
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
bool isaCopyOpInterface(LinalgOp linalgOp);
+/// Checks whether `genericOp` is semantically equivalent to a
+/// `linalg.broadcast`. Returns broadcast dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaBroadcastOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a
+/// `linalg.transpose`. Returns permuted dimensions if true.
+std::optional<SmallVector<int64_t>>
+isaTransposeOpInterface(GenericOp genericOp);
+
/// Checks whether a given `genericOp` is semantically equivalent to a single
/// linalgelementwise unary op. e.g. linalg.exp.
/// A linalg.generic body could be a series of unary elementwise ops e.g.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 0b5191664a9e2f..5842128091972a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -22,6 +22,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
+#include <numeric>
using namespace mlir;
using namespace mlir::linalg;
@@ -49,18 +50,41 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}
+// Returns true if all loops of the linalgOp are parallel
+static bool isAllParallel(LinalgOp op) {
+ return op.getNumParallelLoops() == op.getNumLoops();
+}
+
+// Returns true if and only if linalgOp takes one input and one init.
+static bool isSingleInputOutput(LinalgOp op) {
+ return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1;
+}
+// Returns true if genericOp body is just a yieldOp that yields
+// input operand as result.
+static bool isSingleYieldOp(GenericOp op) {
+ if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1)
+ return false;
+
+ Block *body = op.getBody();
+ if (body->getOperations().size() != 1)
+ return false;
+
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+ if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+ yieldOp->getOperand(0) != body->getArgument(0))
+ return false;
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// CopyOpInterface implementation
//===----------------------------------------------------------------------===//
bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
- // Structural.
- if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
+ // Structural and operands
+ if (!isAllParallel(linalgOp) || !isSingleInputOutput(linalgOp))
return false;
- // Operands and maps.
- if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
- return false;
auto mapRange = linalgOp.getIndexingMapsArray();
if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
!mapRange.back().isIdentity()) {
@@ -75,8 +99,8 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
//===----------------------------------------------------------------------===//
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
// Structural.
- if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
- genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+ if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+ !isSingleYieldOp(genericOp))
return std::nullopt;
// Input should be referenced and init should not.
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
OpOperand *value = genericOp.getDpsInputOperand(0);
if (!genericOp.isScalar(value))
return std::nullopt;
+ return value->get();
+}
- Block *body = genericOp.getBody();
- if (body->getOperations().size() != 1)
+//===----------------------------------------------------------------------===//
+// BroadcastOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaBroadcastOpInterface(GenericOp genericOp) {
+ // Structural.
+ if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+ !isSingleYieldOp(genericOp))
return std::nullopt;
- auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
- if (!yieldOp || yieldOp.getNumOperands() != 1 ||
- yieldOp->getOperand(0) != body->getArgument(0))
+ auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
+ auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
+ if (!isa<MemRefType, RankedTensorType>(t0) ||
+ !isa<MemRefType, RankedTensorType>(t1))
return std::nullopt;
- return value->get();
+
+ // Check output is identity map. Injective function could also be
+ // a permutation of indices and expressible in linalg.generic but
+ // is not expressible for named broadcast op.
+ auto dstMap = genericOp.getIndexingMapsArray()[1];
+ if (!dstMap.isIdentity())
+ return std::nullopt;
+
+ SmallVector<int64_t> position;
+ auto srcMap = genericOp.getIndexingMapsArray()[0];
+
+ // Check input map is monotonically increasing DimIds.
+ for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
+ auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
+ if (!expr)
+ return std::nullopt;
+ int64_t pos = expr.getPosition();
+ if (i > 0 && pos <= position[i - 1])
+ return std::nullopt;
+ position.push_back(expr.getPosition());
+ }
+
+ SmallVector<int64_t> broadcastedDims;
+ auto numDims = srcMap.getNumDims();
+ for (auto dim : llvm::seq<int64_t>(0, numDims)) {
+ if (!llvm::is_contained(position, dim))
+ broadcastedDims.push_back(dim);
+ }
+ return broadcastedDims;
+}
+
+//===----------------------------------------------------------------------===//
+// TranposeOpInterface implementation
+//===----------------------------------------------------------------------===//
+std::optional<SmallVector<int64_t>>
+linalg::isaTransposeOpInterface(GenericOp genericOp) {
+ // Structural.
+ if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
+ !isSingleYieldOp(genericOp))
+ return std::nullopt;
+
+ // mapping checks.
+ auto mapRange = genericOp.getIndexingMapsArray();
+ if (mapRange.size() != 2 || !mapRange.back().isIdentity() ||
+ !mapRange.front().isPermutation())
+ return std::nullopt;
+
+ SmallVector<int64_t> permutation;
+ auto map = mapRange.front();
+ for (unsigned i = 0; i < map.getNumResults(); ++i) {
+ auto expr = llvm::cast<AffineDimExpr>(map.getResults()[i]);
+ permutation.push_back(expr.getPosition());
+ }
+ return permutation;
}
//===----------------------------------------------------------------------===//
@@ -106,8 +192,7 @@ static bool
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
// Check all loops are parallel.
- if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
- genericOp.getNumLoops() < 1)
+ if (!isAllParallel(genericOp) || genericOp.getNumLoops() < 1)
return false;
// Check there are arity-inputs, 1-output and all are identity-maps.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4d7b748d7200e2..dfafffce9d9b60 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -259,18 +259,43 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
+ // Copy
if (isaCopyOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
+ // Fill
if (isaFillOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
+ // Broadcast
+ std::optional<SmallVector<int64_t>> equivalentToBroadcast =
+ isaBroadcastOpInterface(genericOp);
+ if (equivalentToBroadcast) {
+ auto dims = *equivalentToBroadcast;
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+ dims);
+ return namedOp;
+ }
+
+ // Transpose
+ std::optional<SmallVector<int64_t>> equivalentToTranspose =
+ isaTransposeOpInterface(genericOp);
+ if (equivalentToTranspose) {
+ auto permutation = *equivalentToTranspose;
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
+ permutation);
+ return namedOp;
+ }
+
+ // Elementwise Unary
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
Operation *op = &genericOp.getBody()->front();
if (isa<math::ExpOp>(op)) {
@@ -279,6 +304,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
}
+ // Elementwise Binary
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
bool swap = areBinOpsSwapped(genericOp);
Operation *op = &genericOp.getBody()->front();
@@ -300,6 +326,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
}
+ // Contraction - e.g. matmul
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
new file mode 100644
index 00000000000000..d6915ec8fbbf6f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: broadcast_first_dimension
+// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) dimensions = [0]
+//
+func.func @broadcast_first_dimension(%A: tensor<?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %res = linalg.broadcast ins(%A: tensor<?x?xf32>) outs(%Out: tensor<?x?x?xf32>) dimensions = [0]
+ return %res : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: broadcast_mid_dimension
+// CHECK-SAME: %[[A:.+]]: tensor<3x5xf32>, %[[Out:.+]]: tensor<3x4x5xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<3x5xf32>) outs(%[[Out]] : tensor<3x4x5xf32>) dimensions = [1]
+//
+func.func @broadcast_mid_dimension(%A: tensor<3x5xf32>, %Out: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
+ %res = linalg.broadcast ins(%A: tensor<3x5xf32>) outs(%Out: tensor<3x4x5xf32>) dimensions = [1]
+ return %res : tensor<3x4x5xf32>
+}
+
+
+// CHECK-LABEL: broadcast_multiple_dimensions
+// CHECK-SAME: %[[A:.+]]: tensor<4x5x7xf32>, %[[Out:.+]]: tensor<3x4x5x6x7x8x9xf32>)
+// CHECK-NOT: linalg.generic
+// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<4x5x7xf32>) outs(%[[Out]] : tensor<3x4x5x6x7x8x9xf32>) dimensions = [0, 3, 5, 6]
+//
+func.func @broadcast_multiple_dimensions(%A: tensor<4x5x7xf32>, %Out: tensor<3x4x5x6x7x8x9xf32>) -> tensor<3x4x5x6x7x8x9xf32> {
+ %res = linalg.broadcast ins(%A: tensor<4x5x7xf32>) outs(%Out: tensor<3x4x5x6x7x8x9xf32>) dimensions = [0,3,5,6]
+ return %res : tensor<3x4x5x6x7x8x9xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
new file mode 100644
index 00000000000000..ebc42c903e6e3e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+// CHECK-LABEL: linalg_transpose
+// CHECK-SAME: %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0]
+//
+func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
+ %res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0]
+ return %res : tensor<64x16xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 35679db7412f30..31f2f6b1ab513f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -4,18 +4,6 @@
#map1 = affine_map<(d0, d1) -> (d0)>
#map2 = affine_map<(d0, d1) -> (d1, d0)>
-func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
- // expected-note @below {{when applied to this op}}
- linalg.generic {
- indexing_maps = [#map1, #map],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
- ^bb0(%in: f32, %out: f32):
- linalg.yield %in : f32
- }
- return
-}
-
func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
More information about the Mlir-commits
mailing list