[Mlir-commits] [mlir] 6d10d31 - [MLIR][Shape] Support transforming shape.num_elements on tensors

Stephan Herhut llvmlistbot at llvm.org
Tue Jul 28 05:13:20 PDT 2020


Author: Stephan Herhut
Date: 2020-07-28T14:13:06+02:00
New Revision: 6d10d317d8b0f1975dbb17850efd7c069f6ee8fd

URL: https://github.com/llvm/llvm-project/commit/6d10d317d8b0f1975dbb17850efd7c069f6ee8fd
DIFF: https://github.com/llvm/llvm-project/commit/6d10d317d8b0f1975dbb17850efd7c069f6ee8fd.diff

LOG: [MLIR][Shape] Support transforming shape.num_elements on tensors

The current transformation to shape.reduce does not support tensor values.
This adds the required changes to make that work, including fixing the builder
for shape.reduce.

Differential Revision: https://reviews.llvm.org/D84744

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
    mlir/test/Dialect/Shape/shape-to-shape.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 4887c87c1e5f..3c71e3409923 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -834,7 +834,13 @@ void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
   bodyRegion->push_back(new Block);
   Block &bodyBlock = bodyRegion->front();
   bodyBlock.addArgument(builder.getIndexType());
-  bodyBlock.addArgument(SizeType::get(builder.getContext()));
+
+  Type elementType;
+  if (auto tensorType = shape.getType().dyn_cast<TensorType>())
+    elementType = tensorType.getElementType();
+  else
+    elementType = SizeType::get(builder.getContext());
+  bodyBlock.addArgument(elementType);
 
   for (Type initValType : initVals.getTypes()) {
     bodyBlock.addArgument(initValType);

diff  --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
index bb2b03b8ec08..a84fad1f9460 100644
--- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
@@ -9,6 +9,7 @@
 #include "PassDetail.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
@@ -32,14 +33,18 @@ LogicalResult
 NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
                                         PatternRewriter &rewriter) const {
   auto loc = op.getLoc();
-  Value init = rewriter.create<ConstSizeOp>(loc, rewriter.getIndexAttr(1));
+  Type valueType = op.getResult().getType();
+  Value init = op.getDialect()
+                   ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
+                                         valueType, loc)
+                   ->getResult(0);
   ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.shape(), init);
 
   // Generate reduce operator.
   Block *body = reduce.getBody();
   OpBuilder b = OpBuilder::atBlockEnd(body);
-  Value product = b.create<MulOp>(loc, b.getType<SizeType>(),
-                                  body->getArgument(1), body->getArgument(2));
+  Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
+                                  body->getArgument(2));
   b.create<YieldOp>(loc, product);
 
   rewriter.replaceOp(op, reduce.result());
@@ -60,7 +65,7 @@ void ShapeToShapeLowering::runOnFunction() {
   populateShapeRewritePatterns(&ctx, patterns);
 
   ConversionTarget target(getContext());
-  target.addLegalDialect<ShapeDialect>();
+  target.addLegalDialect<ShapeDialect, StandardOpsDialect>();
   target.addIllegalOp<NumElementsOp>();
   if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
     signalPassFailure();

diff  --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir
index d1b00bc12a22..481d682942bb 100644
--- a/mlir/test/Dialect/Shape/shape-to-shape.mlir
+++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir
@@ -14,3 +14,18 @@ func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
 // CHECK: }
 // CHECK: return [[NUM_ELEMENTS]] : !shape.size
 
+// -----
+
+// CHECK-LABEL: func @num_elements_to_reduce_on_index
+// CHECK-SAME:  ([[ARG:%.*]]: tensor<?xindex>) -> index
+func @num_elements_to_reduce_on_index(%shape : tensor<?xindex>) -> index {
+  %num_elements = shape.num_elements %shape : tensor<?xindex> -> index
+  return %num_elements : index
+}
+// CHECK: [[C1:%.*]] = constant 1 : index
+// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) : tensor<?xindex> -> index
+// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: index, [[ACC:%.*]]: index
+// CHECK:   [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]]
+// CHECK:   shape.yield [[NEW_ACC]] : index
+// CHECK: }
+// CHECK: return [[NUM_ELEMENTS]] : index


        


More information about the Mlir-commits mailing list