[Mlir-commits] [mlir] aa6eb2a - [MLIR][LinAlg] Implement detensoring cost-modelling.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 13 00:10:44 PDT 2021
Author: KareemErgawy-TomTom
Date: 2021-04-13T09:07:18+02:00
New Revision: aa6eb2af10094e427827343b67b25d606dde10b7
URL: https://github.com/llvm/llvm-project/commit/aa6eb2af10094e427827343b67b25d606dde10b7
DIFF: https://github.com/llvm/llvm-project/commit/aa6eb2af10094e427827343b67b25d606dde10b7.diff
LOG: [MLIR][LinAlg] Implement detensoring cost-modelling.
This patch introduces the neccessary infrastructure changes to implement
cost-modelling for detensoring. In particular, it introduces the
following changes:
- An extension to the dialect conversion framework to selectively
convert sub-set of non-entry BB arguments.
- An extension to branch conversion pattern to selectively convert
sub-set of a branche's operands.
- An interface for detensoring cost-modelling.
- 2 simple implementations of 2 different cost models.
This sets the stage to explose cost-modelling for detessoring in an
easier way. We still need to come up with better cost models.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D99945
Added:
mlir/test/Dialect/Linalg/detensorize_if.mlir
mlir/test/Dialect/Linalg/detensorize_trivial.mlir
mlir/test/Dialect/Linalg/detensorize_while.mlir
mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
Modified:
mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Dialect/Linalg/detensorized_0d.mlir
Removed:
mlir/test/Dialect/Linalg/detensorized_while.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
index b932d1e009834..4a27d3caf5d73 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
@@ -13,9 +13,13 @@
#ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_
#define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
+
namespace mlir {
// Forward declarations.
+class BranchOpInterface;
class ConversionTarget;
class MLIRContext;
class Operation;
@@ -32,8 +36,15 @@ void populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
/// operands that have been legalized by the conversion framework. This can only
/// be done if the branch operation implements the BranchOpInterface. Only
/// needed for partial conversions.
-void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns,
- TypeConverter &converter);
+///
+/// If for some branch ops, we need to convert/legalize only a sub-set of the
+/// op's operands, such filtering behavior can be specified in
+/// shouldConvertBranchOperand. This callback should return true if branchOp's
+/// operand at index idx should be converted.
+void populateBranchOpInterfaceTypeConversionPattern(
+ RewritePatternSet &patterns, TypeConverter &converter,
+ function_ref<bool(BranchOpInterface branchOp, int idx)>
+ shouldConvertBranchOperand = nullptr);
/// Return true if op is a BranchOpInterface op whose operands are all legal
/// according to converter.
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index c7598ca4f577d..b0de3a170e667 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -498,8 +498,13 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Convert the types of block arguments within the given region except for
/// the entry region. This replaces each non-entry block with a new block
/// containing the updated signature.
- LogicalResult convertNonEntryRegionTypes(Region *region,
- TypeConverter &converter);
+ ///
+ /// If special conversion behavior is needed for the non-entry blocks (for
+ /// example, we need to convert only a subset of a BB arguments), such
+ /// behavior can be specified in blockConversions.
+ LogicalResult convertNonEntryRegionTypes(
+ Region *region, TypeConverter &converter,
+ ArrayRef<TypeConverter::SignatureConversion> blockConversions);
/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 85b9836d5d369..c23b84cf1f62d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -39,13 +39,21 @@ namespace {
/// Defines the criteria a TensorType must follow in order to be considered
/// "detensorable".
///
-/// NOTE: For now, only 0-D are supported.
+/// NOTE: For now, only 0-D tensors are supported.
///
/// Returns true if tensorType can be detensored.
bool canBeDetensored(TensorType tensorType) {
return tensorType.hasRank() && tensorType.getRank() == 0;
}
+bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
+ GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
+ return genericOp && llvm::all_of(genericOp.getShapedOperandTypes(),
+ [&](ShapedType shapedType) {
+ return !typeConverter.isLegal(shapedType);
+ });
+}
+
/// A conversion patttern for detensoring `linalg.generic` ops.
class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
public:
@@ -82,16 +90,35 @@ class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
/// function.
struct FunctionNonEntryBlockConversion : public ConversionPattern {
FunctionNonEntryBlockConversion(StringRef functionLikeOpName,
- MLIRContext *ctx, TypeConverter &converter)
- : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
+ MLIRContext *ctx, TypeConverter &converter,
+ DenseSet<BlockArgument> blockArgsToDetensor)
+ : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx),
+ blockArgsToDetensor(blockArgsToDetensor) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.startRootUpdate(op);
+ Region ®ion = mlir::impl::getFunctionBody(op);
+ SmallVector<TypeConverter::SignatureConversion, 2> conversions;
+
+ for (Block &block : llvm::drop_begin(region, 1)) {
+ conversions.emplace_back(block.getNumArguments());
+ TypeConverter::SignatureConversion &back = conversions.back();
+
+ for (BlockArgument blockArgument : block.getArguments()) {
+ int idx = blockArgument.getArgNumber();
+
+ if (blockArgsToDetensor.count(blockArgument))
+ back.addInputs(idx, {getTypeConverter()->convertType(
+ block.getArgumentTypes()[idx])});
+ else
+ back.addInputs(idx, {block.getArgumentTypes()[idx]});
+ }
+ }
- if (failed(rewriter.convertNonEntryRegionTypes(
- &mlir::impl::getFunctionBody(op), *typeConverter))) {
+ if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter,
+ conversions))) {
rewriter.cancelRootUpdate(op);
return failure();
}
@@ -99,6 +126,9 @@ struct FunctionNonEntryBlockConversion : public ConversionPattern {
rewriter.finalizeRootUpdate(op);
return success();
}
+
+private:
+ const DenseSet<BlockArgument> blockArgsToDetensor;
};
class DetensorizeTypeConverter : public TypeConverter {
@@ -160,46 +190,309 @@ struct ExtractFromReshapeFromElements
/// @see LinalgDetensorize in Linalg/Passes.td for more details.
struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
+ LinalgDetensorize() = default;
+ LinalgDetensorize(const LinalgDetensorize &pass) {}
+
+ class CostModel {
+ public:
+ virtual ~CostModel() = default;
+
+ /// A cost model algorithm computes the following outputs:
+ ///
+ /// - opsToDetensor: the list of linalg ops that should be
+ /// detensored.
+ ///
+ /// - blockArgsToDetensor: since the operands and results of detensored
+ /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
+ /// from a BB argument and a linalg op's output can be passed to successor
+ /// BBs), we need to maintain the sub-set of arguments that should be
+ /// detensored (i.e. converted by typeConverter) for each affected BB.
+ ///
+ /// Example:
+ ///
+ /// For the following snippet:
+ /// ...
+ /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
+ /// %7 = linalg.init_tensor [] : tensor<i32>
+ /// %8 = linalg.generic #attrs
+ /// ins(%6, %6 : tensor<i32>, tensor<i32>)
+ /// outs(%7 : tensor<i32>) {
+ /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
+ /// %9 = addi %arg0, %arg1 : i32
+ /// linalg.yield %9 : i32
+ /// } -> tensor<i32>
+ /// %10 = "some.op"(%9)
+ /// br ^bb2(%8 : tensor<i32>)
+ /// ...
+ ///
+ /// if the cost model decides that the linalg.generic op should be
+ /// detensored, then:
+ /// - opsToDetensor should be = {linalg.generic{add}}.
+ /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
+ virtual void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
+ DenseSet<Operation *> &opsToDetensor,
+ DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
+
+ /// From the blockArgsToDetensor set computed by a CostModel
+ /// implementation, this method computes the corresponding branch op
+ /// detensoring. The result is a map from a branch op to a subset of indices
+ /// of its operands. The indices specify which of the branch op's operands
+ /// should be detensored.
+ ///
+ /// For the previous example, this method would compute: {bb2 -> {0}}.
+ static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
+ const DenseSet<BlockArgument> &blockArgsToDetensor) {
+ DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
+
+ for (auto blockArgumentElem : blockArgsToDetensor) {
+ Block *block = blockArgumentElem.getOwner();
+
+ for (PredecessorIterator pred = block->pred_begin();
+ pred != block->pred_end(); ++pred) {
+ BranchOpInterface terminator =
+ dyn_cast<BranchOpInterface>((*pred)->getTerminator());
+ auto blockOperands =
+ terminator.getSuccessorOperands(pred.getSuccessorIndex());
+
+ if (!blockOperands || blockOperands->empty())
+ continue;
+
+ detensorableBranchOps[terminator].insert(
+ blockOperands->getBeginOperandIndex() +
+ blockArgumentElem.getArgNumber());
+ }
+ }
+
+ return detensorableBranchOps;
+ }
+ };
+
+ /// Detensorize linalg ops involved in control-flow within a function.
+ ///
+ /// This model starts from CondBranchOps within a function. For each cond_br,
+ /// the model then walks the use-def chain for the branch's condition
+ /// backwards in order to understand where the condition's value comes from.
+ /// If the condition value is (indirectly) computed by a linalg op that can be
+ /// detensored, the model then continues walking the use-def chain in order to
+ /// understand where the linalg op's operands come from. This leads to
+ /// discovering a "detensoring component". A detensoring component is the set
+ /// of operations + block arguments that are involved in control-flow AND can
+ /// be detensored.
+ ///
+ /// For examples where this model succeeds to discover a detensoring
+ /// component, see:
+ /// - test/Dialect/Linalg/detensorize_while.mlir
+ /// - test/Dialect/Linalg/detesorize_while_pure_cf.mlir.
+ ///
+ /// For an example where this model marks control-flow as "non-detensorable",
+ /// see:
+ /// - test/Dialect/Linalg/detensorize_while_failure.mlir
+ class PureControlFlowDetectionModel : public CostModel {
+ public:
+ void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
+ DenseSet<Operation *> &opsToDetensor,
+ DenseSet<BlockArgument> &blockArgsToDetensor) override {
+ SmallVector<Value> workList;
+
+ func.walk(
+ [&](CondBranchOp condBr) { workList.push_back(condBr.condition()); });
+
+ DenseSet<Value> visitedValues;
+ DenseSet<Operation *> visitedOps;
+
+ while (!workList.empty()) {
+ Value currentItem = workList.pop_back_val();
+
+ if (!visitedValues.insert(currentItem).second)
+ continue;
+
+ // The current item is defined by a block argument.
+ if (auto bbarg = currentItem.dyn_cast<BlockArgument>()) {
+ BlockArgument currentItemBlockArgument =
+ currentItem.cast<BlockArgument>();
+ Block *ownerBlock = currentItemBlockArgument.getOwner();
+
+ // Function arguments are not detensored/converted.
+ if (&*ownerBlock->getParent()->begin() == ownerBlock)
+ continue;
+
+ // This inner-block argument is involved in control-flow, it should be
+ // detensored.
+ blockArgsToDetensor.insert(currentItemBlockArgument);
+
+ for (PredecessorIterator pred = ownerBlock->pred_begin();
+ pred != ownerBlock->pred_end(); ++pred) {
+ BranchOpInterface terminator =
+ dyn_cast<BranchOpInterface>((*pred)->getTerminator());
+
+ // TODO: For now, we give up if any of the control-flow components
+ // in a function is not detensorable. Fix that.
+ if (!terminator) {
+ opsToDetensor.clear();
+ blockArgsToDetensor.clear();
+ return;
+ }
+
+ auto ownerBlockOperands =
+ terminator.getSuccessorOperands(pred.getSuccessorIndex());
+
+ if (!ownerBlockOperands || ownerBlockOperands->empty())
+ continue;
+
+ // For each predecessor, add the value it passes to that argument to
+ // workList to find out how it's computed.
+ workList.push_back(
+ ownerBlockOperands
+ .getValue()[currentItemBlockArgument.getArgNumber()]);
+ }
+
+ continue;
+ }
+
+ Operation *currentItemDefiningOp = currentItem.getDefiningOp();
+
+ if (!visitedOps.insert(currentItemDefiningOp).second)
+ continue;
+
+ // The current item is computed by a GenericOp.
+ if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
+ // The op was encountered already, no need to inspect it again.
+ if (opsToDetensor.count(genericOp))
+ continue;
+
+ // TODO: For now, we give up if any of the control-flow components
+ // in a function is not detensorable. Fix that.
+ if (!shouldBeDetensored(genericOp, typeConverter)) {
+ opsToDetensor.clear();
+ blockArgsToDetensor.clear();
+ return;
+ }
+
+ opsToDetensor.insert(genericOp);
+
+ for (Value genericOpOperand : genericOp.inputs())
+ workList.push_back(genericOpOperand);
+
+ continue;
+ }
+
+ // The current item is the result of a FromElemntsOp, it will be
+ // trivially detensored later as part of canonicalization patterns
+ // applied at the end of detensoring.
+ //
+ // Note: No need to check whether the result type of this op is
+ // detensorable since if it wasn't we wouldn't reach that point in the
+ // work list.
+ if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
+ continue;
+
+ // The current item is the result of a scalar op, add all its operands
+ // to the work list.
+ if (llvm::all_of(
+ currentItemDefiningOp->getResultTypes(),
+ [&](Type resultType) { return resultType.isIntOrFloat(); }))
+ for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
+ workList.push_back(scalarOpOperand);
+ }
+ }
+ };
+
+ /// Detensorize everything that can detensored.
+ class AggressiveDetensoringModel : public CostModel {
+ public:
+ void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
+ DenseSet<Operation *> &opsToDetensor,
+ DenseSet<BlockArgument> &blockArgsToDetensor) override {
+ func.walk([&](GenericOp genericOp) {
+ if (shouldBeDetensored(genericOp, typeConverter))
+ opsToDetensor.insert(genericOp);
+ });
+
+ for (Block &block : llvm::drop_begin(func.getBody(), 1))
+ for (BlockArgument blockArgument : block.getArguments())
+ blockArgsToDetensor.insert(blockArgument);
+ }
+ };
+
void runOnFunction() override {
- auto *context = &getContext();
+ MLIRContext *context = &getContext();
DetensorizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
+ DenseSet<Operation *> opsToDetensor;
+ DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
+ DenseSet<BlockArgument> blockArgsToDetensor;
+
+ if (aggressiveMode.getValue()) {
+ AggressiveDetensoringModel costModel;
+ costModel.compute(getFunction(), typeConverter, opsToDetensor,
+ blockArgsToDetensor);
+
+ } else {
+ PureControlFlowDetectionModel costModel;
+ costModel.compute(getFunction(), typeConverter, opsToDetensor,
+ blockArgsToDetensor);
+ }
- target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
- // If any of the operands or results cannot be detensored (i.e. they are
- // all legal according the DetensorizeTypeConverter), the op is considered
- // legal and won't be detensored.
- return llvm::any_of(op.getShapedOperandTypes(),
- [&](ShapedType shapedType) {
- return typeConverter.isLegal(shapedType);
- });
- });
+ detensorableBranchOps =
+ CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
+
+ target.addDynamicallyLegalOp<GenericOp>(
+ [&](GenericOp op) { return !opsToDetensor.count(op); });
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
- // A function is legal if all of its non-entry blocks are legal. We don't
- // legalize the entry block (i.e. the function's signature) since
- // detensoring can't happen along external calling convention boundaries,
- // which we conservatively approximate as all function signatures.
+ // A function is legal if all of its non-entry blocks are legal. We
+ // don't legalize the entry block (i.e. the function's signature) since
+ // detensoring can't happen along external calling convention
+ // boundaries, which we conservatively approximate as all function
+ // signatures.
return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) {
- return typeConverter.isLegal(block.getArgumentTypes());
+ if (llvm::any_of(blockArgsToDetensor, [&](BlockArgument blockArgument) {
+ return blockArgument.getOwner() == &block &&
+ !typeConverter.isLegal(blockArgument.getType());
+ })) {
+ return false;
+ }
+ return true;
});
});
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
- return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
- isLegalForBranchOpInterfaceTypeConversionPattern(op,
- typeConverter) ||
- isLegalForReturnOpTypeConversionPattern(
- op, typeConverter, /*returnOpAlwaysLegal*/ true);
+ if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
+ isLegalForReturnOpTypeConversionPattern(op, typeConverter,
+ /*returnOpAlwaysLegal*/ true))
+ return true;
+
+ if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
+ if (!detensorableBranchOps.count(branchOp))
+ return true;
+
+ for (auto operandIdx : detensorableBranchOps[branchOp])
+ if (!typeConverter.isLegal(
+ branchOp->getOperand(operandIdx).getType()))
+ return false;
+
+ return true;
+ }
+
+ return false;
});
- patterns.add<DetensorizeGenericOp>(typeConverter, context);
- patterns.add<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
- context, typeConverter);
- // Since non-entry block arguments get detensorized, we also need to update
- // the control flow inside the function to reflect the correct types.
- populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
+ patterns.insert<DetensorizeGenericOp>(typeConverter, context);
+ patterns.insert<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
+ context, typeConverter,
+ blockArgsToDetensor);
+ // Since non-entry block arguments get detensorized, we also need to
+ // update the control flow inside the function to reflect the correct
+ // types.
+ auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
+ int operandIdx) -> bool {
+ return detensorableBranchOps.count(branchOp) &&
+ detensorableBranchOps[branchOp].count(operandIdx);
+ };
+
+ populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
+ shouldConvertBranchOperand);
if (failed(applyFullConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
@@ -210,6 +503,11 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
std::move(canonPatterns))))
signalPassFailure();
}
+
+ Option<bool> aggressiveMode{
+ *this, "aggressive-mode",
+ llvm::cl::desc("Detensorize all ops that qualify for detensoring along "
+ "with branch operands and basic-block arguments.")};
};
} // namespace
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
index bf2dcd69e9cad..218efaccb6e74 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
@@ -52,6 +52,12 @@ class BranchOpInterfaceTypeConversion
using OpInterfaceConversionPattern<
BranchOpInterface>::OpInterfaceConversionPattern;
+ BranchOpInterfaceTypeConversion(
+ TypeConverter &typeConverter, MLIRContext *ctx,
+ function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
+ : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
+ shouldConvertBranchOperand(shouldConvertBranchOperand) {}
+
LogicalResult
matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
@@ -61,18 +67,23 @@ class BranchOpInterfaceTypeConversion
for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
succIdx < succEnd; ++succIdx) {
auto successorOperands = op.getSuccessorOperands(succIdx);
- if (!successorOperands)
+ if (!successorOperands || successorOperands->empty())
continue;
+
for (int idx = successorOperands->getBeginOperandIndex(),
eidx = idx + successorOperands->size();
idx < eidx; ++idx) {
- newOperands[idx] = operands[idx];
+ if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
+ newOperands[idx] = operands[idx];
}
}
rewriter.updateRootInPlace(
op, [newOperands, op]() { op->setOperands(newOperands); });
return success();
}
+
+private:
+ function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
};
} // end anonymous namespace
@@ -98,9 +109,10 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
} // end anonymous namespace
void mlir::populateBranchOpInterfaceTypeConversionPattern(
- RewritePatternSet &patterns, TypeConverter &typeConverter) {
- patterns.add<BranchOpInterfaceTypeConversion>(typeConverter,
- patterns.getContext());
+ RewritePatternSet &patterns, TypeConverter &typeConverter,
+ function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
+ patterns.insert<BranchOpInterfaceTypeConversion>(
+ typeConverter, patterns.getContext(), shouldConvertBranchOperand);
}
bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bbdaec68364c9..54dce7417e219 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -495,8 +495,17 @@ Block *ArgConverter::applySignatureConversion(
// to pack the new values. For 1->1 mappings, if there is no materialization
// provided, use the argument directly instead.
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
- Value newArg = converter.materializeArgumentConversion(
- rewriter, origArg.getLoc(), origArg.getType(), replArgs);
+ Value newArg;
+
+ // If this is a 1->1 mapping and the types of new and replacement arguments
+ // match (i.e. it's an identity map), then the argument is mapped to its
+ // original type.
+ if (replArgs.size() == 1 && replArgs[0].getType() == origArg.getType())
+ newArg = replArgs[0];
+ else
+ newArg = converter.materializeArgumentConversion(
+ rewriter, origArg.getLoc(), origArg.getType(), replArgs);
+
if (!newArg) {
assert(replArgs.size() == 1 &&
"couldn't materialize the result of 1->N conversion");
@@ -754,8 +763,9 @@ struct ConversionPatternRewriterImpl {
TypeConverter::SignatureConversion *entryConversion);
/// Convert the types of non-entry block arguments within the given region.
- LogicalResult convertNonEntryRegionTypes(Region *region,
- TypeConverter &converter);
+ LogicalResult convertNonEntryRegionTypes(
+ Region *region, TypeConverter &converter,
+ ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -1173,15 +1183,30 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
}
LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
- Region *region, TypeConverter &converter) {
+ Region *region, TypeConverter &converter,
+ ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
argConverter.setConverter(region, &converter);
if (region->empty())
return success();
// Convert the arguments of each block within the region.
- for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
- if (failed(convertBlockSignature(&block, converter)))
+ int blockIdx = 0;
+ assert((blockConversions.empty() ||
+ blockConversions.size() == region->getBlocks().size() - 1) &&
+ "expected either to provide no SignatureConversions at all or to "
+ "provide a SignatureConversion for each non-entry block");
+
+ for (Block &block :
+ llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
+ TypeConverter::SignatureConversion *blockConversion =
+ blockConversions.empty()
+ ? nullptr
+ : const_cast<TypeConverter::SignatureConversion *>(
+ &blockConversions[blockIdx++]);
+
+ if (failed(convertBlockSignature(&block, converter, blockConversion)))
return failure();
+ }
return success();
}
@@ -1351,8 +1376,9 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
}
LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
- Region *region, TypeConverter &converter) {
- return impl->convertNonEntryRegionTypes(region, converter);
+ Region *region, TypeConverter &converter,
+ ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
+ return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
new file mode 100644
index 0000000000000..05a3720b50891
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = []
+}
+
+func @main() -> (tensor<i32>) attributes {} {
+ %c0 = constant 0 : i32
+ %0 = tensor.from_elements %c0 : tensor<1xi32>
+ %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+ %c10 = constant 10 : i32
+ %1 = tensor.from_elements %c10 : tensor<1xi32>
+ %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+ br ^bb1(%reshaped0 : tensor<i32>)
+
+^bb1(%2: tensor<i32>): // 2 preds: ^bb0, ^bb2
+ %3 = linalg.init_tensor [] : tensor<i1>
+ %4 = linalg.generic #attrs
+ ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+ outs(%3 : tensor<i1>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors
+ %8 = cmpi slt, %arg0, %arg1 : i32
+ linalg.yield %8 : i1
+ } -> tensor<i1>
+ %5 = tensor.extract %4[] : tensor<i1>
+ cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
+
+^bb2(%6: tensor<i32>): // pred: ^bb1
+ %7 = linalg.init_tensor [] : tensor<i32>
+ %8 = linalg.generic #attrs
+ ins(%6, %6 : tensor<i32>, tensor<i32>)
+ outs(%7 : tensor<i32>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors
+ %9 = addi %arg0, %arg1 : i32
+ linalg.yield %9 : i32
+ } -> tensor<i32>
+ br ^bb3(%8 : tensor<i32>)
+
+^bb3(%10: tensor<i32>): // pred: ^bb1
+ return %10 : tensor<i32>
+}
+
+// CHECK-LABEL: func @main()
+// CHECK-NEXT: constant 0
+// CHECK-NEXT: constant 10
+// CHECK-NEXT: br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT: tensor.from_elements %{{.*}}
+// CHECK-NEXT: linalg.tensor_reshape %{{.*}}
+// CHECK-NEXT: cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
+// CHECK-NEXT: ^[[bb2]](%{{.*}}: tensor<i32>)
+// CHECK-NEXT: linalg.init_tensor
+// CHECK-NEXT: linalg.generic
+// CHECK-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32)
+// CHECK-NEXT: addi %{{.*}}, %{{.*}}
+// CHECK-NEXT: linalg.yield %{{.*}}
+// CHECK-NEXT: } -> tensor<i32>
+// CHECK-NEXT: br ^[[bb3:.*]](%{{.*}} : tensor<i32>)
+// CHECK-NEXT: ^[[bb3]](%{{.*}}: tensor<i32>)
+// CHECK-NEXT: return %{{.*}}
+// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
new file mode 100644
index 0000000000000..6fcd056f9b365
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL
+// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF
+
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = []
+}
+
+func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
+ %c10 = constant 10 : i32
+ %1 = tensor.from_elements %c10 : tensor<1xi32>
+ %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+ %3 = linalg.init_tensor [] : tensor<i1>
+ %4 = linalg.generic #attrs
+ ins(%farg0, %reshaped1 : tensor<i32>, tensor<i32>)
+ outs(%3 : tensor<i1>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):
+ %8 = cmpi slt, %arg0, %arg1 : i32
+ linalg.yield %8 : i1
+ } -> tensor<i1>
+ return %4 : tensor<i1>
+}
+
+
+// DET-ALL-LABEL: func @main(%{{.*}}: tensor<i32>)
+// DET-ALL-NEXT: constant 10
+// DET-ALL-NEXT: tensor.extract %{{.*}}[]
+// DET-ALL-NEXT: cmpi slt, %{{.*}}, %{{.*}}
+// DET-ALL-NEXT: tensor.from_elements %{{.*}}
+// DET-ALL-NEXT: linalg.tensor_reshape %{{.*}}
+// DET-ALL-NEXT: return %{{.*}} : tensor<i1>
+// DET-ALL-NEXT: }
+
+// DET-CF-LABEL: func @main(%{{.*}}: tensor<i32>)
+// DET-CF-NEXT: constant 10 : i32
+// DET-CF-NEXT: tensor.from_elements %{{.*}}
+// DET-CF-NEXT: linalg.tensor_reshape %{{.*}}
+// DET-CF-NEXT: linalg.init_tensor [] : tensor<i1>
+// DET-CF-NEXT: linalg.generic
+// DET-CF-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i1)
+// DET-CF-NEXT: cmpi slt, %{{.*}}, %{{.*}}
+// DET-CF-NEXT: linalg.yield %{{.*}}
+// DET-CF-NEXT: } -> tensor<i1>
+// DET-CF-NEXT: return %{{.*}}
+// DET-CF-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
new file mode 100644
index 0000000000000..72390f0d76087
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL
+// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = []
+}
+
+func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} {
+ br ^bb1(%farg0 : tensor<i32>)
+
+^bb1(%0: tensor<i32>): // 2 preds: ^bb0, ^bb2
+ %1 = linalg.init_tensor [] : tensor<i1>
+ %2 = linalg.generic #attrs
+ ins(%0, %farg1 : tensor<i32>, tensor<i32>)
+ outs(%1 : tensor<i1>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors
+ %8 = cmpi slt, %arg0, %arg1 : i32
+ linalg.yield %8 : i1
+ } -> tensor<i1>
+ %3 = tensor.extract %2[] : tensor<i1>
+ cond_br %3, ^bb2(%0 : tensor<i32>), ^bb3(%0 : tensor<i32>)
+
+^bb2(%4: tensor<i32>): // pred: ^bb1
+ %5 = linalg.init_tensor [] : tensor<i32>
+ %6 = linalg.generic #attrs
+ ins(%4, %4 : tensor<i32>, tensor<i32>)
+ outs(%5 : tensor<i32>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors
+ %8 = addi %arg0, %arg1 : i32
+ linalg.yield %8 : i32
+ } -> tensor<i32>
+ br ^bb1(%6 : tensor<i32>)
+
+^bb3(%7: tensor<i32>): // pred: ^bb1
+ return %7 : tensor<i32>
+}
+
+// Test aggresively detensoring all detensorable ops.
+//
+// DET-ALL-LABEL: func @main
+// DET-ALL-SAME: (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>)
+// DET-ALL: tensor.extract {{.*}}
+// DET-ALL: br ^[[bb1:.*]](%{{.*}} : i32)
+// DET-ALL: ^[[bb1]](%{{.*}}: i32)
+// DET-ALL: cmpi slt, {{.*}}
+// DET-ALL: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-ALL: ^[[bb2]](%{{.*}}: i32)
+// DET-ALL: addi {{.*}}
+// DET-ALL: br ^[[bb1]](%{{.*}} : i32)
+// DET-ALL: ^[[bb3]](%{{.*}}: i32)
+// DET-ALL: tensor.from_elements {{.*}}
+// DET-ALL: linalg.tensor_reshape {{.*}}
+// DET-ALL: return %{{.*}} : tensor<i32>
+
+// Test detensoring only ops involed in control-flow.
+//
+// DET-CF-LABEL: func @main
+// DET-CF-SAME: (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>)
+// DET-CF: tensor.extract {{.*}}
+// DET-CF: br ^[[bb1:.*]](%{{.*}} : i32)
+// DET-CF: ^[[bb1]](%{{.*}}: i32)
+// DET-CF-DAG tensor.from_elements {{.*}}
+// DET-CF-DAG: linalg.tensor_reshape {{.*}}
+// DET-CF-DAG: cmpi slt, {{.*}}
+// DET-CF: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : tensor<i32>)
+// DET-CF: ^[[bb2]](%{{.*}}: i32)
+// DET-CF: addi {{.*}}
+// DET-CF: br ^[[bb1]](%{{.*}} : i32)
+// DET-CF: ^[[bb3]](%{{.*}}: tensor<i32>)
+// DET-CF: return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
new file mode 100644
index 0000000000000..36361d5492110
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
@@ -0,0 +1,111 @@
+// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL
+// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF
+
+#map0 = affine_map<() -> ()>
+#map1 = affine_map<(i) -> ()>
+#map2 = affine_map<(i) -> (i)>
+
+#attrs = {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = []
+}
+
+#sum_reduction_attrs = {
+ indexing_maps = [#map2, #map1],
+ iterator_types = ["reduction"]
+}
+
+
+#broadcast_attrs = {
+ indexing_maps = [#map1, #map2],
+ iterator_types = ["parallel"]
+}
+
+func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} {
+ br ^bb1(%farg0 : tensor<10xi32>)
+
+^bb1(%0: tensor<10xi32>): // 2 preds: ^bb0, ^bb2
+ %1 = linalg.init_tensor [] : tensor<i32>
+ %2 = linalg.generic #sum_reduction_attrs
+ ins(%0: tensor<10xi32>)
+ outs(%1: tensor<i32>) {
+ ^bb(%a: i32, %x: i32):
+ %b = addi %x, %a : i32
+ linalg.yield %b : i32
+ } -> tensor<i32>
+
+ %3 = linalg.init_tensor [] : tensor<i1>
+ %4 = linalg.generic #attrs
+ ins(%2, %farg1 : tensor<i32>, tensor<i32>)
+ outs(%3 : tensor<i1>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors
+ %8 = cmpi slt, %arg0, %arg1 : i32
+ linalg.yield %8 : i1
+ } -> tensor<i1>
+ %5 = tensor.extract %4[] : tensor<i1>
+ cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
+
+^bb2(%6: tensor<i32>): // pred: ^bb1
+ %7 = linalg.init_tensor [10] : tensor<10xi32>
+ %9 = linalg.generic #broadcast_attrs
+ ins(%6: tensor<i32>)
+ outs(%7: tensor<10xi32>) {
+ ^bb(%a: i32, %b: i32) :
+ linalg.yield %a : i32
+ } -> tensor<10xi32>
+
+ br ^bb1(%9 : tensor<10xi32>)
+
+^bb3(%10: tensor<i32>): // pred: ^bb1
+ return %10 : tensor<i32>
+}
+
+// Test aggresively detensoring all detensorable ops.
+//
+// DET-ALL-LABEL: func @main
+// DET-ALL-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
+// DET-ALL: br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
+// DET-ALL: ^[[bb1]](%{{.*}}: tensor<10xi32>)
+// DET-ALL: linalg.init_tensor [] : tensor<i32>
+// DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
+// DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32): // no predecessors
+// DET-ALL: %{{.*}} = addi %{{.*}}, %{{.*}}
+// DET-ALL: linalg.yield %{{.*}} : i32
+// DET-ALL: } -> tensor<i32>
+// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL: cmpi slt, %{{.*}}, %{{.*}} : i32
+// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-ALL: ^[[bb2]](%{{.*}}: i32)
+// DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32>
+// DET-ALL: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-ALL: linalg.init_tensor [10] : tensor<10xi32>
+// DET-ALL: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
+// DET-ALL: ^bb0(%{{.*}}: i32, %{{.*}}: i32):
+// DET-ALL: linalg.yield %{{.*}} : i32
+// DET-ALL: } -> tensor<10xi32>
+// DET-ALL: br ^[[bb1]](%{{.*}} : tensor<10xi32>)
+// DET-ALL: ^[[bb3]](%{{.*}}: i32)
+// DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32>
+// DET-ALL: linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-ALL: return %{{.*}} : tensor<i32>
+// DET-ALL: }
+
+// Try to detensor pure control-flow. However, that fails since the potential
+// detensorable component contains some ops that cannot be detensored.
+//
+// DET-CF-LABEL: func @main
+// DET-CF-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
+// DET-CF: br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
+// DET-CF: ^bb1(%{{.*}}: tensor<10xi32>)
+// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
+// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>) outs(%{{.*}} : tensor<i1>) {
+// DET-CF: cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
+// DET-CF: ^bb2(%{{.*}}: tensor<i32>)
+// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
+// DET-CF: br ^bb1(%{{.*}} : tensor<10xi32>)
+// DET-CF: ^bb3(%{{.*}}: tensor<i32>)
+// DET-CF: return %{{.*}} : tensor<i32>
+// DET-CF: }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
new file mode 100644
index 0000000000000..b0d88efadca70
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = []
+}
+
+func @main() -> () attributes {} {
+ %c0 = constant 0 : i32
+ %0 = tensor.from_elements %c0 : tensor<1xi32>
+ %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+ %c10 = constant 10 : i32
+ %1 = tensor.from_elements %c10 : tensor<1xi32>
+ %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+ br ^bb1(%reshaped0 : tensor<i32>)
+
+^bb1(%2: tensor<i32>): // 2 preds: ^bb0, ^bb2
+ %3 = linalg.init_tensor [] : tensor<i1>
+ %4 = linalg.generic #attrs
+ ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+ outs(%3 : tensor<i1>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors
+ %8 = cmpi slt, %arg0, %arg1 : i32
+ linalg.yield %8 : i1
+ } -> tensor<i1>
+ %5 = tensor.extract %4[] : tensor<i1>
+ cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3
+
+^bb2(%6: tensor<i32>): // pred: ^bb1
+ %7 = linalg.init_tensor [] : tensor<i32>
+ %8 = linalg.generic #attrs
+ ins(%6, %6 : tensor<i32>, tensor<i32>)
+ outs(%7 : tensor<i32>) {
+ ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors
+ %9 = addi %arg0, %arg1 : i32
+ linalg.yield %9 : i32
+ } -> tensor<i32>
+ br ^bb1(%8 : tensor<i32>)
+
+^bb3: // pred: ^bb1
+ return
+}
+
+// CHECK-LABEL: func @main
+// CHECK-NEXT: constant 0 : i32
+// CHECK-NEXT: constant 10
+// CHECK-NEXT: br ^[[bb1:.*]](%{{.*}} : i32)
+// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32)
+// CHECK-NEXT: %{{.*}} = cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]]
+// CHECK-NEXT: ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT: %{{.*}} = addi %{{.*}}, %{{.*}}
+// CHECK-NEXT: br ^[[bb1]](%{{.*}} : i32)
+// CHECK-NEXT: ^[[bb3]]:
+// CHECK-NEXT: return
+// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorized_0d.mlir
index e35a34fd157d4..91e3f080bedcb 100644
--- a/mlir/test/Dialect/Linalg/detensorized_0d.mlir
+++ b/mlir/test/Dialect/Linalg/detensorized_0d.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize=aggressive-mode | FileCheck %s
#map = affine_map<() -> ()>
diff --git a/mlir/test/Dialect/Linalg/detensorized_while.mlir b/mlir/test/Dialect/Linalg/detensorized_while.mlir
deleted file mode 100644
index a227e753006c9..0000000000000
--- a/mlir/test/Dialect/Linalg/detensorized_while.mlir
+++ /dev/null
@@ -1,53 +0,0 @@
-// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s
-
-#map0 = affine_map<() -> ()>
-
-#attrs = {
- indexing_maps = [#map0, #map0, #map0],
- iterator_types = []
-}
-
-func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} {
- br ^bb1(%farg0 : tensor<i32>)
-
-^bb1(%0: tensor<i32>): // 2 preds: ^bb0, ^bb2
- %1 = linalg.init_tensor [] : tensor<i1>
- %2 = linalg.generic #attrs
- ins(%0, %farg1 : tensor<i32>, tensor<i32>)
- outs(%1 : tensor<i1>) {
- ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors
- %8 = cmpi slt, %arg0, %arg1 : i32
- linalg.yield %8 : i1
- } -> tensor<i1>
- %3 = tensor.extract %2[] : tensor<i1>
- cond_br %3, ^bb2(%0 : tensor<i32>), ^bb3(%0 : tensor<i32>)
-
-^bb2(%4: tensor<i32>): // pred: ^bb1
- %5 = linalg.init_tensor [] : tensor<i32>
- %6 = linalg.generic #attrs
- ins(%4, %4 : tensor<i32>, tensor<i32>)
- outs(%5 : tensor<i32>) {
- ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors
- %8 = addi %arg0, %arg1 : i32
- linalg.yield %8 : i32
- } -> tensor<i32>
- br ^bb1(%6 : tensor<i32>)
-
-^bb3(%7: tensor<i32>): // pred: ^bb1
- return %7 : tensor<i32>
-}
-
-// CHECK-LABEL: func @main
-// CHECK-SAME: (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>)
-// CHECK: tensor.extract {{.*}}
-// CHECK: br ^[[bb1:.*]](%{{.*}} : i32)
-// CHECK: ^[[bb1]](%{{.*}}: i32)
-// CHECK: cmpi slt, {{.*}}
-// CHECK: cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK: ^[[bb2]](%{{.*}}: i32)
-// CHECK: addi {{.*}}
-// CHECK: br ^[[bb1]](%{{.*}} : i32)
-// CHECK: ^[[bb3]](%{{.*}}: i32)
-// CHECK: tensor.from_elements {{.*}}
-// CHECK: linalg.tensor_reshape {{.*}}
-// CHECK: return %{{.*}} : tensor<i32>
More information about the Mlir-commits
mailing list