[Mlir-commits] [mlir] [mlir][linalg] Add constant folding for linalg.generic arith.extsi casts (PR #198113)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat May 16 10:06:10 PDT 2026
llvmorg-github-actions[bot] wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Hendrik_Klug (Jimmy2027)
<details>
<summary>Changes</summary>
Folds a `linalg.generic` whose body is a single `arith.extsi` consuming a constant input block argument, replacing the op
with the sign-extended `arith.constant`.
Had to move the check for matching input and output types to a new abstract method `checkElementTypes`.
---
Full diff: https://github.com/llvm/llvm-project/pull/198113.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp (+58-9)
- (modified) mlir/test/Dialect/Linalg/constant-fold.mlir (+68)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index ba763ec2137e7..5ec8b4ee1e945 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/Matchers.h"
@@ -28,6 +29,7 @@ namespace {
/// `ConcreteType` should provide methods with signatures
///
/// ```c++
+/// bool checkElementTypes(LinalgOp linalgOp) const;
/// bool matchIndexingMaps(LinalgOp linalgOp) const;
/// RegionComputationFn getRegionComputeFn(LinalgOp) const;
/// ```
@@ -75,18 +77,18 @@ class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
}))
return failure();
- // Make sure all element types are the same.
- auto getOperandElementType = [](Value value) {
- return cast<ShapedType>(value.getType()).getElementType();
- };
- if (!llvm::all_equal(
- llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
- return failure();
-
// We can only handle the case where we have int/float elements.
auto elementType = outputType.getElementType();
if (!elementType.isIntOrFloat())
return failure();
+ for (Value input : linalgOp.getDpsInputs()) {
+ Type elemTy = cast<ShapedType>(input.getType()).getElementType();
+ if (!elemTy.isIntOrFloat())
+ return failure();
+ }
+
+ if (!static_cast<const ConcreteType *>(this)->checkElementTypes(linalgOp))
+ return failure();
// Require all indexing maps to be permutations for now. This is common and
// it simplifies input/output access greatly: we can do the data shuffling
@@ -267,6 +269,14 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
using FoldConstantBase::FoldConstantBase;
+ // Transpose requires all operand element types to match.
+ bool checkElementTypes(LinalgOp linalgOp) const {
+ auto getElem = [](Value v) {
+ return cast<ShapedType>(v.getType()).getElementType();
+ };
+ return llvm::all_equal(llvm::map_range(linalgOp->getOperands(), getElem));
+ }
+
bool matchIndexingMaps(LinalgOp linalgOp) const {
// We should have one input and one output.
return linalgOp.getIndexingMapsArray().size() == 2;
@@ -300,10 +310,49 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
ControlFusionFn controlFn;
};
+
+// Folds a linalg.generic whose body is a single arith cast op on the input
+// block arg, when the input is a constant. Only `arith.extsi` is supported for
+// now. In the future arith ops like extui, trunci, sitofp, uitofp, extf,
+// truncf, fptosi, fptoui could be added as well.
+struct FoldConstantCast : public FoldConstantBase<FoldConstantCast> {
+ using FoldConstantBase::FoldConstantBase;
+
+ // Allow differing input/output element types.
+ bool checkElementTypes(LinalgOp) const { return true; }
+
+ bool matchIndexingMaps(LinalgOp linalgOp) const {
+ return linalgOp.getNumDpsInputs() == 1;
+ }
+
+ RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
+ Block &body = linalgOp->getRegion(0).front();
+ // Expect exactly two ops: the cast, then `linalg.yield`.
+ if (body.getOperations().size() != 2)
+ return nullptr;
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
+ if (!yieldOp || yieldOp.getValues().size() != 1)
+ return nullptr;
+
+ auto castOp = yieldOp.getValues().front().getDefiningOp<arith::ExtSIOp>();
+ if (!castOp || castOp->getBlock() != &body)
+ return nullptr;
+
+ // The cast must consume `bb0` arg 0.
+ auto inArg = dyn_cast<BlockArgument>(castOp.getIn());
+ if (!inArg || inArg.getOwner() != &body || inArg.getArgNumber() != 0)
+ return nullptr;
+
+ unsigned outBW = castOp.getResult().getType().getIntOrFloatBitWidth();
+ return [outBW](const APIntOrFloatArray &inputs) {
+ return APIntOrFloat{inputs.apInts.front().sext(outBW), std::nullopt};
+ };
+ }
+};
} // namespace
void mlir::linalg::populateConstantFoldLinalgOperations(
RewritePatternSet &patterns, const ControlFusionFn &controlFn) {
MLIRContext *context = patterns.getContext();
- patterns.insert<FoldConstantTranspose>(context, controlFn);
+ patterns.insert<FoldConstantTranspose, FoldConstantCast>(context, controlFn);
}
diff --git a/mlir/test/Dialect/Linalg/constant-fold.mlir b/mlir/test/Dialect/Linalg/constant-fold.mlir
index 3929c26a3382f..ddb605a09623d 100644
--- a/mlir/test/Dialect/Linalg/constant-fold.mlir
+++ b/mlir/test/Dialect/Linalg/constant-fold.mlir
@@ -145,4 +145,72 @@ func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf3
// -----
+// CHECK-LABEL: @cast_fold_extsi_i32_to_i64
+func.func @cast_fold_extsi_i32_to_i64(%init: tensor<4xi64>) -> tensor<4xi64> {
+ %input = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ // CHECK: %[[CST:.+]] = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%input : tensor<4xi32>) outs(%init : tensor<4xi64>) {
+ ^bb0(%arg1: i32, %arg2: i64):
+ %2 = arith.extsi %arg1 : i32 to i64
+ linalg.yield %2 : i64
+ } -> tensor<4xi64>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<4xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @cast_fold_extsi_negative
+func.func @cast_fold_extsi_negative(%init: tensor<2xi64>) -> tensor<2xi64> {
+ %input = arith.constant dense<[-1, -2]> : tensor<2xi32>
+ // CHECK: %[[CST:.+]] = arith.constant dense<[-1, -2]> : tensor<2xi64>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%input : tensor<2xi32>) outs(%init : tensor<2xi64>) {
+ ^bb0(%arg1: i32, %arg2: i64):
+ %2 = arith.extsi %arg1 : i32 to i64
+ linalg.yield %2 : i64
+ } -> tensor<2xi64>
+ // CHECK: return %[[CST]]
+ return %1 : tensor<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @cast_nofold_non_cst_input
+func.func @cast_nofold_non_cst_input(%input: tensor<4xi32>, %init: tensor<4xi64>) -> tensor<4xi64> {
+ // CHECK: linalg.generic
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%input : tensor<4xi32>) outs(%init : tensor<4xi64>) {
+ ^bb0(%arg1: i32, %arg2: i64):
+ %2 = arith.extsi %arg1 : i32 to i64
+ linalg.yield %2 : i64
+ } -> tensor<4xi64>
+ return %1 : tensor<4xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @cast_nofold_multi_ops_in_region
+func.func @cast_nofold_multi_ops_in_region(%init: tensor<4xi64>) -> tensor<4xi64> {
+ %input = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ %two = arith.constant 2 : i64
+ // CHECK: linalg.generic
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%input : tensor<4xi32>) outs(%init : tensor<4xi64>) {
+ ^bb0(%arg1: i32, %arg2: i64):
+ %2 = arith.extsi %arg1 : i32 to i64
+ %3 = arith.muli %2, %two : i64
+ linalg.yield %3 : i64
+ } -> tensor<4xi64>
+ return %1 : tensor<4xi64>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/198113
More information about the Mlir-commits
mailing list