[Mlir-commits] [mlir] [mlir][linalg] Add constant folding for linalg.generic arith.extsi casts (PR #198113)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat May 16 10:06:10 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Hendrik_Klug (Jimmy2027)

<details>
<summary>Changes</summary>

Folds a `linalg.generic` whose body is a single `arith.extsi` consuming a constant input block argument, replacing the op
with the sign-extended `arith.constant`.

Had to move the check for matching input and output types to a new abstract method `checkElementTypes`.

---
Full diff: https://github.com/llvm/llvm-project/pull/198113.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp (+58-9) 
- (modified) mlir/test/Dialect/Linalg/constant-fold.mlir (+68) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index ba763ec2137e7..5ec8b4ee1e945 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/IR/Matchers.h"
@@ -28,6 +29,7 @@ namespace {
 /// `ConcreteType` should provide methods with signatures
 ///
 /// ```c++
+///   bool checkElementTypes(LinalgOp linalgOp) const;
 ///   bool matchIndexingMaps(LinalgOp linalgOp) const;
 ///   RegionComputationFn getRegionComputeFn(LinalgOp) const;
 /// ```
@@ -75,18 +77,18 @@ class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
         }))
       return failure();
 
-    // Make sure all element types are the same.
-    auto getOperandElementType = [](Value value) {
-      return cast<ShapedType>(value.getType()).getElementType();
-    };
-    if (!llvm::all_equal(
-            llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
-      return failure();
-
     // We can only handle the case where we have int/float elements.
     auto elementType = outputType.getElementType();
     if (!elementType.isIntOrFloat())
       return failure();
+    for (Value input : linalgOp.getDpsInputs()) {
+      Type elemTy = cast<ShapedType>(input.getType()).getElementType();
+      if (!elemTy.isIntOrFloat())
+        return failure();
+    }
+
+    if (!static_cast<const ConcreteType *>(this)->checkElementTypes(linalgOp))
+      return failure();
 
     // Require all indexing maps to be permutations for now. This is common and
     // it simplifies input/output access greatly: we can do the data shuffling
@@ -267,6 +269,14 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
 
   using FoldConstantBase::FoldConstantBase;
 
+  // Transpose requires all operand element types to match.
+  bool checkElementTypes(LinalgOp linalgOp) const {
+    auto getElem = [](Value v) {
+      return cast<ShapedType>(v.getType()).getElementType();
+    };
+    return llvm::all_equal(llvm::map_range(linalgOp->getOperands(), getElem));
+  }
+
   bool matchIndexingMaps(LinalgOp linalgOp) const {
     // We should have one input and one output.
     return linalgOp.getIndexingMapsArray().size() == 2;
@@ -300,10 +310,49 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
 
   ControlFusionFn controlFn;
 };
+
+// Folds a linalg.generic whose body is a single arith cast op on the input
+// block arg, when the input is a constant. Only `arith.extsi` is supported for
+// now. In the future arith ops like extui, trunci, sitofp, uitofp, extf,
+// truncf, fptosi, fptoui could be added as well.
+struct FoldConstantCast : public FoldConstantBase<FoldConstantCast> {
+  using FoldConstantBase::FoldConstantBase;
+
+  // Allow differing input/output element types.
+  bool checkElementTypes(LinalgOp) const { return true; }
+
+  bool matchIndexingMaps(LinalgOp linalgOp) const {
+    return linalgOp.getNumDpsInputs() == 1;
+  }
+
+  RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
+    Block &body = linalgOp->getRegion(0).front();
+    // Expect exactly two ops: the cast, then `linalg.yield`.
+    if (body.getOperations().size() != 2)
+      return nullptr;
+    auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
+    if (!yieldOp || yieldOp.getValues().size() != 1)
+      return nullptr;
+
+    auto castOp = yieldOp.getValues().front().getDefiningOp<arith::ExtSIOp>();
+    if (!castOp || castOp->getBlock() != &body)
+      return nullptr;
+
+    // The cast must consume `bb0` arg 0.
+    auto inArg = dyn_cast<BlockArgument>(castOp.getIn());
+    if (!inArg || inArg.getOwner() != &body || inArg.getArgNumber() != 0)
+      return nullptr;
+
+    unsigned outBW = castOp.getResult().getType().getIntOrFloatBitWidth();
+    return [outBW](const APIntOrFloatArray &inputs) {
+      return APIntOrFloat{inputs.apInts.front().sext(outBW), std::nullopt};
+    };
+  }
+};
 } // namespace
 
 void mlir::linalg::populateConstantFoldLinalgOperations(
     RewritePatternSet &patterns, const ControlFusionFn &controlFn) {
   MLIRContext *context = patterns.getContext();
-  patterns.insert<FoldConstantTranspose>(context, controlFn);
+  patterns.insert<FoldConstantTranspose, FoldConstantCast>(context, controlFn);
 }
diff --git a/mlir/test/Dialect/Linalg/constant-fold.mlir b/mlir/test/Dialect/Linalg/constant-fold.mlir
index 3929c26a3382f..ddb605a09623d 100644
--- a/mlir/test/Dialect/Linalg/constant-fold.mlir
+++ b/mlir/test/Dialect/Linalg/constant-fold.mlir
@@ -145,4 +145,72 @@ func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf3
 
 // -----
 
+// CHECK-LABEL: @cast_fold_extsi_i32_to_i64
+func.func @cast_fold_extsi_i32_to_i64(%init: tensor<4xi64>) -> tensor<4xi64> {
+  %input = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  // CHECK: %[[CST:.+]] = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]
+  } ins(%input : tensor<4xi32>) outs(%init : tensor<4xi64>) {
+  ^bb0(%arg1: i32, %arg2: i64):
+    %2 = arith.extsi %arg1 : i32 to i64
+    linalg.yield %2 : i64
+  } -> tensor<4xi64>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<4xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @cast_fold_extsi_negative
+func.func @cast_fold_extsi_negative(%init: tensor<2xi64>) -> tensor<2xi64> {
+  %input = arith.constant dense<[-1, -2]> : tensor<2xi32>
+  // CHECK: %[[CST:.+]] = arith.constant dense<[-1, -2]> : tensor<2xi64>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]
+  } ins(%input : tensor<2xi32>) outs(%init : tensor<2xi64>) {
+  ^bb0(%arg1: i32, %arg2: i64):
+    %2 = arith.extsi %arg1 : i32 to i64
+    linalg.yield %2 : i64
+  } -> tensor<2xi64>
+  // CHECK: return %[[CST]]
+  return %1 : tensor<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @cast_nofold_non_cst_input
+func.func @cast_nofold_non_cst_input(%input: tensor<4xi32>, %init: tensor<4xi64>) -> tensor<4xi64> {
+  // CHECK: linalg.generic
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]
+  } ins(%input : tensor<4xi32>) outs(%init : tensor<4xi64>) {
+  ^bb0(%arg1: i32, %arg2: i64):
+    %2 = arith.extsi %arg1 : i32 to i64
+    linalg.yield %2 : i64
+  } -> tensor<4xi64>
+  return %1 : tensor<4xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @cast_nofold_multi_ops_in_region
+func.func @cast_nofold_multi_ops_in_region(%init: tensor<4xi64>) -> tensor<4xi64> {
+  %input = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+  %two = arith.constant 2 : i64
+  // CHECK: linalg.generic
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+    iterator_types = ["parallel"]
+  } ins(%input : tensor<4xi32>) outs(%init : tensor<4xi64>) {
+  ^bb0(%arg1: i32, %arg2: i64):
+    %2 = arith.extsi %arg1 : i32 to i64
+    %3 = arith.muli %2, %two : i64
+    linalg.yield %3 : i64
+  } -> tensor<4xi64>
+  return %1 : tensor<4xi64>
+}
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/198113


More information about the Mlir-commits mailing list