[Mlir-commits] [mlir] [mlir][linalg] Add splat transpose canonicalization patterns (PR #195991)

Hocky Yudhiono llvmlistbot at llvm.org
Wed May 6 18:58:15 PDT 2026


https://github.com/hockyy updated https://github.com/llvm/llvm-project/pull/195991

>From bb484b848375d2cae3dd91b707b01926c5c95d5e Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Wed, 6 May 2026 11:06:05 +0800
Subject: [PATCH] [mlir][linalg] Add splat transpose canonicalization patterns

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 48 +++++++++++++++++++++-
 mlir/test/Dialect/Linalg/canonicalize.mlir | 46 +++++++++++++++++++++
 2 files changed, 93 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27988a451173c..3c86768098b19 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -30,6 +30,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
@@ -90,6 +91,15 @@ static Operation *getSlice(OpBuilder &b, Location loc, Value source,
       .Default([&](Type t) -> Operation * { return nullptr; });
 }
 
+static TypedAttr getScalarConstantAttrFromDenseSplat(Value input) {
+  DenseElementsAttr splatAttr;
+  matchPattern(input, m_Constant<DenseElementsAttr>(&splatAttr));
+  if (!splatAttr || !splatAttr.isSplat())
+    return {};
+
+  return dyn_cast<TypedAttr>(splatAttr.getSplatValue<Attribute>());
+}
+
 //===----------------------------------------------------------------------===//
 // Helper functions
 //===----------------------------------------------------------------------===//
@@ -2124,6 +2134,14 @@ LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
     return success();
   }
 
+  if (getInput().getType() == getInit().getType()) {
+    auto splatAttr = dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput());
+    if (splatAttr && splatAttr.isSplat()) {
+      result.push_back(getInput());
+      return success();
+    }
+  }
+
   return failure();
 }
 
@@ -2150,6 +2168,33 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
   }
 };
 
+/// Rewrite a transpose of a dense splat constant into a dense splat constant of
+/// the transposed output shape.
+struct FoldTransposeSplatConstant : OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    if (!transposeOp.hasPureTensorSemantics())
+      return failure();
+
+    TypedAttr splatValue =
+        getScalarConstantAttrFromDenseSplat(transposeOp.getInput());
+    if (!splatValue)
+      return failure();
+
+    auto resultType =
+        cast<RankedTensorType>(transposeOp.getResult()[0].getType());
+    if (!resultType.hasStaticShape())
+      return failure();
+
+    auto resultAttr = DenseElementsAttr::get(resultType, splatValue);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(transposeOp, resultType,
+                                                   resultAttr);
+    return success();
+  }
+};
+
 /// This pattern canonicalize transpose by swapping the order of
 /// broadcast and transpose:
 ///   transpose(broadcast(input)) -> broadcast(transpose(input))
@@ -2210,7 +2255,8 @@ struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
 
 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
+  results.add<FoldTransposeWithTranspose, FoldTransposeSplatConstant,
+              SwapTransposeWithBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 019b7433b2777..5efeaabcd8981 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1256,6 +1256,52 @@ func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
 
 // -----
 
+// CHECK-LABEL: @transpose_splat_constant_to_dense
+//       CHECK:   %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<3x2xf32>
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[CST]] : tensor<3x2xf32>
+func.func @transpose_splat_constant_to_dense(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  %cst = arith.constant dense<1.000000e+00> : tensor<2x3xf32>
+  %transpose = linalg.transpose
+      ins(%cst:tensor<2x3xf32>)
+      outs(%init:tensor<3x2xf32>)
+      permutation = [1, 0]
+  func.return %transpose : tensor<3x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_splat_constant_same_type
+//       CHECK:   %[[CST:.+]] = arith.constant dense<1.000000e+00> : tensor<2x2xf32>
+//   CHECK-NOT:   linalg.fill
+//   CHECK-NOT:   linalg.transpose
+//       CHECK:   return %[[CST]] : tensor<2x2xf32>
+func.func @transpose_splat_constant_same_type(%init: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  %cst = arith.constant dense<1.000000e+00> : tensor<2x2xf32>
+  %transpose = linalg.transpose
+      ins(%cst:tensor<2x2xf32>)
+      outs(%init:tensor<2x2xf32>)
+      permutation = [1, 0]
+  func.return %transpose : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_non_splat_constant
+//       CHECK:   %[[CST:.+]] = arith.constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
+//       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[CST]] : tensor<2x3xf32>) outs({{.*}} : tensor<3x2xf32>) permutation = [1, 0]
+//       CHECK:   return %[[TRANSPOSE]] : tensor<3x2xf32>
+func.func @transpose_non_splat_constant(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
+  %cst = arith.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32>
+  %transpose = linalg.transpose
+      ins(%cst:tensor<2x3xf32>)
+      outs(%init:tensor<3x2xf32>)
+      permutation = [1, 0]
+  func.return %transpose : tensor<3x2xf32>
+}
+
+// -----
+
 func.func @transpose_transpose_cancel(%input: tensor<5x4x3xf32>,
                                       %init1: tensor<4x3x5xf32>,
                                       %init2: tensor<5x4x3xf32>) -> tensor<5x4x3xf32> {



More information about the Mlir-commits mailing list