[Mlir-commits] [mlir] eb56fa9 - [MLIR][Shape] Fix `shape.broadcast` to standard lowering

Frederik Gossen llvmlistbot at llvm.org
Thu Apr 29 01:09:34 PDT 2021


Author: Frederik Gossen
Date: 2021-04-29T10:09:15+02:00
New Revision: eb56fa97de96856bb63e31340598a356056470c5

URL: https://github.com/llvm/llvm-project/commit/eb56fa97de96856bb63e31340598a356056470c5
DIFF: https://github.com/llvm/llvm-project/commit/eb56fa97de96856bb63e31340598a356056470c5.diff

LOG: [MLIR][Shape] Fix `shape.broadcast` to standard lowering

Differential Revision: https://reviews.llvm.org/D101456

Added: 
    

Modified: 
    mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
    mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index e0342f6162c5..9e0020a8e859 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -155,17 +155,18 @@ LogicalResult BroadcastOpConverter::matchAndRewrite(
                        return lb.create<SubIOp>(indexTy, maxRank, v);
                      }));
 
-  rewriter.replaceOp(
-      op, lb.create<tensor::GenerateOp>(
-                getExtentTensorType(lb.getContext()), ValueRange{maxRank},
-                [&](OpBuilder &b, Location loc, ValueRange args) {
-                  Value broadcastedDim = getBroadcastedDim(
-                      ImplicitLocOpBuilder(loc, b), transformed.shapes(),
-                      rankDiffs, args[0]);
-
-                  b.create<tensor::YieldOp>(loc, broadcastedDim);
-                })
-              ->getResults());
+  Value replacement = lb.create<tensor::GenerateOp>(
+      getExtentTensorType(lb.getContext()), ValueRange{maxRank},
+      [&](OpBuilder &b, Location loc, ValueRange args) {
+        Value broadcastedDim =
+            getBroadcastedDim(ImplicitLocOpBuilder(loc, b),
+                              transformed.shapes(), rankDiffs, args[0]);
+
+        b.create<tensor::YieldOp>(loc, broadcastedDim);
+      });
+  if (replacement.getType() != op.getType())
+    replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
+  rewriter.replaceOp(op, replacement);
   return success();
 }
 

diff  --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index 751f5002703b..98000445443c 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -593,6 +593,17 @@ func @broadcast_3_shapes_
diff erent_extents(%a : tensor<2xindex>,
   return
 }
 
+// ----
+
+// CHECK-LABEL: @broadcast_to_known_rank
+func @broadcast_to_known_rank(%a : tensor<1xindex>, %b : tensor<3xindex>)
+    -> tensor<3xindex> {
+  // CHECK: %[[RES:.*]] = tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
+  // CHECK: return %[[RES]] : tensor<3xindex>
+  %0 = shape.broadcast %a, %b : tensor<1xindex>, tensor<3xindex> -> tensor<3xindex>
+  return %0 : tensor<3xindex>
+}
+
 // -----
 
 // Lower `split_at`


        


More information about the Mlir-commits mailing list