[Mlir-commits] [mlir] c3728d2 - [mlir] support !shape.value_shape when replace WithOp in OutlineShapeComputationPass.

Xiang Li llvmlistbot at llvm.org
Mon Jan 23 19:36:23 PST 2023


Author: Xiang Li
Date: 2023-01-23T22:24:23-05:00
New Revision: c3728d28821e212bd3658261e58e744421668720

URL: https://github.com/llvm/llvm-project/commit/c3728d28821e212bd3658261e58e744421668720
DIFF: https://github.com/llvm/llvm-project/commit/c3728d28821e212bd3658261e58e744421668720.diff

LOG: [mlir] support !shape.value_shape when replace WithOp in OutlineShapeComputationPass.

Fixes #60069  https://github.com/llvm/llvm-project/issues/60069

In case like:
  %1 = shape.with_shape %arg1, %0 : !shape.value_shape, !shape.shape
  %2 = shape.value_of %1 : tensor<?xf32>
cannot replace %2 with %arg1.
Transform it into
  %2 = shape.value_of %arg1 : tensor<?xf32>

Differential Revision: https://reviews.llvm.org/D142275

Added: 
    mlir/test/Dialect/Shape/arg_with_shape.mlir

Modified: 
    mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
    mlir/test/Dialect/Shape/outline-shape-computation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
index 372ec80978c6..f23a090a25a0 100644
--- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
@@ -232,9 +232,24 @@ void OutlineShapeComputationPass::runOnOperation() {
 
     for (shape::WithOp withOp : allWithOps) {
       Value value = withOp.getOperand();
-      for (Operation *user : withOp.getResult().getUsers()) {
-        if (Value valueOf = llvm::dyn_cast<shape::ValueOfOp>(user))
-          valueOf.replaceAllUsesExcept(value, withOp);
+      for (Operation *user :
+           llvm::make_early_inc_range(withOp.getResult().getUsers())) {
+        if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
+          // For pattern like
+          //   %1 = shape.with_shape %arg1, %0
+          //   %2 = shape.value_of %1
+          // because shape.value doesn't care the shape, the shape.with_shape is
+          // redundant.
+          // If type of %arg1 and %2 has same type, just
+          //   replaced %2 with %arg1.
+          // If type of %arg1 has 
diff erent type like !shape.value_shape,
+          // transform into
+          //   %2 = shape.value_of %arg1
+          if (valueOf.getType() == value.getType())
+            valueOf.replaceAllUsesWith(value);
+          else
+            valueOf.setOperand(value);
+        }
       }
     }
 

diff  --git a/mlir/test/Dialect/Shape/arg_with_shape.mlir b/mlir/test/Dialect/Shape/arg_with_shape.mlir
new file mode 100644
index 000000000000..089c5031fa55
--- /dev/null
+++ b/mlir/test/Dialect/Shape/arg_with_shape.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt -outline-shape-computation -split-input-file %s 2>%t | FileCheck %s
+
+func.func @func1(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> !shape.shape {
+  %0 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape
+  %1 = shape.shape_of %arg1 : !shape.value_shape -> !shape.shape
+  %2 = shape.meet %0, %1 : !shape.shape, !shape.shape -> !shape.shape
+  return %2 : !shape.shape
+}
+// Make sure with_shape used by call not crash.
+// CHECK-LABEL:func.func @func
+func.func @func(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> !shape.shape {
+  %0 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape
+  %1 = shape.with_shape %arg1, %0 : !shape.value_shape, !shape.shape
+  %2 = call @func1(%arg0, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape
+  return %2 : !shape.shape
+}

diff  --git a/mlir/test/Dialect/Shape/outline-shape-computation.mlir b/mlir/test/Dialect/Shape/outline-shape-computation.mlir
index 73f6e3a5a2cb..4aa15a927140 100644
--- a/mlir/test/Dialect/Shape/outline-shape-computation.mlir
+++ b/mlir/test/Dialect/Shape/outline-shape-computation.mlir
@@ -207,3 +207,13 @@ func.func @multiple_reused(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> (t
 // CHECK-DAG:   %[[V5:.*]] = from_extents %[[V4]], %c4 : index, index
 // CHECK-DAG:   return %[[V5]] : !shape.shape
 
+// Make sure redundant with_shape is removed when with_shape input is !shape.value_shape.
+func.func @value_shape_with_shape(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> tensor<?xf32> {
+  %1 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape
+  %2 = shape.with_shape %arg1, %1 : !shape.value_shape, !shape.shape
+  %3 = shape.value_of %2 : tensor<?xf32>
+  return %3 : tensor<?xf32>
+}
+// CHECK-LABEL:func.func @value_shape_with_shape
+// CHECK-NEXT:%0 = shape.value_of %arg1 : tensor<?xf32>
+// CHECK-NEXT:return %0 : tensor<?xf32>


        


More information about the Mlir-commits mailing list