[Mlir-commits] [mlir] 2b9b999 - [MLIR][Shape] Replace single operand broadcasts with appropriate cast

Frederik Gossen llvmlistbot at llvm.org
Tue Apr 27 05:49:15 PDT 2021


Author: Frederik Gossen
Date: 2021-04-27T14:48:56+02:00
New Revision: 2b9b999d4d35d02170dc12a48d0c4d0a3ad22739

URL: https://github.com/llvm/llvm-project/commit/2b9b999d4d35d02170dc12a48d0c4d0a3ad22739
DIFF: https://github.com/llvm/llvm-project/commit/2b9b999d4d35d02170dc12a48d0c4d0a3ad22739.diff

LOG: [MLIR][Shape] Replace single operand broadcasts with appropriate cast

Differential Revision: https://reviews.llvm.org/D101350

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index c19acf33f535..ea2b10bdcf0e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -558,14 +558,26 @@ struct BroadcastForwardSingleOperandPattern
 
   LogicalResult matchAndRewrite(BroadcastOp op,
                                 PatternRewriter &rewriter) const override {
-    if (op.getNumOperands() == 1) {
-      Value uniqueShapeOperand = op.shapes().front();
-      if (uniqueShapeOperand.getType() == op.getType()) {
-        rewriter.replaceOp(op, uniqueShapeOperand);
-        return success();
+    if (op.getNumOperands() != 1)
+      return failure();
+    Value replacement = op.shapes().front();
+
+    // Insert cast if needed.
+    if (replacement.getType() != op.getType()) {
+      auto loc = op.getLoc();
+      if (op.getType().isa<ShapeType>()) {
+        replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
+      } else {
+        assert(!op.getType().isa<ShapeType>() &&
+               !replacement.getType().isa<ShapeType>() &&
+               "expect extent tensor cast");
+        replacement =
+            rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
       }
     }
-    return failure();
+
+    rewriter.replaceOp(op, replacement);
+    return success();
   }
 };
 

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 68545d161374..a47db3044e0e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1242,13 +1242,24 @@ func @broadcast_on_single_operand(%a : tensor<?xindex>) {
 
 // -----
 
-// CHECK-LABEL: @broadcast_on_single_operand
+// CHECK-LABEL: @broadcast_as_tensor_cast
 // CHECK-SAME: (%[[A:.*]]: tensor<3xindex>)
-func @broadcast_on_single_operand(%a : tensor<3xindex>) {
-  // CHECK: broadcast %[[A]]
+func @broadcast_as_tensor_cast(%a : tensor<3xindex>) -> tensor<?xindex> {
+  // CHECK: %[[RESULT:.*]] = tensor.cast %[[A]] : tensor<3xindex> to tensor<?xindex>
+  // CHECK: return %[[RESULT]] : tensor<?xindex>
   %0 = shape.broadcast %a : tensor<3xindex> -> tensor<?xindex>
-  "use"(%0) : (tensor<?xindex>) -> ()
-  return
+  return %0 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_as_from_extent_tensor
+// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>)
+func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape {
+  // CHECK: %[[RESULT:.*]] = shape.from_extent_tensor %[[A]] : tensor<?xindex>
+  // CHECK: return %[[RESULT]] : !shape.shape
+  %0 = shape.broadcast %a : tensor<?xindex> -> !shape.shape
+  return %0 : !shape.shape
 }
 
 // -----


        


More information about the Mlir-commits mailing list