[Mlir-commits] [mlir] 6d10d31 - [MLIR][Shape] Support transforming shape.num_elements on tensors
Stephan Herhut
llvmlistbot at llvm.org
Tue Jul 28 05:13:20 PDT 2020
Author: Stephan Herhut
Date: 2020-07-28T14:13:06+02:00
New Revision: 6d10d317d8b0f1975dbb17850efd7c069f6ee8fd
URL: https://github.com/llvm/llvm-project/commit/6d10d317d8b0f1975dbb17850efd7c069f6ee8fd
DIFF: https://github.com/llvm/llvm-project/commit/6d10d317d8b0f1975dbb17850efd7c069f6ee8fd.diff
LOG: [MLIR][Shape] Support transforming shape.num_elements on tensors
The current transformation to shape.reduce does not support tensor values.
This adds the required changes to make that work, including fixing the builder
for shape.reduce.
Differential Revision: https://reviews.llvm.org/D84744
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
mlir/test/Dialect/Shape/shape-to-shape.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 4887c87c1e5f..3c71e3409923 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -834,7 +834,13 @@ void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
bodyRegion->push_back(new Block);
Block &bodyBlock = bodyRegion->front();
bodyBlock.addArgument(builder.getIndexType());
- bodyBlock.addArgument(SizeType::get(builder.getContext()));
+
+ Type elementType;
+ if (auto tensorType = shape.getType().dyn_cast<TensorType>())
+ elementType = tensorType.getElementType();
+ else
+ elementType = SizeType::get(builder.getContext());
+ bodyBlock.addArgument(elementType);
for (Type initValType : initVals.getTypes()) {
bodyBlock.addArgument(initValType);
diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
index bb2b03b8ec08..a84fad1f9460 100644
--- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
@@ -9,6 +9,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -32,14 +33,18 @@ LogicalResult
NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
PatternRewriter &rewriter) const {
auto loc = op.getLoc();
- Value init = rewriter.create<ConstSizeOp>(loc, rewriter.getIndexAttr(1));
+ Type valueType = op.getResult().getType();
+ Value init = op.getDialect()
+ ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
+ valueType, loc)
+ ->getResult(0);
ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.shape(), init);
// Generate reduce operator.
Block *body = reduce.getBody();
OpBuilder b = OpBuilder::atBlockEnd(body);
- Value product = b.create<MulOp>(loc, b.getType<SizeType>(),
- body->getArgument(1), body->getArgument(2));
+ Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
+ body->getArgument(2));
b.create<YieldOp>(loc, product);
rewriter.replaceOp(op, reduce.result());
@@ -60,7 +65,7 @@ void ShapeToShapeLowering::runOnFunction() {
populateShapeRewritePatterns(&ctx, patterns);
ConversionTarget target(getContext());
- target.addLegalDialect<ShapeDialect>();
+ target.addLegalDialect<ShapeDialect, StandardOpsDialect>();
target.addIllegalOp<NumElementsOp>();
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
diff --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir
index d1b00bc12a22..481d682942bb 100644
--- a/mlir/test/Dialect/Shape/shape-to-shape.mlir
+++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir
@@ -14,3 +14,18 @@ func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
// CHECK: }
// CHECK: return [[NUM_ELEMENTS]] : !shape.size
+// -----
+
+// CHECK-LABEL: func @num_elements_to_reduce_on_index
+// CHECK-SAME: ([[ARG:%.*]]: tensor<?xindex>) -> index
+func @num_elements_to_reduce_on_index(%shape : tensor<?xindex>) -> index {
+ %num_elements = shape.num_elements %shape : tensor<?xindex> -> index
+ return %num_elements : index
+}
+// CHECK: [[C1:%.*]] = constant 1 : index
+// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) : tensor<?xindex> -> index
+// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: index, [[ACC:%.*]]: index
+// CHECK: [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]]
+// CHECK: shape.yield [[NEW_ACC]] : index
+// CHECK: }
+// CHECK: return [[NUM_ELEMENTS]] : index
More information about the Mlir-commits
mailing list