[Mlir-commits] [mlir] a70f2eb - [MLIR][Shape] Merge `shape` to `std`/`scf` lowerings.
Frederik Gossen
llvmlistbot at llvm.org
Mon Sep 7 07:40:01 PDT 2020
Author: Frederik Gossen
Date: 2020-09-07T14:39:37Z
New Revision: a70f2eb3e39a42a71ba077247f9deafbdf1e8092
URL: https://github.com/llvm/llvm-project/commit/a70f2eb3e39a42a71ba077247f9deafbdf1e8092
DIFF: https://github.com/llvm/llvm-project/commit/a70f2eb3e39a42a71ba077247f9deafbdf1e8092.diff
LOG: [MLIR][Shape] Merge `shape` to `std`/`scf` lowerings.
Merge the two lowering passes because they are not useful by themselves. The new
pass lowers to `std` and `scf` is considered an auxiliary dialect.
See also
https://llvm.discourse.group/t/conversions-with-multiple-target-dialects/1541/12
Differential Revision: https://reviews.llvm.org/D86779
Added:
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
Removed:
mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 5dd10932981b..b04498598b29 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -23,7 +23,6 @@
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
-#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1b27a7308c7a..d4b478dbf4ed 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -239,17 +239,7 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
let summary = "Convert operations from the shape dialect into the standard "
"dialect";
let constructor = "mlir::createConvertShapeToStandardPass()";
- let dependentDialects = ["StandardOpsDialect"];
-}
-
-//===----------------------------------------------------------------------===//
-// ShapeToSCF
-//===----------------------------------------------------------------------===//
-
-def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
- let summary = "Convert operations from the shape dialect to the SCF dialect";
- let constructor = "mlir::createConvertShapeToSCFPass()";
- let dependentDialects = ["scf::SCFDialect"];
+ let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h b/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
deleted file mode 100644
index f953f6e2ddf1..000000000000
--- a/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h
+++ /dev/null
@@ -1,27 +0,0 @@
-//===- ShapeToSCF.h - Conversion utils from Shape to SCF dialect ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
-#define MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
-
-#include <memory>
-
-namespace mlir {
-
-class MLIRContext;
-class FunctionPass;
-class OwningRewritePatternList;
-
-void populateShapeToSCFConversionPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
-
-std::unique_ptr<FunctionPass> createConvertShapeToSCFPass();
-
-} // namespace mlir
-
-#endif // MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index c2bb2130569d..fe2af07b2a6a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -12,7 +12,6 @@ add_subdirectory(OpenMPToLLVM)
add_subdirectory(SCFToGPU)
add_subdirectory(SCFToSPIRV)
add_subdirectory(SCFToStandard)
-add_subdirectory(ShapeToSCF)
add_subdirectory(ShapeToStandard)
add_subdirectory(SPIRVToLLVM)
add_subdirectory(StandardToLLVM)
diff --git a/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt b/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
deleted file mode 100644
index 60dd2b8514da..000000000000
--- a/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-add_mlir_conversion_library(MLIRShapeToSCF
- ShapeToSCF.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShapeToSCF
-
- DEPENDS
- MLIRConversionPassIncGen
-
- LINK_COMPONENTS
- Core
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRShape
- MLIRPass
- MLIRSCF
- MLIRTransforms
- )
diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
deleted file mode 100644
index ae326c5c513e..000000000000
--- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp
+++ /dev/null
@@ -1,337 +0,0 @@
-//===- ShapeToSCF.cpp - conversion from Shape to SCF dialect --------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
-
-#include "../PassDetail.h"
-#include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/Shape/IR/Shape.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-using namespace mlir;
-using namespace mlir::shape;
-using namespace mlir::scf;
-
-namespace {
-struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
- using OpConversionPattern<BroadcastOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult BroadcastOpConverter::matchAndRewrite(
- BroadcastOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- // For now, this lowering is only defined on `tensor<?xindex>` operands, not
- // on shapes.
- if (op.getType().isa<ShapeType>())
- return failure();
-
- assert(!op.lhs().getType().isa<ShapeType>() &&
- !op.rhs().getType().isa<ShapeType>());
- auto loc = op.getLoc();
- BroadcastOp::Adaptor transformed(operands);
- Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
-
- // Find smaller and greater rank and extent tensor.
- Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
- Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
- Value lhsSmaller =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
- Type indexTy = rewriter.getIndexType();
- Type extentTensorTy = op.getType();
- auto ifOp = rewriter.create<IfOp>(
- loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
- lhsSmaller,
- [&](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(),
- rhsRank, transformed.rhs()});
- },
- [&](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(),
- lhsRank, transformed.lhs()});
- });
- Value smallerRank = ifOp.getResult(0);
- Value smallerOperand = ifOp.getResult(1);
- Value greaterRank = ifOp.getResult(2);
- Value greaterOperand = ifOp.getResult(3);
-
- // Allocate stack memory for the broadcasted extent tensor.
- Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
- Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
-
- // Copy extents from greater operand that are not challenged.
- Value rankDiff =
- rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank);
- rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
- [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
- Value extent = b.create<ExtractElementOp>(
- loc, greaterOperand, ValueRange{iv});
- b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
- b.create<scf::YieldOp>(loc);
- });
-
- // Determine remaining broadcasted extents.
- rewriter.create<ForOp>(
- loc, rankDiff, greaterRank, one, llvm::None,
- [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
- Value greaterOperandExtent =
- b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv});
- Value greaterOperandExtentIsOne =
- b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
- auto ifOp = b.create<IfOp>(
- loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
- [&](OpBuilder &b, Location loc) {
- Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
- Value smallerOperandExtent = b.create<ExtractElementOp>(
- loc, smallerOperand, ValueRange{ivShifted});
- b.create<scf::YieldOp>(loc, smallerOperandExtent);
- },
- [&](OpBuilder &b, Location loc) {
- b.create<scf::YieldOp>(loc, greaterOperandExtent);
- });
- Value extent = ifOp.getResult(0);
- b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
- b.create<scf::YieldOp>(loc);
- });
-
- // Load broadcasted shape as an extent tensor.
- rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem);
- return success();
-}
-
-namespace {
-/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
-/// only defined on `tensor<?xindex>` operands. The test for equality first
-/// compares their size and, if equal, checks every extent for equality.
-///
-/// Example:
-///
-/// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
-///
-/// becomes
-///
-/// %c0 = constant 0 : index
-/// %0 = dim %arg0, %c0 : tensor<?xindex>
-/// %1 = dim %arg1, %c0 : tensor<?xindex>
-/// %2 = cmpi "eq", %0, %1 : index
-/// %result = scf.if %2 -> (i1) {
-/// %c1 = constant 1 : index
-/// %true = constant true
-/// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
-/// %5 = extract_element %arg0[%arg2] : tensor<?xindex>
-/// %6 = extract_element %arg1[%arg2] : tensor<?xindex>
-/// %7 = cmpi "eq", %5, %6 : index
-/// %8 = and %arg3, %7 : i1
-/// scf.yield %8 : i1
-/// }
-/// scf.yield %4 : i1
-/// } else {
-/// %false = constant false
-/// scf.yield %false : i1
-/// }
-///
-struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
- using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult
-ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- // For now, this lowering is only defined on `tensor<?xindex>` operands, not
- // on shapes.
- if (op.lhs().getType().isa<ShapeType>() ||
- op.rhs().getType().isa<ShapeType>()) {
- return failure();
- }
-
- ShapeEqOp::Adaptor transformed(operands);
- auto loc = op.getLoc();
- Type indexTy = rewriter.getIndexType();
- Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
- Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
- Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
- Value eqRank =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
- Type i1Ty = rewriter.getI1Type();
- rewriter.replaceOpWithNewOp<IfOp>(
- op, i1Ty, eqRank,
- [&](OpBuilder &b, Location loc) {
- Value one = b.create<ConstantIndexOp>(loc, 1);
- Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
- auto loop = b.create<scf::ForOp>(
- loc, zero, lhsRank, one, ValueRange{init},
- [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
- Value conj = args[0];
- Value lhsExtent =
- b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
- Value rhsExtent =
- b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
- Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
- lhsExtent, rhsExtent);
- Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
- b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
- });
- b.create<scf::YieldOp>(loc, loop.getResults());
- },
- [&](OpBuilder &b, Location loc) {
- Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
- b.create<scf::YieldOp>(loc, result);
- });
- return success();
-}
-
-namespace {
-/// Converts `shape.reduce` to `scf.for`.
-struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
-public:
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final;
-};
-} // namespace
-
-LogicalResult
-ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- // For now, this lowering is only defined on `tensor<?xindex>` operands.
- if (op.shape().getType().isa<ShapeType>())
- return failure();
-
- auto loc = op.getLoc();
- shape::ReduceOp::Adaptor transformed(operands);
-
- Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- Type indexTy = rewriter.getIndexType();
- Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
-
- auto loop = rewriter.create<scf::ForOp>(
- loc, zero, rank, one, op.initVals(),
- [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
- Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
-
- SmallVector<Value, 2> mappedValues{iv, extent};
- mappedValues.append(args.begin(), args.end());
-
- BlockAndValueMapping mapping;
- Block *reduceBody = op.getBody();
- mapping.map(reduceBody->getArguments(), mappedValues);
- for (auto &nested : reduceBody->without_terminator())
- b.clone(nested, mapping);
-
- SmallVector<Value, 2> mappedResults;
- for (auto result : reduceBody->getTerminator()->getOperands())
- mappedResults.push_back(mapping.lookup(result));
- b.create<scf::YieldOp>(loc, mappedResults);
- });
-
- rewriter.replaceOp(op, loop.getResults());
- return success();
-}
-
-namespace {
-/// Converts `shape_of` to for loop for unranked tensors.
-class ShapeOfOpConverter : public OpConversionPattern<ShapeOfOp> {
-public:
- using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult
-ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- // For now, this lowering supports only error-free arguments.
- if (op.getType().isa<ShapeType>())
- return failure();
-
- // For ranked tensors `shape_of` lowers to `std` and the pattern can be
- // found in the corresponding pass.
- ShapeOfOp::Adaptor transformed(operands);
- Value arg = transformed.arg();
- Type argTy = arg.getType();
- if (argTy.isa<RankedTensorType>())
- return failure();
-
- // Allocate stack memory.
- auto loc = op.getLoc();
- Value rank = rewriter.create<mlir::RankOp>(loc, arg);
- Type indexTy = rewriter.getIndexType();
- Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
- Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{rank});
-
- // Copy shape extents to stack-allocated memory.
- Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
- Value one = rewriter.create<ConstantIndexOp>(loc, 1);
- rewriter.create<scf::ForOp>(
- loc, zero, rank, one, llvm::None,
- [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
- Value dim = rewriter.create<DimOp>(loc, arg, iv);
- rewriter.create<StoreOp>(loc, dim, mem, ValueRange{iv});
- rewriter.create<scf::YieldOp>(loc);
- });
-
- // Load extents to tensor value.
- rewriter.replaceOpWithNewOp<TensorLoadOp>(op.getOperation(), mem);
- return success();
-}
-
-namespace {
-struct ConvertShapeToSCFPass
- : public ConvertShapeToSCFBase<ConvertShapeToSCFPass> {
- void runOnFunction() override;
-};
-} // namespace
-
-void ConvertShapeToSCFPass::runOnFunction() {
- MLIRContext &ctx = getContext();
-
- // Populate conversion patterns.
- OwningRewritePatternList patterns;
- populateShapeToSCFConversionPatterns(patterns, &ctx);
-
- // Setup target legality.
- ConversionTarget target(getContext());
- target.addLegalDialect<SCFDialect, StandardOpsDialect>();
-
- // Apply conversion.
- if (failed(applyPartialConversion(getFunction(), target, patterns)))
- signalPassFailure();
-}
-
-void mlir::populateShapeToSCFConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
- // clang-format off
- patterns.insert<
- BroadcastOpConverter,
- ShapeEqOpConverter,
- ReduceOpConverter,
- ShapeOfOpConverter>(ctx);
- // clang-format on
-}
-
-std::unique_ptr<FunctionPass> mlir::createConvertShapeToSCFPass() {
- return std::make_unique<ConvertShapeToSCFPass>();
-}
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index e92bb83d4f42..8c917e08f942 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -12,10 +12,12 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::shape;
+using namespace mlir::scf;
/// Conversion patterns.
namespace {
@@ -63,67 +65,94 @@ class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
} // namespace
namespace {
-class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
-public:
- using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
- return success();
- }
-};
-} // namespace
-
-namespace {
-class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
-public:
- using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
+struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
+ using OpConversionPattern<BroadcastOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+ matchAndRewrite(BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
-LogicalResult ShapeOfOpConversion::matchAndRewrite(
- ShapeOfOp op, ArrayRef<Value> operands,
+LogicalResult BroadcastOpConverter::matchAndRewrite(
+ BroadcastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
-
- // For now, only error-free types are supported by this lowering.
+ // For now, this lowering is only defined on `tensor<?xindex>` operands, not
+ // on shapes.
if (op.getType().isa<ShapeType>())
return failure();
- // For unranked tensors `shape_of` lowers to `scf` and the pattern can be
- // found in the corresponding pass.
- ShapeOfOp::Adaptor transformed(operands);
- Value tensorVal = transformed.arg();
- Type tensorTy = tensorVal.getType();
- if (tensorTy.isa<UnrankedTensorType>())
- return failure();
-
- // Build values for individual dimensions.
- SmallVector<Value, 8> dimValues;
- RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
- int64_t rank = rankedTensorTy.getRank();
+ assert(!op.lhs().getType().isa<ShapeType>() &&
+ !op.rhs().getType().isa<ShapeType>());
auto loc = op.getLoc();
- for (int64_t i = 0; i < rank; i++) {
- if (rankedTensorTy.isDynamicDim(i)) {
- Value dimVal = rewriter.create<DimOp>(loc, tensorVal, i);
- dimValues.push_back(dimVal);
- } else {
- int64_t dim = rankedTensorTy.getDimSize(i);
- Value dimVal = rewriter.create<ConstantIndexOp>(loc, dim);
- dimValues.push_back(dimVal);
- }
- }
-
- // Materialize extent tensor.
- Value staticExtentTensor =
- rewriter.create<TensorFromElementsOp>(loc, dimValues);
- rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
- op.getType());
+ BroadcastOp::Adaptor transformed(operands);
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+
+ // Find smaller and greater rank and extent tensor.
+ Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
+ Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
+ Value lhsSmaller =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
+ Type indexTy = rewriter.getIndexType();
+ Type extentTensorTy = op.getType();
+ auto ifOp = rewriter.create<IfOp>(
+ loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
+ lhsSmaller,
+ [&](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(loc, ValueRange{lhsRank, transformed.lhs(),
+ rhsRank, transformed.rhs()});
+ },
+ [&](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(loc, ValueRange{rhsRank, transformed.rhs(),
+ lhsRank, transformed.lhs()});
+ });
+ Value smallerRank = ifOp.getResult(0);
+ Value smallerOperand = ifOp.getResult(1);
+ Value greaterRank = ifOp.getResult(2);
+ Value greaterOperand = ifOp.getResult(3);
+
+ // Allocate stack memory for the broadcasted extent tensor.
+ Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
+ Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{greaterRank});
+
+ // Copy extents from greater operand that are not challenged.
+ Value rankDiff =
+ rewriter.create<SubIOp>(loc, indexTy, greaterRank, smallerRank);
+ rewriter.create<ForOp>(loc, zero, rankDiff, one, llvm::None,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
+ Value extent = b.create<ExtractElementOp>(
+ loc, greaterOperand, ValueRange{iv});
+ b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
+ b.create<scf::YieldOp>(loc);
+ });
+
+ // Determine remaining broadcasted extents.
+ rewriter.create<ForOp>(
+ loc, rankDiff, greaterRank, one, llvm::None,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange) {
+ Value greaterOperandExtent =
+ b.create<ExtractElementOp>(loc, greaterOperand, ValueRange{iv});
+ Value greaterOperandExtentIsOne =
+ b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterOperandExtent, one);
+ auto ifOp = b.create<IfOp>(
+ loc, TypeRange{indexTy}, greaterOperandExtentIsOne,
+ [&](OpBuilder &b, Location loc) {
+ Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
+ Value smallerOperandExtent = b.create<ExtractElementOp>(
+ loc, smallerOperand, ValueRange{ivShifted});
+ b.create<scf::YieldOp>(loc, smallerOperandExtent);
+ },
+ [&](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(loc, greaterOperandExtent);
+ });
+ Value extent = ifOp.getResult(0);
+ b.create<StoreOp>(loc, extent, mem, ValueRange{iv});
+ b.create<scf::YieldOp>(loc);
+ });
+
+ // Load broadcasted shape as an extent tensor.
+ rewriter.replaceOpWithNewOp<TensorLoadOp>(op, mem);
return success();
}
@@ -161,26 +190,23 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
}
namespace {
-class ToExtentTensorOpConversion
- : public OpConversionPattern<ToExtentTensorOp> {
+class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
public:
- using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
+ using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- ToExtentTensorOpAdaptor adaptor(operands);
-
- if (!adaptor.input().getType().isa<RankedTensorType>())
- return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
-
- rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
- op.getType());
- return success();
- }
+ matchAndRewrite(ConstSizeOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
};
} // namespace
+LogicalResult ConstSizeOpConversion::matchAndRewrite(
+ ConstSizeOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(op, op.value().getSExtValue());
+ return success();
+}
+
namespace {
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@@ -239,6 +265,236 @@ RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
return success();
}
+namespace {
+/// Converts `shape.reduce` to `scf.for`.
+struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final;
+};
+} // namespace
+
+LogicalResult
+ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ // For now, this lowering is only defined on `tensor<?xindex>` operands.
+ if (op.shape().getType().isa<ShapeType>())
+ return failure();
+
+ auto loc = op.getLoc();
+ shape::ReduceOp::Adaptor transformed(operands);
+
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ Type indexTy = rewriter.getIndexType();
+ Value rank = rewriter.create<DimOp>(loc, indexTy, transformed.shape(), zero);
+
+ auto loop = rewriter.create<scf::ForOp>(
+ loc, zero, rank, one, op.initVals(),
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ Value extent = b.create<ExtractElementOp>(loc, transformed.shape(), iv);
+
+ SmallVector<Value, 2> mappedValues{iv, extent};
+ mappedValues.append(args.begin(), args.end());
+
+ BlockAndValueMapping mapping;
+ Block *reduceBody = op.getBody();
+ mapping.map(reduceBody->getArguments(), mappedValues);
+ for (auto &nested : reduceBody->without_terminator())
+ b.clone(nested, mapping);
+
+ SmallVector<Value, 2> mappedResults;
+ for (auto result : reduceBody->getTerminator()->getOperands())
+ mappedResults.push_back(mapping.lookup(result));
+ b.create<scf::YieldOp>(loc, mappedResults);
+ });
+
+ rewriter.replaceOp(op, loop.getResults());
+ return success();
+}
+
+namespace {
+/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
+/// only defined on `tensor<?xindex>` operands. The test for equality first
+/// compares their size and, if equal, checks every extent for equality.
+///
+/// Example:
+///
+/// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
+///
+/// becomes
+///
+/// %c0 = constant 0 : index
+/// %0 = dim %arg0, %c0 : tensor<?xindex>
+/// %1 = dim %arg1, %c0 : tensor<?xindex>
+/// %2 = cmpi "eq", %0, %1 : index
+/// %result = scf.if %2 -> (i1) {
+/// %c1 = constant 1 : index
+/// %true = constant true
+/// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
+/// %5 = extract_element %arg0[%arg2] : tensor<?xindex>
+/// %6 = extract_element %arg1[%arg2] : tensor<?xindex>
+/// %7 = cmpi "eq", %5, %6 : index
+/// %8 = and %arg3, %7 : i1
+/// scf.yield %8 : i1
+/// }
+/// scf.yield %4 : i1
+/// } else {
+/// %false = constant false
+/// scf.yield %false : i1
+/// }
+///
+struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
+ using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult
+ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ // For now, this lowering is only defined on `tensor<?xindex>` operands, not
+ // on shapes.
+ if (op.lhs().getType().isa<ShapeType>() ||
+ op.rhs().getType().isa<ShapeType>()) {
+ return failure();
+ }
+
+ ShapeEqOp::Adaptor transformed(operands);
+ auto loc = op.getLoc();
+ Type indexTy = rewriter.getIndexType();
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value lhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.lhs(), zero);
+ Value rhsRank = rewriter.create<DimOp>(loc, indexTy, transformed.rhs(), zero);
+ Value eqRank =
+ rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsRank, rhsRank);
+ Type i1Ty = rewriter.getI1Type();
+ rewriter.replaceOpWithNewOp<IfOp>(
+ op, i1Ty, eqRank,
+ [&](OpBuilder &b, Location loc) {
+ Value one = b.create<ConstantIndexOp>(loc, 1);
+ Value init = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
+ auto loop = b.create<scf::ForOp>(
+ loc, zero, lhsRank, one, ValueRange{init},
+ [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
+ Value conj = args[0];
+ Value lhsExtent =
+ b.create<ExtractElementOp>(loc, transformed.lhs(), iv);
+ Value rhsExtent =
+ b.create<ExtractElementOp>(loc, transformed.rhs(), iv);
+ Value eqExtent = b.create<CmpIOp>(loc, CmpIPredicate::eq,
+ lhsExtent, rhsExtent);
+ Value conjNext = b.create<AndOp>(loc, conj, eqExtent);
+ b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
+ });
+ b.create<scf::YieldOp>(loc, loop.getResults());
+ },
+ [&](OpBuilder &b, Location loc) {
+ Value result = b.create<ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
+ b.create<scf::YieldOp>(loc, result);
+ });
+ return success();
+}
+
+namespace {
+class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
+public:
+ using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult ShapeOfOpConversion::matchAndRewrite(
+ ShapeOfOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+
+ // For now, only error-free types are supported by this lowering.
+ if (op.getType().isa<ShapeType>())
+ return failure();
+
+ // For ranked tensor arguments, lower to `tensor_from_elements`.
+ ShapeOfOp::Adaptor transformed(operands);
+ Value tensor = transformed.arg();
+ Type tensorTy = tensor.getType();
+ if (tensorTy.isa<RankedTensorType>()) {
+
+ // Build values for individual extents.
+ SmallVector<Value, 8> extentValues;
+ RankedTensorType rankedTensorTy = tensorTy.cast<RankedTensorType>();
+ int64_t rank = rankedTensorTy.getRank();
+ auto loc = op.getLoc();
+ for (int64_t i = 0; i < rank; i++) {
+ if (rankedTensorTy.isDynamicDim(i)) {
+ Value extent = rewriter.create<DimOp>(loc, tensor, i);
+ extentValues.push_back(extent);
+ } else {
+ Value extent =
+ rewriter.create<ConstantIndexOp>(loc, rankedTensorTy.getDimSize(i));
+ extentValues.push_back(extent);
+ }
+ }
+
+ // Materialize extent tensor.
+ Value staticExtentTensor =
+ rewriter.create<TensorFromElementsOp>(loc, extentValues);
+ rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
+ op.getType());
+ return success();
+ }
+
+ // Allocate stack memory.
+ auto loc = op.getLoc();
+ Value rank = rewriter.create<mlir::RankOp>(loc, tensor);
+ Type indexTy = rewriter.getIndexType();
+ Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy);
+ Value mem = rewriter.create<AllocaOp>(loc, memTy, ValueRange{rank});
+
+ // Copy shape extents to stack-allocated memory.
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ rewriter.create<scf::ForOp>(
+ loc, zero, rank, one, llvm::None,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ Value dim = rewriter.create<DimOp>(loc, tensor, iv);
+ rewriter.create<StoreOp>(loc, dim, mem, ValueRange{iv});
+ rewriter.create<scf::YieldOp>(loc);
+ });
+
+ // Load extents to tensor value.
+ rewriter.replaceOpWithNewOp<TensorLoadOp>(op.getOperation(), mem);
+ return success();
+}
+
+namespace {
+class ToExtentTensorOpConversion
+ : public OpConversionPattern<ToExtentTensorOp> {
+public:
+ using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ToExtentTensorOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ ToExtentTensorOpAdaptor adaptor(operands);
+
+ if (!adaptor.input().getType().isa<RankedTensorType>())
+ return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
+
+ rewriter.replaceOpWithNewOp<TensorCastOp>(op, adaptor.input(),
+ op.getType());
+ return success();
+ }
+};
+} // namespace
+
namespace {
/// Conversion pass.
class ConvertShapeToStandardPass
@@ -252,7 +508,7 @@ void ConvertShapeToStandardPass::runOnOperation() {
// Setup target legality.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
- target.addLegalDialect<StandardOpsDialect>();
+ target.addLegalDialect<StandardOpsDialect, SCFDialect>();
target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
// Setup conversion patterns.
@@ -271,11 +527,14 @@ void mlir::populateShapeToStandardConversionPatterns(
patterns.insert<
AnyOpConversion,
BinaryOpConversion<AddOp, AddIOp>,
- ConstShapeOpConverter,
BinaryOpConversion<MulOp, MulIOp>,
+ BroadcastOpConverter,
+ ConstShapeOpConverter,
ConstSizeOpConversion,
GetExtentOpConverter,
RankOpConverter,
+ ReduceOpConverter,
+ ShapeEqOpConverter,
ShapeOfOpConversion,
ToExtentTensorOpConversion>(ctx);
// clang-format on
diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
deleted file mode 100644
index cc384496dff0..000000000000
--- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir
+++ /dev/null
@@ -1,132 +0,0 @@
-// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
-
-// CHECK-LABEL: @shape_reduce
-// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
-func @shape_reduce(%shape : tensor<?xindex>) -> index {
- %init = constant 1 : index
- %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
- ^bb0(%index : index, %extent : index, %acc: index):
- %new_acc = muli %acc, %extent : index
- shape.yield %new_acc : index
- }
- return %num_elements : index
-}
-// CHECK-NEXT: %[[INIT:.*]] = constant 1 : index
-// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
-// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
-// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
-// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index)
-// CHECK-NEXT: %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]]
-// CHECK-NEXT: %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index
-// CHECK-NEXT: scf.yield %[[NEW_ACC]] : index
-// CHECK-NEXT: }
-// CHECK-NEXT: return %[[RESULT]] : index
-
-// -----
-
-// Don't lower `shape_of` for result type of `shape.shape`.
-// CHECK-LABEL: @shape_of
-// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
-func @shape_of(%arg : tensor<*xf32>) {
- // CHECK: shape.shape
- %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape
- return
-}
-
-// -----
-
-// Lower `shape_of` for unranked tensors.
-// CHECK-LABEL: @shape_of_unranked
-// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
-func @shape_of_unranked(%arg : tensor<*xf32>) {
- // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
- // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xindex>
- // CHECK: %[[C0:.*]] = constant 0 : index
- // CHECK: %[[C1:.*]] = constant 1 : index
- // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
- // CHECK: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
- // CHECK: store %[[DIM]], %[[SHAPE_MEM]][%[[I]]] : memref<?xindex>
- // CHECK: }
- // CHECK: %[[SHAPE:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xindex>
- %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @shape_eq
-// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>) -> i1
-func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
- // CHECK: %[[C0:.*]] = constant 0 : index
- // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor<?xindex>
- // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor<?xindex>
- // CHECK: %[[RANK_EQ:.*]] = cmpi "eq", %[[RANK_A]], %[[RANK_B]]
- // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
- // CHECK: %[[C1:.*]] = constant 1 : index
- // CHECK: %[[INIT:.*]] = constant true
- // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
- // CHECK: %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor<?xindex>
- // CHECK: %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor<?xindex>
- // CHECK: %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]]
- // CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
- // CHECK: scf.yield %[[CONJ_NEXT]] : i1
- // CHECK: }
- // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
- // CHECK: } else {
- // CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false
- // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
- // CHECK: }
- // CHECK: return %[[SHAPE_EQ]] : i1
- %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
- return %result : i1
-}
-
-// -----
-
-// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
-// CHECK-LABEL: @broadcast
-func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
- // CHECK: shape.broadcast
- %c = shape.broadcast %a, %b : tensor<?xindex>, !shape.shape -> !shape.shape
- return %c : !shape.shape
-}
-
-// -----
-
-// CHECK-LABEL: @broadcast
-// CHECK-SAME: (%[[LHS:.*]]: tensor<?xindex>, %[[RHS:.*]]: tensor<?xindex>)
-func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
- // CHECK: %[[C0:.*]] = constant 0 : index
- // CHECK: %[[C1:.*]] = constant 1 : index
- // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
- // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
- // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]]
- // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
- // CHECK: scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
- // CHECK: } else {
- // CHECK: scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
- // CHECK: }
- // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref<?xindex>
- // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index
- // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
- // CHECK: %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
- // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
- // CHECK: }
- // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] {
- // CHECK: %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
- // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index
- // CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
- // CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
- // CHECK: %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor<?xindex>
- // CHECK: scf.yield %[[SMALLER_OPERAND_EXTENT]] : index
- // CHECK: } else {
- // CHECK: scf.yield %[[GREATER_OPERAND_EXTENT]] : index
- // CHECK: }
- // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
- // CHECK: }
- // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
- %0 = shape.broadcast %a, %b
- : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
- return
-}
-
diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
index b0fb5bac9071..bf8e74e5143e 100644
--- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
+++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -26,46 +26,6 @@ func @binary_ops_on_size(%lhs : !shape.size, %rhs : !shape.size) {
// -----
-// Don't lower `shape_of` with `shape.shape` type.
-// CHECK-LABEL: @shape_of
-// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
-func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
- // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape
- %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape
- return
-}
-
-// -----
-
-// Lower `shape_of` for statically shaped tensor.
-// CHECK-LABEL: @shape_of_stat
-// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
-func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
- // CHECK-DAG: %[[C1:.*]] = constant 1 : index
- // CHECK-DAG: %[[C2:.*]] = constant 2 : index
- // CHECK-DAG: %[[C3:.*]] = constant 3 : index
- // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
- %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
- return
-}
-
-// -----
-
-// Lower `shape_of` for dynamically shaped tensor.
-// CHECK-LABEL: @shape_of_dyn
-// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
-func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
- // CHECK-DAG: %[[C1:.*]] = constant 1 : index
- // CHECK-DAG: %[[C5:.*]] = constant 5 : index
- // CHECK-DAG: %[[C2:.*]] = constant 2 : index
- // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
- // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
- %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
- return
-}
-
-// -----
-
// Convert `rank` to `dim` of the first dimension.
// CHECK-LABEL: @rank
// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
@@ -190,3 +150,174 @@ func @to_extent_tensor(%arg: tensor<?xindex>) -> tensor<3xindex> {
// CHECK: return %[[RES]]
return %casted : tensor<3xindex>
}
+
+// CHECK-LABEL: @shape_reduce
+// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> index
+func @shape_reduce(%shape : tensor<?xindex>) -> index {
+ %init = constant 1 : index
+ %num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
+ ^bb0(%index : index, %extent : index, %acc: index):
+ %new_acc = muli %acc, %extent : index
+ shape.yield %new_acc : index
+ }
+ return %num_elements : index
+}
+// CHECK-NEXT: %[[INIT:.*]] = constant 1 : index
+// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+// CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+// CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
+// CHECK-NEXT: %[[RESULT:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) -> (index)
+// CHECK-NEXT: %[[EXTENT:.*]] = extract_element %[[SHAPE]][%[[I]]]
+// CHECK-NEXT: %[[NEW_ACC:.*]] = muli %[[ACC]], %[[EXTENT]] : index
+// CHECK-NEXT: scf.yield %[[NEW_ACC]] : index
+// CHECK-NEXT: }
+// CHECK-NEXT: return %[[RESULT]] : index
+
+// -----
+
+// Don't lower `shape_of` for result type of `shape.shape`.
+// CHECK-LABEL: @shape_of
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+func @shape_of(%arg : tensor<*xf32>) {
+ // CHECK: shape.shape
+ %shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape
+ return
+}
+
+// -----
+
+// Lower `shape_of` for unranked tensors.
+// CHECK-LABEL: @shape_of_unranked
+// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
+func @shape_of_unranked(%arg : tensor<*xf32>) {
+ // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
+ // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xindex>
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
+ // CHECK: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
+ // CHECK: store %[[DIM]], %[[SHAPE_MEM]][%[[I]]] : memref<?xindex>
+ // CHECK: }
+ // CHECK: %[[SHAPE:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xindex>
+ %shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
+ return
+}
+
+// -----
+
+// Don't lower `shape_of` with `shape.shape` type.
+// CHECK-LABEL: @shape_of
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
+func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
+ // CHECK: shape.shape_of %[[ARG]] : tensor<1x2x3xf32> -> !shape.shape
+ %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> !shape.shape
+ return
+}
+
+// -----
+
+// Lower `shape_of` for statically shaped tensor.
+// CHECK-LABEL: @shape_of_stat
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>)
+func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
+ // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+ // CHECK-DAG: %[[C2:.*]] = constant 2 : index
+ // CHECK-DAG: %[[C3:.*]] = constant 3 : index
+ // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex>
+ %shape = shape.shape_of %arg : tensor<1x2x3xf32> -> tensor<?xindex>
+ return
+}
+
+// -----
+
+// Lower `shape_of` for dynamically shaped tensor.
+// CHECK-LABEL: @shape_of_dyn
+// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
+func @shape_of_dyn(%arg : tensor<1x5x?xf32>) {
+ // CHECK-DAG: %[[C1:.*]] = constant 1 : index
+ // CHECK-DAG: %[[C5:.*]] = constant 5 : index
+ // CHECK-DAG: %[[C2:.*]] = constant 2 : index
+ // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32>
+ // CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex>
+ %shape = shape.shape_of %arg : tensor<1x5x?xf32> -> tensor<?xindex>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @shape_eq
+// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>) -> i1
+func @shape_eq(%a : tensor<?xindex>, %b : tensor<?xindex>) -> i1 {
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[RANK_A:.*]] = dim %[[A]], %[[C0]] : tensor<?xindex>
+ // CHECK: %[[RANK_B:.*]] = dim %[[B]], %[[C0]] : tensor<?xindex>
+ // CHECK: %[[RANK_EQ:.*]] = cmpi "eq", %[[RANK_A]], %[[RANK_B]]
+ // CHECK: %[[SHAPE_EQ:.*]] = scf.if %[[RANK_EQ]] -> (i1) {
+ // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK: %[[INIT:.*]] = constant true
+ // CHECK: %[[SHAPE_EQ_INNER:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[RANK_A]] step %[[C1]] iter_args(%[[CONJ:.*]] = %[[INIT]]) -> (i1) {
+ // CHECK: %[[EXTENT_A:.*]] = extract_element %[[A]][%[[I]]] : tensor<?xindex>
+ // CHECK: %[[EXTENT_B:.*]] = extract_element %[[B]][%[[I]]] : tensor<?xindex>
+ // CHECK: %[[EXTENT_EQ:.*]] = cmpi "eq", %[[EXTENT_A]], %[[EXTENT_B]]
+ // CHECK: %[[CONJ_NEXT:.*]] = and %[[CONJ]], %[[EXTENT_EQ]]
+ // CHECK: scf.yield %[[CONJ_NEXT]] : i1
+ // CHECK: }
+ // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
+ // CHECK: } else {
+ // CHECK: %[[SHAPE_EQ_INNER:.*]] = constant false
+ // CHECK: scf.yield %[[SHAPE_EQ_INNER]] : i1
+ // CHECK: }
+ // CHECK: return %[[SHAPE_EQ]] : i1
+ %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
+ return %result : i1
+}
+
+// -----
+
+// Don't lower `shape.broadcast` if a `shape.shape` type is involved.
+// CHECK-LABEL: @broadcast
+func @broadcast(%a : tensor<?xindex>, %b : !shape.shape) -> !shape.shape {
+ // CHECK: shape.broadcast
+ %c = shape.broadcast %a, %b : tensor<?xindex>, !shape.shape -> !shape.shape
+ return %c : !shape.shape
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast
+// CHECK-SAME: (%[[LHS:.*]]: tensor<?xindex>, %[[RHS:.*]]: tensor<?xindex>)
+func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) {
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
+ // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+ // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]]
+ // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
+ // CHECK: scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+ // CHECK: } else {
+ // CHECK: scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
+ // CHECK: }
+ // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref<?xindex>
+ // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index
+ // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] {
+ // CHECK: %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
+ // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
+ // CHECK: }
+ // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] {
+ // CHECK: %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor<?xindex>
+ // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index
+ // CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) {
+ // CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index
+ // CHECK: %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor<?xindex>
+ // CHECK: scf.yield %[[SMALLER_OPERAND_EXTENT]] : index
+ // CHECK: } else {
+ // CHECK: scf.yield %[[GREATER_OPERAND_EXTENT]] : index
+ // CHECK: }
+ // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref<?xindex>
+ // CHECK: }
+ // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref<?xindex>
+ %0 = shape.broadcast %a, %b
+ : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
+ return
+}
+
More information about the Mlir-commits
mailing list