[Mlir-commits] [mlir] [mlir][linalg] Add folder for `linalg.index` (PR #136640)

Jakub Kuderski llvmlistbot at llvm.org
Tue Apr 22 08:29:17 PDT 2025


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/136640

>From 492439a533a552b2859951d6e331663f1c3244b2 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 21 Apr 2025 21:26:41 -0400
Subject: [PATCH 1/4] [mlir][linalg] Add folder for `linalg.index`

We know that the index of unit dims is always 0.
---
 .../mlir/Dialect/Linalg/IR/LinalgOps.td       |  1 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 29 +++++++
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 80 +++++++++++++++++++
 3 files changed, 110 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index f8df828f74851..1b48bf5fcb237 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -88,6 +88,7 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
 
   let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
   let hasVerifier = 1;
+  let hasFolder = 1;
 }
 
 def Linalg_SoftmaxOp : Linalg_Op<"softmax",
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6c680498af2ad..a3787f101afa3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2283,6 +2283,35 @@ LogicalResult IndexOp::verify() {
   return success();
 }
 
+OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
+  auto linalgOp = cast<LinalgOp>((*this)->getParentOp());
+  int64_t flatDimPos =
+      cast<AffineDimExpr>(linalgOp.getShapesToLoopsMap().getResult(getDim()))
+          .getPosition();
+
+  // Find the flat dimension position among the operands.
+  int64_t flatPosOffset = 0;
+  for (Value operand : linalgOp->getOperands()) {
+    assert(flatDimPos >= flatPosOffset && "invalid position");
+    auto shapedType = dyn_cast<ShapedType>(operand.getType());
+    if (!shapedType)
+      break;
+
+    int64_t rank = shapedType.getRank();
+    if (flatDimPos < flatPosOffset + rank) {
+      // Found the dimension within this shape. Now we can either fold if the
+      // dim size is 1, or bail out otherwise.
+      int64_t pos = flatDimPos - flatPosOffset;
+      if (shapedType.getDimSize(pos) != 1)
+        break;
+
+      return IntegerAttr::get(IndexType::get(getContext()), 0);
+    }
+    flatPosOffset += rank;
+  }
+  return OpFoldResult{};
+}
+
 /////// Operations corresponding to library calls defined with Tablegen ////////
 
 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 86cb8f58abe02..3daf221f4402d 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -305,6 +305,86 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
 }
 
 // -----
+
+// CHECK: func @fold_linalg_index_tensor_static
+func.func @fold_linalg_index_tensor_static(%0: tensor<4x16xi32>, %1: tensor<1x16xi32>,
+                                           %2: tensor<4x1xi32>) -> tensor<4x1xi32> {
+// CHECK-NEXT: linalg.generic
+// CHECK:        %[[IDX_0:.+]] = linalg.index 0 : index
+// CHECK-NOT:    linalg.index 1
+// CHECK:        %[[IDX_2:.+]] = linalg.index 2 : index
+// CHECK:        %[[ADD:.+]] = arith.addi %[[IDX_0]], %[[IDX_2]]
+// CHECK:        %[[CAST:.+]] = arith.index_cast %[[ADD]]
+// CHECK:        linalg.yield %[[CAST]]
+  %res = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
+                                          affine_map<(d0, d1, d2) -> (d1, d2)>,
+                                          affine_map<(d0, d1, d2) -> (d0, d1)>],
+                         iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%0, %1 : tensor<4x16xi32>, tensor<1x16xi32>)
+    outs(%2 : tensor<4x1xi32>) {
+      ^bb0(%lhs: i32, %rhs: i32, %out: i32):
+        %idx0 = linalg.index 0 : index
+        %idx1 = linalg.index 1 : index
+        %idx2 = linalg.index 2 : index
+        %add0 = arith.addi %idx0, %idx1 : index
+        %add1 = arith.addi %add0, %idx2 : index
+        %int = arith.index_cast %add1 : index to i32
+        linalg.yield %int : i32
+    } -> tensor<4x1xi32>
+  return %res : tensor<4x1xi32>
+}
+
+// -----
+
+// CHECK: func @fold_linalg_index_tensor_dynamic
+func.func @fold_linalg_index_tensor_dynamic(%0: tensor<?x1xi32>,
+                                            %1: tensor<?x1xi32>) -> tensor<?x1xi32> {
+// CHECK-NEXT: linalg.generic
+// CHECK:        %[[IDX_0:.+]] = linalg.index 0 : index
+// CHECK-NOT:    linalg.index 1
+// CHECK:        %[[CAST:.+]] = arith.index_cast %[[IDX_0]]
+// CHECK:        linalg.yield %[[CAST]]
+  %res = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d1, d1)>],
+                         iterator_types = ["parallel", "parallel"]}
+    ins(%0 : tensor<?x1xi32>)
+    outs(%1 : tensor<?x1xi32>) {
+      ^bb0(%lhs: i32, %out: i32):
+        %idx0 = linalg.index 0 : index
+        %idx1 = linalg.index 1 : index
+        %add = arith.addi %idx0, %idx1 : index
+        %int = arith.index_cast %add : index to i32
+        linalg.yield %int : i32
+    } -> tensor<?x1xi32>
+  return %res : tensor<?x1xi32>
+}
+
+// -----
+
+// CHECK: func @fold_linalg_index_memref
+func.func @fold_linalg_index_memref(%0: memref<1x?xi32>, %1: memref<1x?xi32>) {
+// CHECK-NEXT: linalg.generic
+// CHECK-NOT:    linalg.index 0
+// CHECK:        %[[IDX_1:.+]] = linalg.index 1 : index
+// CHECK:        %[[CAST:.+]] = arith.index_cast %[[IDX_1]]
+// CHECK:        linalg.yield %[[CAST]]
+  linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                   affine_map<(d0, d1) -> (d1, d1)>],
+                  iterator_types = ["parallel", "parallel"]}
+    ins(%0 : memref<1x?xi32>)
+    outs(%1 : memref<1x?xi32>) {
+      ^bb0(%lhs: i32, %out: i32):
+        %idx0 = linalg.index 0 : index
+        %idx1 = linalg.index 1 : index
+        %add = arith.addi %idx0, %idx1 : index
+        %int = arith.index_cast %add : index to i32
+        linalg.yield %int : i32
+    }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_fill_reshape()
 func.func @fold_fill_reshape() -> tensor<6x4xf32> {
   %zero = arith.constant 0.0 : f32

>From b3e5afe9fbf42b095eaf232bc7f3ffa52b4864f1 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 21 Apr 2025 21:32:52 -0400
Subject: [PATCH 2/4] Update test

---
 mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 375fa37bd84b0..01eafafc8ea29 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -278,12 +278,11 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
 // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
 // CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
 // CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
-// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
 // CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
 // CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
 // CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
-// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
-// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
+// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex>
 // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
 // CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
 

>From cbcdb3346e6296552fec9bbcbc3838764c59e6ec Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 21 Apr 2025 21:51:05 -0400
Subject: [PATCH 3/4] Simplify

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 28 +++++-------------------
 1 file changed, 6 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index a3787f101afa3..967b7685cd89c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2285,30 +2285,14 @@ LogicalResult IndexOp::verify() {
 
 OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
   auto linalgOp = cast<LinalgOp>((*this)->getParentOp());
-  int64_t flatDimPos =
-      cast<AffineDimExpr>(linalgOp.getShapesToLoopsMap().getResult(getDim()))
-          .getPosition();
-
-  // Find the flat dimension position among the operands.
-  int64_t flatPosOffset = 0;
-  for (Value operand : linalgOp->getOperands()) {
-    assert(flatDimPos >= flatPosOffset && "invalid position");
-    auto shapedType = dyn_cast<ShapedType>(operand.getType());
-    if (!shapedType)
-      break;
 
-    int64_t rank = shapedType.getRank();
-    if (flatDimPos < flatPosOffset + rank) {
-      // Found the dimension within this shape. Now we can either fold if the
-      // dim size is 1, or bail out otherwise.
-      int64_t pos = flatDimPos - flatPosOffset;
-      if (shapedType.getDimSize(pos) != 1)
-        break;
+  // Index of unit dims is always 0.
+  SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
+  uint64_t dim = getDim();
+  assert(dim < loopBounds.size());
+  if (loopBounds[dim] == 1)
+    return IntegerAttr::get(IndexType::get(getContext()), 0);
 
-      return IntegerAttr::get(IndexType::get(getContext()), 0);
-    }
-    flatPosOffset += rank;
-  }
   return OpFoldResult{};
 }
 

>From 12ba5a3e8dbc9f80b67601df54da9482dc39ddf1 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 22 Apr 2025 11:29:06 -0400
Subject: [PATCH 4/4] Add assert message

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 967b7685cd89c..72fb3308a2549 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2289,7 +2289,7 @@ OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
   // Index of unit dims is always 0.
   SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
   uint64_t dim = getDim();
-  assert(dim < loopBounds.size());
+  assert(dim < loopBounds.size() && "Dim is out of bounds");
   if (loopBounds[dim] == 1)
     return IntegerAttr::get(IndexType::get(getContext()), 0);
 



More information about the Mlir-commits mailing list