[Mlir-commits] [mlir] eb56fa9 - [MLIR][Shape] Fix `shape.broadcast` to standard lowering
Frederik Gossen
llvmlistbot at llvm.org
Thu Apr 29 01:09:34 PDT 2021
Author: Frederik Gossen
Date: 2021-04-29T10:09:15+02:00
New Revision: eb56fa97de96856bb63e31340598a356056470c5
URL: https://github.com/llvm/llvm-project/commit/eb56fa97de96856bb63e31340598a356056470c5
DIFF: https://github.com/llvm/llvm-project/commit/eb56fa97de96856bb63e31340598a356056470c5.diff
LOG: [MLIR][Shape] Fix `shape.broadcast` to standard lowering
Differential Revision: https://reviews.llvm.org/D101456
Added:
Modified:
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index e0342f6162c5..9e0020a8e859 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -155,17 +155,18 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
return lb.create<SubIOp>(indexTy, maxRank, v);
}));
- rewriter.replaceOp(
- op, lb.create<tensor::GenerateOp>(
- getExtentTensorType(lb.getContext()), ValueRange{maxRank},
- [&](OpBuilder &b, Location loc, ValueRange args) {
- Value broadcastedDim = getBroadcastedDim(
- ImplicitLocOpBuilder(loc, b), transformed.shapes(),
- rankDiffs, args[0]);
-
- b.create<tensor::YieldOp>(loc, broadcastedDim);
- })
- ->getResults());
+ Value replacement = lb.create<tensor::GenerateOp>(
+ getExtentTensorType(lb.getContext()), ValueRange{maxRank},
+ [&](OpBuilder &b, Location loc, ValueRange args) {
+ Value broadcastedDim =
+ getBroadcastedDim(ImplicitLocOpBuilder(loc, b),
+ transformed.shapes(), rankDiffs, args[0]);
+
+ b.create<tensor::YieldOp>(loc, broadcastedDim);
+ });
+ if (replacement.getType() != op.getType())
+ replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
+ rewriter.replaceOp(op, replacement);
return success();
}
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 751f5002703b..98000445443c 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -593,6 +593,17 @@ func @broadcast_3_shapes_
diff erent_extents(%a : tensor<2xindex>,
return
}
+// ----
+
+// CHECK-LABEL: @broadcast_to_known_rank
+func @broadcast_to_known_rank(%a : tensor<1xindex>, %b : tensor<3xindex>)
+ -> tensor<3xindex> {
+ // CHECK: %[[RES:.*]] = tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
+ // CHECK: return %[[RES]] : tensor<3xindex>
+ %0 = shape.broadcast %a, %b : tensor<1xindex>, tensor<3xindex> -> tensor<3xindex>
+ return %0 : tensor<3xindex>
+}
+
// -----
// Lower `split_at`
More information about the Mlir-commits
mailing list