[Mlir-commits] [mlir] [mlir][arith] Add infer exact from dlti (PR #184631)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 4 07:11:48 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Erick Ochoa Lopez (amd-eochoalo)
<details>
<summary>Changes</summary>
https://github.com/llvm/llvm-project/pull/183395 introduced the `exact` flag to index_cast operations and updated canonicalization patterns.
Pattern `IndexCastOfIndexCast` now only triggers when the inner cast contains the `exact` flag.
```
def IndexCastOfIndexCast :
Pat<(Arith_IndexCastOp:$res (Arith_IndexCastOp $x, $exact1), $exact2),
(replaceWithValue $x),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x),
(Constraint<CPred<"(bool)$0">> $exact1)]>;
```
Before, one could narrow a value without the exact flag (possibly truncating some bits) and widening again, making the pattern unsound.
Widening and narrowing is always sound, however the size of index types may not be known at this stage of the compilation.
The ArithInferExactFromDLTI allows users to automatically annotate casts with exact when the bitwidth of index is known in the current context via a `dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 64 : i32>>`. When this pass is run, the index's bitwidth is used to infer if arith.index_cast and arith.index_castui can be annotated with the `exact` flag. Allowing the canonicalization pattern to fold this operation.
---
Full diff: https://github.com/llvm/llvm-project/pull/184631.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.h (+5)
- (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+15)
- (modified) mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Arith/Transforms/InferExactFromDLTI.cpp (+83)
- (added) mlir/test/Dialect/Arith/infer-exact-from-dlti.mlir (+52)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 18ac0dbc8d13e..3a47c93ad1638 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -82,6 +82,11 @@ void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns,
/// Create a pass which do optimizations based on integer range analysis.
std::unique_ptr<Pass> createIntRangeOptimizationsPass();
+/// Add `exact` flag on index casts whose source type is narrower than the
+/// target type bitwidth.
+void populateInferExactFromDLTIPatterns(RewritePatternSet &patterns,
+ unsigned indexBitwidth);
+
/// Add patterns for int range based narrowing.
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver,
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..35161c19d5519 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -117,4 +117,19 @@ def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
let dependentDialects = ["vector::VectorDialect"];
}
+def ArithInferExactFromDLTI : Pass<"arith-infer-exact-from-dlti"> {
+ let summary = "Infer exact flags on index casts using DLTI index bitwidth";
+ let description = [{
+ Uses the DLTI data layout to determine the target index bitwidth. For
+ `arith.index_cast` and `arith.index_castui` operations whose source type
+ is narrower than the destination type, the `exact` flag is added since the
+ cast is provably lossless (widening). This covers both `iN -> index`
+ (when `N <= index_bitwidth`) and `index -> iN` (when
+ `index_bitwidth <= N`).
+
+ This enables downstream canonicalization patterns such as
+ `IndexCastOfIndexCast` to fire in more cases.
+ }];
+}
+
#endif // MLIR_DIALECT_ARITH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 637e16a3963d6..8faf57fc3704e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRArithTransforms
EmulateWideInt.cpp
EmulateNarrowType.cpp
ExpandOps.cpp
+ InferExactFromDLTI.cpp
IntRangeOptimizations.cpp
ReifyValueBounds.cpp
ShardingInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Arith/Transforms/InferExactFromDLTI.cpp b/mlir/lib/Dialect/Arith/Transforms/InferExactFromDLTI.cpp
new file mode 100644
index 0000000000000..83c7b2da70832
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/InferExactFromDLTI.cpp
@@ -0,0 +1,83 @@
+//===- InferExactFromDLTI.cpp - Infer exact flags from DLTI ------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+namespace mlir {
+namespace arith {
+#define GEN_PASS_DEF_ARITHINFEREXACTFROMDLTI
+#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
+} // namespace arith
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::arith;
+
+static unsigned getBitwidth(Type type, unsigned indexBitwidth) {
+ Type elemType = getElementTypeOrSelf(type);
+ if (isa<IndexType>(elemType))
+ return indexBitwidth;
+ return elemType.getIntOrFloatBitWidth();
+}
+
+namespace {
+template <typename CastOp>
+struct InferExactOnIndexCast final : OpRewritePattern<CastOp> {
+ InferExactOnIndexCast(MLIRContext *context, unsigned indexBitwidth)
+ : OpRewritePattern<CastOp>(context), indexBitwidth(indexBitwidth) {}
+
+ LogicalResult matchAndRewrite(CastOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getExact())
+ return failure();
+
+ unsigned srcBW = getBitwidth(op.getIn().getType(), indexBitwidth);
+ unsigned dstBW = getBitwidth(op.getType(), indexBitwidth);
+ if (srcBW > dstBW)
+ return rewriter.notifyMatchFailure(op, "source is wider than dest");
+
+ rewriter.modifyOpInPlace(op, [&] { op.setExact(true); });
+ return success();
+ }
+
+private:
+ unsigned indexBitwidth;
+};
+
+struct ArithInferExactFromDLTIPass
+ : public arith::impl::ArithInferExactFromDLTIBase<
+ ArithInferExactFromDLTIPass> {
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+
+ DataLayout layout = DataLayout::closest(op);
+ unsigned indexBitwidth = layout.getTypeSizeInBits(IndexType::get(ctx));
+
+ RewritePatternSet patterns(ctx);
+ populateInferExactFromDLTIPatterns(patterns, indexBitwidth);
+
+ walkAndApplyPatterns(op, std::move(patterns));
+ }
+};
+} // end anonymous namespace
+
+void mlir::arith::populateInferExactFromDLTIPatterns(
+ RewritePatternSet &patterns, unsigned indexBitwidth) {
+ patterns.add<InferExactOnIndexCast<IndexCastOp>,
+ InferExactOnIndexCast<IndexCastUIOp>>(patterns.getContext(),
+ indexBitwidth);
+}
diff --git a/mlir/test/Dialect/Arith/infer-exact-from-dlti.mlir b/mlir/test/Dialect/Arith/infer-exact-from-dlti.mlir
new file mode 100644
index 0000000000000..225db935c8676
--- /dev/null
+++ b/mlir/test/Dialect/Arith/infer-exact-from-dlti.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt --split-input-file --arith-infer-exact-from-dlti %s | FileCheck %s --check-prefixes=ALL,INFER
+// RUN: mlir-opt --split-input-file --arith-infer-exact-from-dlti --canonicalize %s | FileCheck %s --check-prefixes=ALL,CANON
+
+// ALL-LABEL: func @narrowing_and_widening
+// INFER: arith.index_cast %arg0 : index to i8
+// INFER-NOT: exact
+// INFER: arith.index_cast %arg1 exact : i8 to index
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 64 : i32>> } {
+ func.func @narrowing_and_widening(%arg0: index, %arg1: i8) -> (i8, index) {
+ %0 = arith.index_cast %arg0 : index to i8
+ %1 = arith.index_cast %arg1 : i8 to index
+ return %0, %1 : i8, index
+ }
+}
+
+// -----
+
+// ALL-LABEL: func @widen_to_index
+// INFER: arith.index_cast %arg0 exact : i8 to index
+// INFER: arith.index_castui %arg1 exact : i16 to index
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 64 : i32>> } {
+ func.func @widen_to_index(%arg0: i8, %arg1: i16) -> (index, index) {
+ %0 = arith.index_cast %arg0 : i8 to index
+ %1 = arith.index_castui %arg1 : i16 to index
+ return %0, %1 : index, index
+ }
+}
+
+// -----
+
+// ALL-LABEL: func @widen_to_int
+// INFER: arith.index_cast %arg0 exact : index to i64
+// INFER: arith.index_castui %arg0 exact : index to i64
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32 : i32>> } {
+ func.func @widen_to_int(%arg0: index) -> (i64, i64) {
+ %0 = arith.index_cast %arg0 : index to i64
+ %1 = arith.index_castui %arg0 : index to i64
+ return %0, %1 : i64, i64
+ }
+}
+
+// -----
+
+// ALL-LABEL: func @roundtrip_folds
+// CANON-NEXT: return %arg0 : i8
+module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 64 : i32>> } {
+ func.func @roundtrip_folds(%arg0: i8) -> i8 {
+ %0 = arith.index_cast %arg0 : i8 to index
+ %1 = arith.index_cast %0 : index to i8
+ return %1 : i8
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/184631
More information about the Mlir-commits
mailing list