[Mlir-commits] [mlir] [mlir][linalg] Add folder for `linalg.index` (PR #136640)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 21 18:28:33 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
We know that the index of unit dims is always 0.
---
Full diff: https://github.com/llvm/llvm-project/pull/136640.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+1)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+29)
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+80)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index f8df828f74851..1b48bf5fcb237 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -88,6 +88,7 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
let hasVerifier = 1;
+ let hasFolder = 1;
}
def Linalg_SoftmaxOp : Linalg_Op<"softmax",
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6c680498af2ad..a3787f101afa3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2283,6 +2283,35 @@ LogicalResult IndexOp::verify() {
return success();
}
+OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
+ auto linalgOp = cast<LinalgOp>((*this)->getParentOp());
+ int64_t flatDimPos =
+ cast<AffineDimExpr>(linalgOp.getShapesToLoopsMap().getResult(getDim()))
+ .getPosition();
+
+ // Find the flat dimension position among the operands.
+ int64_t flatPosOffset = 0;
+ for (Value operand : linalgOp->getOperands()) {
+ assert(flatDimPos >= flatPosOffset && "invalid position");
+ auto shapedType = dyn_cast<ShapedType>(operand.getType());
+ if (!shapedType)
+ break;
+
+ int64_t rank = shapedType.getRank();
+ if (flatDimPos < flatPosOffset + rank) {
+ // Found the dimension within this shape. Now we can either fold if the
+ // dim size is 1, or bail out otherwise.
+ int64_t pos = flatDimPos - flatPosOffset;
+ if (shapedType.getDimSize(pos) != 1)
+ break;
+
+ return IntegerAttr::get(IndexType::get(getContext()), 0);
+ }
+ flatPosOffset += rank;
+ }
+ return OpFoldResult{};
+}
+
/////// Operations corresponding to library calls defined with Tablegen ////////
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 86cb8f58abe02..3daf221f4402d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -305,6 +305,86 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
}
// -----
+
+// CHECK: func @fold_linalg_index_tensor_static
+func.func @fold_linalg_index_tensor_static(%0: tensor<4x16xi32>, %1: tensor<1x16xi32>,
+ %2: tensor<4x1xi32>) -> tensor<4x1xi32> {
+// CHECK-NEXT: linalg.generic
+// CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
+// CHECK-NOT: linalg.index 1
+// CHECK: %[[IDX_2:.+]] = linalg.index 2 : index
+// CHECK: %[[ADD:.+]] = arith.addi %[[IDX_0]], %[[IDX_2]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[ADD]]
+// CHECK: linalg.yield %[[CAST]]
+ %res = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%0, %1 : tensor<4x16xi32>, tensor<1x16xi32>)
+ outs(%2 : tensor<4x1xi32>) {
+ ^bb0(%lhs: i32, %rhs: i32, %out: i32):
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %idx2 = linalg.index 2 : index
+ %add0 = arith.addi %idx0, %idx1 : index
+ %add1 = arith.addi %add0, %idx2 : index
+ %int = arith.index_cast %add1 : index to i32
+ linalg.yield %int : i32
+ } -> tensor<4x1xi32>
+ return %res : tensor<4x1xi32>
+}
+
+// -----
+
+// CHECK: func @fold_linalg_index_tensor_dynamic
+func.func @fold_linalg_index_tensor_dynamic(%0: tensor<?x1xi32>,
+ %1: tensor<?x1xi32>) -> tensor<?x1xi32> {
+// CHECK-NEXT: linalg.generic
+// CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
+// CHECK-NOT: linalg.index 1
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_0]]
+// CHECK: linalg.yield %[[CAST]]
+ %res = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : tensor<?x1xi32>)
+ outs(%1 : tensor<?x1xi32>) {
+ ^bb0(%lhs: i32, %out: i32):
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %add = arith.addi %idx0, %idx1 : index
+ %int = arith.index_cast %add : index to i32
+ linalg.yield %int : i32
+ } -> tensor<?x1xi32>
+ return %res : tensor<?x1xi32>
+}
+
+// -----
+
+// CHECK: func @fold_linalg_index_memref
+func.func @fold_linalg_index_memref(%0: memref<1x?xi32>, %1: memref<1x?xi32>) {
+// CHECK-NEXT: linalg.generic
+// CHECK-NOT: linalg.index 0
+// CHECK: %[[IDX_1:.+]] = linalg.index 1 : index
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_1]]
+// CHECK: linalg.yield %[[CAST]]
+ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d1, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : memref<1x?xi32>)
+ outs(%1 : memref<1x?xi32>) {
+ ^bb0(%lhs: i32, %out: i32):
+ %idx0 = linalg.index 0 : index
+ %idx1 = linalg.index 1 : index
+ %add = arith.addi %idx0, %idx1 : index
+ %int = arith.index_cast %add : index to i32
+ linalg.yield %int : i32
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @fold_fill_reshape()
func.func @fold_fill_reshape() -> tensor<6x4xf32> {
%zero = arith.constant 0.0 : f32
``````````
</details>
https://github.com/llvm/llvm-project/pull/136640
More information about the Mlir-commits
mailing list