[Mlir-commits] [mlir] [mlir][linalg] Add splat transpose canonicalization patterns (PR #195991)
Hocky Yudhiono
llvmlistbot at llvm.org
Wed May 6 18:58:15 PDT 2026
https://github.com/hockyy updated https://github.com/llvm/llvm-project/pull/195991
>From bb484b848375d2cae3dd91b707b01926c5c95d5e Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Wed, 6 May 2026 11:06:05 +0800
Subject: [PATCH] [mlir][linalg] Add splat transpose canonicalization patterns
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 48 +++++++++++++++++++++-
mlir/test/Dialect/Linalg/canonicalize.mlir | 46 +++++++++++++++++++++
2 files changed, 93 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27988a451173c..3c86768098b19 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -30,6 +30,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
@@ -90,6 +91,15 @@ static Operation *getSlice(OpBuilder &b, Location loc, Value source,
.Default([&](Type t) -> Operation * { return nullptr; });
}
+static TypedAttr getScalarConstantAttrFromDenseSplat(Value input) {
+ DenseElementsAttr splatAttr;
+ matchPattern(input, m_Constant<DenseElementsAttr>(&splatAttr));
+ if (!splatAttr || !splatAttr.isSplat())
+ return {};
+
+ return dyn_cast<TypedAttr>(splatAttr.getSplatValue<Attribute>());
+}
+
//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
@@ -2124,6 +2134,14 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
return success();
}
+ if (getInput().getType() == getInit().getType()) {
+ auto splatAttr = dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput());
+ if (splatAttr && splatAttr.isSplat()) {
+ result.push_back(getInput());
+ return success();
+ }
+ }
+
return failure();
}
@@ -2150,6 +2168,33 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
}
};
+/// Rewrite a transpose of a dense splat constant into a dense splat constant of
+/// the transposed output shape.
+struct FoldTransposeSplatConstant : OpRewritePattern<linalg::TransposeOp> {
+ using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ if (!transposeOp.hasPureTensorSemantics())
+ return failure();
+
+ TypedAttr splatValue =
+ getScalarConstantAttrFromDenseSplat(transposeOp.getInput());
+ if (!splatValue)
+ return failure();
+
+ auto resultType =
+ cast<RankedTensorType>(transposeOp.getResult()[0].getType());
+ if (!resultType.hasStaticShape())
+ return failure();
+
+ auto resultAttr = DenseElementsAttr::get(resultType, splatValue);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(transposeOp, resultType,
+ resultAttr);
+ return success();
+ }
+};
+
/// This pattern canonicalize transpose by swapping the order of
/// broadcast and transpose:
/// transpose(broadcast(input)) -> broadcast(transpose(input))
@@ -2210,7 +2255,8 @@ struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
+ results.add<FoldTransposeWithTranspose, FoldTransposeSplatConstant,
+ SwapTransposeWithBroadcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 019b7433b2777..5efeaabcd8981 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1256,6 +1256,52 @@ func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
// -----
+// CHECK-LABEL: @transpose_splat_constant_to_dense
+// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<3x2xf32>
+// CHECK-NOT: linalg.transpose
+// CHECK: return %[[CST]] : tensor<3x2xf32>
+func.func @transpose_splat_constant_to_dense(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ %cst = arith.constant dense<1.000000e+00> : tensor<2x3xf32>
+ %transpose = linalg.transpose
+ ins(%cst:tensor<2x3xf32>)
+ outs(%init:tensor<3x2xf32>)
+ permutation = [1, 0]
+ func.return %transpose : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_splat_constant_same_type
+// CHECK: %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<2x2xf32>
+// CHECK-NOT: linalg.fill
+// CHECK-NOT: linalg.transpose
+// CHECK: return %[[CST]] : tensor<2x2xf32>
+func.func @transpose_splat_constant_same_type(%init: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ %cst = arith.constant dense<1.000000e+00> : tensor<2x2xf32>
+ %transpose = linalg.transpose
+ ins(%cst:tensor<2x2xf32>)
+ outs(%init:tensor<2x2xf32>)
+ permutation = [1, 0]
+ func.return %transpose : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_non_splat_constant
+// CHECK: %[[CST:.+]] = arith.constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[CST]] : tensor<2x3xf32>) outs({{.*}} : tensor<3x2xf32>) permutation = [1, 0]
+// CHECK: return %[[TRANSPOSE]] : tensor<3x2xf32>
+func.func @transpose_non_splat_constant(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+ %cst = arith.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
+ %transpose = linalg.transpose
+ ins(%cst:tensor<2x3xf32>)
+ outs(%init:tensor<3x2xf32>)
+ permutation = [1, 0]
+ func.return %transpose : tensor<3x2xf32>
+}
+
+// -----
+
func.func @transpose_transpose_cancel(%input: tensor<5x4x3xf32>,
%init1: tensor<4x3x5xf32>,
%init2: tensor<5x4x3xf32>) -> tensor<5x4x3xf32> {
More information about the Mlir-commits
mailing list