[Mlir-commits] [mlir] [mlir] Handle arith.const expr in dispatchIndexOpFoldResult func (PR #122432)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 10 01:45:44 PST 2025


https://github.com/rutkoor updated https://github.com/llvm/llvm-project/pull/122432

>From 533c396a52e7532f07cc8fd1b403b4cb79c6bfee Mon Sep 17 00:00:00 2001
From: rutkoor <quic_rutkoor at quicinc.com>
Date: Fri, 10 Jan 2025 02:11:15 -0600
Subject: [PATCH] Handle arith.const in dispatchIndexOpFoldResult func

Change-Id: I15280932f88d8ff638f5d0f964a1c03ce7a7881a
---
 mlir/lib/Dialect/Utils/StaticValueUtils.cpp   |  8 ++++++++
 mlir/test/Dialect/Tensor/bubble-reshapes.mlir | 20 +++++++++++++++++++
 2 files changed, 28 insertions(+)

diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 5c8f6ded39ba4e..163481069be42d 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -54,6 +54,14 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
     staticVec.push_back(apInt.getSExtValue());
     return;
   }
+
+  OpFoldResult result = getAsOpFoldResult(v);
+  if (auto attr = result.dyn_cast<Attribute>()) {
+    APInt apInt = cast<IntegerAttr>(attr).getValue();
+    staticVec.push_back(apInt.getSExtValue());
+    return;
+  }
+
   dynamicVec.push_back(v);
   staticVec.push_back(ShapedType::kDynamic);
 }
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
index cf6b12852bcd39..15bc9b0435f6e6 100644
--- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -20,6 +20,26 @@ func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1:
 
 // -----
 
+func.func @bubble_parallel_reshapes2(%arg0: tensor<?x2x2x6xf32>, %s0: index, %s1: index) -> tensor<?x4x2x3xf32> {
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32>
+  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+              output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
+  return %expand : tensor<?x4x2x3xf32>
+}
+//      CHECK: func @bubble_parallel_reshapes2
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x2x2x6xf32>
+// CHECK-SAME:   %[[S0:.+]]: index, %[[S1:.+]]: index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME:       output_shape [%[[S0]], 2, 2, %[[C2]], %[[C3]]] : tensor<?x2x2x6xf32> into tensor<?x2x2x2x3xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x2x2x2x3xf32> into tensor<?x4x2x3xf32>
+//      CHECK:   return %[[COLLAPSE]]
+
+// -----
+
 func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
   %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
   %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]



More information about the Mlir-commits mailing list