[Mlir-commits] [mlir] 2b9b999 - [MLIR][Shape] Replace single operand broadcasts with appropriate cast
Frederik Gossen
llvmlistbot at llvm.org
Tue Apr 27 05:49:15 PDT 2021
Author: Frederik Gossen
Date: 2021-04-27T14:48:56+02:00
New Revision: 2b9b999d4d35d02170dc12a48d0c4d0a3ad22739
URL: https://github.com/llvm/llvm-project/commit/2b9b999d4d35d02170dc12a48d0c4d0a3ad22739
DIFF: https://github.com/llvm/llvm-project/commit/2b9b999d4d35d02170dc12a48d0c4d0a3ad22739.diff
LOG: [MLIR][Shape] Replace single operand broadcasts with appropriate cast
Differential Revision: https://reviews.llvm.org/D101350
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index c19acf33f535..ea2b10bdcf0e 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -558,14 +558,26 @@ struct BroadcastForwardSingleOperandPattern
LogicalResult matchAndRewrite(BroadcastOp op,
PatternRewriter &rewriter) const override {
- if (op.getNumOperands() == 1) {
- Value uniqueShapeOperand = op.shapes().front();
- if (uniqueShapeOperand.getType() == op.getType()) {
- rewriter.replaceOp(op, uniqueShapeOperand);
- return success();
+ if (op.getNumOperands() != 1)
+ return failure();
+ Value replacement = op.shapes().front();
+
+ // Insert cast if needed.
+ if (replacement.getType() != op.getType()) {
+ auto loc = op.getLoc();
+ if (op.getType().isa<ShapeType>()) {
+ replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
+ } else {
+ assert(!op.getType().isa<ShapeType>() &&
+ !replacement.getType().isa<ShapeType>() &&
+ "expect extent tensor cast");
+ replacement =
+ rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
}
}
- return failure();
+
+ rewriter.replaceOp(op, replacement);
+ return success();
}
};
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 68545d161374..a47db3044e0e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1242,13 +1242,24 @@ func @broadcast_on_single_operand(%a : tensor<?xindex>) {
// -----
-// CHECK-LABEL: @broadcast_on_single_operand
+// CHECK-LABEL: @broadcast_as_tensor_cast
// CHECK-SAME: (%[[A:.*]]: tensor<3xindex>)
-func @broadcast_on_single_operand(%a : tensor<3xindex>) {
- // CHECK: broadcast %[[A]]
+func @broadcast_as_tensor_cast(%a : tensor<3xindex>) -> tensor<?xindex> {
+ // CHECK: %[[RESULT:.*]] = tensor.cast %[[A]] : tensor<3xindex> to tensor<?xindex>
+ // CHECK: return %[[RESULT]] : tensor<?xindex>
%0 = shape.broadcast %a : tensor<3xindex> -> tensor<?xindex>
- "use"(%0) : (tensor<?xindex>) -> ()
- return
+ return %0 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_as_from_extent_tensor
+// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>)
+func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape {
+ // CHECK: %[[RESULT:.*]] = shape.from_extent_tensor %[[A]] : tensor<?xindex>
+ // CHECK: return %[[RESULT]] : !shape.shape
+ %0 = shape.broadcast %a : tensor<?xindex> -> !shape.shape
+ return %0 : !shape.shape
}
// -----
More information about the Mlir-commits
mailing list