[Mlir-commits] [mlir] c3728d2 - [mlir] support !shape.value_shape when replace WithOp in OutlineShapeComputationPass.
Xiang Li
llvmlistbot at llvm.org
Mon Jan 23 19:36:23 PST 2023
Author: Xiang Li
Date: 2023-01-23T22:24:23-05:00
New Revision: c3728d28821e212bd3658261e58e744421668720
URL: https://github.com/llvm/llvm-project/commit/c3728d28821e212bd3658261e58e744421668720
DIFF: https://github.com/llvm/llvm-project/commit/c3728d28821e212bd3658261e58e744421668720.diff
LOG: [mlir] support !shape.value_shape when replace WithOp in OutlineShapeComputationPass.
Fixes #60069 https://github.com/llvm/llvm-project/issues/60069
In case like:
%1 = shape.with_shape %arg1, %0 : !shape.value_shape, !shape.shape
%2 = shape.value_of %1 : tensor<?xf32>
cannot replace %2 with %arg1.
Transform it into
%2 = shape.value_of %arg1 : tensor<?xf32>
Differential Revision: https://reviews.llvm.org/D142275
Added:
mlir/test/Dialect/Shape/arg_with_shape.mlir
Modified:
mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
mlir/test/Dialect/Shape/outline-shape-computation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
index 372ec80978c6..f23a090a25a0 100644
--- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
@@ -232,9 +232,24 @@ void OutlineShapeComputationPass::runOnOperation() {
for (shape::WithOp withOp : allWithOps) {
Value value = withOp.getOperand();
- for (Operation *user : withOp.getResult().getUsers()) {
- if (Value valueOf = llvm::dyn_cast<shape::ValueOfOp>(user))
- valueOf.replaceAllUsesExcept(value, withOp);
+ for (Operation *user :
+ llvm::make_early_inc_range(withOp.getResult().getUsers())) {
+ if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
+ // For pattern like
+ // %1 = shape.with_shape %arg1, %0
+ // %2 = shape.value_of %1
+ // because shape.value doesn't care the shape, the shape.with_shape is
+ // redundant.
+ // If type of %arg1 and %2 has same type, just
+ // replaced %2 with %arg1.
+ // If type of %arg1 has
diff erent type like !shape.value_shape,
+ // transform into
+ // %2 = shape.value_of %arg1
+ if (valueOf.getType() == value.getType())
+ valueOf.replaceAllUsesWith(value);
+ else
+ valueOf.setOperand(value);
+ }
}
}
diff --git a/mlir/test/Dialect/Shape/arg_with_shape.mlir b/mlir/test/Dialect/Shape/arg_with_shape.mlir
new file mode 100644
index 000000000000..089c5031fa55
--- /dev/null
+++ b/mlir/test/Dialect/Shape/arg_with_shape.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -outline-shape-computation -split-input-file %s 2>%t | FileCheck %s
+
+func.func @func1(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> !shape.shape {
+ %0 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape
+ %1 = shape.shape_of %arg1 : !shape.value_shape -> !shape.shape
+ %2 = shape.meet %0, %1 : !shape.shape, !shape.shape -> !shape.shape
+ return %2 : !shape.shape
+}
+// Make sure with_shape used by call not crash.
+// CHECK-LABEL:func.func @func
+func.func @func(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> !shape.shape {
+ %0 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape
+ %1 = shape.with_shape %arg1, %0 : !shape.value_shape, !shape.shape
+ %2 = call @func1(%arg0, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape
+ return %2 : !shape.shape
+}
diff --git a/mlir/test/Dialect/Shape/outline-shape-computation.mlir b/mlir/test/Dialect/Shape/outline-shape-computation.mlir
index 73f6e3a5a2cb..4aa15a927140 100644
--- a/mlir/test/Dialect/Shape/outline-shape-computation.mlir
+++ b/mlir/test/Dialect/Shape/outline-shape-computation.mlir
@@ -207,3 +207,13 @@ func.func @multiple_reused(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> (t
// CHECK-DAG: %[[V5:.*]] = from_extents %[[V4]], %c4 : index, index
// CHECK-DAG: return %[[V5]] : !shape.shape
+// Make sure redundant with_shape is removed when with_shape input is !shape.value_shape.
+func.func @value_shape_with_shape(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> tensor<?xf32> {
+ %1 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape
+ %2 = shape.with_shape %arg1, %1 : !shape.value_shape, !shape.shape
+ %3 = shape.value_of %2 : tensor<?xf32>
+ return %3 : tensor<?xf32>
+}
+// CHECK-LABEL:func.func @value_shape_with_shape
+// CHECK-NEXT:%0 = shape.value_of %arg1 : tensor<?xf32>
+// CHECK-NEXT:return %0 : tensor<?xf32>
More information about the Mlir-commits
mailing list