[Mlir-commits] [mlir] [mlir] Handle arith.const expr in dispatchIndexOpFoldResult func (PR #122432)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 10 00:28:56 PST 2025
https://github.com/rutkoor created https://github.com/llvm/llvm-project/pull/122432
This PR addresses the handling of arith.constant expressions in the dispatchIndexOpFoldResult helper function. Previously, the helper function dispatched an OpFoldResult into **staticVec** only if it was an IntegerAttr. The changes in this PR now enable the evaluation of arith.constant expressions, extraction of the integer value, and dispatch into **staticVec**.
>From 83245bf933adf20abce5cf058ded561797235376 Mon Sep 17 00:00:00 2001
From: rutkoor <quic_rutkoor at quicinc.com>
Date: Fri, 10 Jan 2025 02:11:15 -0600
Subject: [PATCH] Handle arith.const in dispatchIndexOpFoldResult func
Change-Id: I15280932f88d8ff638f5d0f964a1c03ce7a7881a
---
mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 13 ++++++++++++
mlir/test/Dialect/Tensor/bubble-reshapes.mlir | 20 +++++++++++++++++++
2 files changed, 33 insertions(+)
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 5c8f6ded39ba4e..7ad4c982af2aae 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
@@ -54,6 +55,18 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
staticVec.push_back(apInt.getSExtValue());
return;
}
+
+ Operation *definingOp = v.getDefiningOp();
+ if (definingOp) {
+ // Check if definingOp is an arith.constant
+ if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
+ if (auto intAttr = mlir::dyn_cast<IntegerAttr>(constantOp.getValue())) {
+ staticVec.push_back(intAttr.getValue().getSExtValue());
+ return;
+ }
+ }
+ }
+
dynamicVec.push_back(v);
staticVec.push_back(ShapedType::kDynamic);
}
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
index cf6b12852bcd39..15bc9b0435f6e6 100644
--- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -20,6 +20,26 @@ func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1:
// -----
+func.func @bubble_parallel_reshapes2(%arg0: tensor<?x2x2x6xf32>, %s0: index, %s1: index) -> tensor<?x4x2x3xf32> {
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x2x2x6xf32> into tensor<?x4x6xf32>
+ %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+ output_shape [%s0, %s1, %c2, %c3] : tensor<?x4x6xf32> into tensor<?x4x2x3xf32>
+ return %expand : tensor<?x4x2x3xf32>
+}
+// CHECK: func @bubble_parallel_reshapes2
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x2x2x6xf32>
+// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME: output_shape [%[[S0]], 2, 2, %[[C2]], %[[C3]]] : tensor<?x2x2x6xf32> into tensor<?x2x2x2x3xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x2x2x2x3xf32> into tensor<?x4x2x3xf32>
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
More information about the Mlir-commits
mailing list