[Mlir-commits] [mlir] aa6eb2a - [MLIR][LinAlg] Implement detensoring cost-modelling.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 13 00:10:44 PDT 2021


Author: KareemErgawy-TomTom
Date: 2021-04-13T09:07:18+02:00
New Revision: aa6eb2af10094e427827343b67b25d606dde10b7

URL: https://github.com/llvm/llvm-project/commit/aa6eb2af10094e427827343b67b25d606dde10b7
DIFF: https://github.com/llvm/llvm-project/commit/aa6eb2af10094e427827343b67b25d606dde10b7.diff

LOG: [MLIR][LinAlg] Implement detensoring cost-modelling.

This patch introduces the neccessary infrastructure changes to implement
cost-modelling for detensoring. In particular, it introduces the
following changes:
- An extension to the dialect conversion framework to selectively
convert sub-set of non-entry BB arguments.
- An extension to branch conversion pattern to selectively convert
sub-set of a branche's operands.
- An interface for detensoring cost-modelling.
- 2 simple implementations of 2 different cost models.

This sets the stage to explose cost-modelling for detessoring in an
easier way. We still need to come up with better cost models.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D99945

Added: 
    mlir/test/Dialect/Linalg/detensorize_if.mlir
    mlir/test/Dialect/Linalg/detensorize_trivial.mlir
    mlir/test/Dialect/Linalg/detensorize_while.mlir
    mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
    mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir

Modified: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Dialect/Linalg/detensorized_0d.mlir

Removed: 
    mlir/test/Dialect/Linalg/detensorized_while.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
index b932d1e009834..4a27d3caf5d73 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
@@ -13,9 +13,13 @@
 #ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_
 #define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_
 
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
+
 namespace mlir {
 
 // Forward declarations.
+class BranchOpInterface;
 class ConversionTarget;
 class MLIRContext;
 class Operation;
@@ -32,8 +36,15 @@ void populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
 /// operands that have been legalized by the conversion framework. This can only
 /// be done if the branch operation implements the BranchOpInterface. Only
 /// needed for partial conversions.
-void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns,
-                                                    TypeConverter &converter);
+///
+/// If for some branch ops, we need to convert/legalize only a sub-set of the
+/// op's operands, such filtering behavior can be specified in
+/// shouldConvertBranchOperand. This callback should return true if branchOp's
+/// operand at index idx should be converted.
+void populateBranchOpInterfaceTypeConversionPattern(
+    RewritePatternSet &patterns, TypeConverter &converter,
+    function_ref<bool(BranchOpInterface branchOp, int idx)>
+        shouldConvertBranchOperand = nullptr);
 
 /// Return true if op is a BranchOpInterface op whose operands are all legal
 /// according to converter.

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index c7598ca4f577d..b0de3a170e667 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -498,8 +498,13 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Convert the types of block arguments within the given region except for
   /// the entry region. This replaces each non-entry block with a new block
   /// containing the updated signature.
-  LogicalResult convertNonEntryRegionTypes(Region *region,
-                                           TypeConverter &converter);
+  ///
+  /// If special conversion behavior is needed for the non-entry blocks (for
+  /// example, we need to convert only a subset of a BB arguments), such
+  /// behavior can be specified in blockConversions.
+  LogicalResult convertNonEntryRegionTypes(
+      Region *region, TypeConverter &converter,
+      ArrayRef<TypeConverter::SignatureConversion> blockConversions);
 
   /// Replace all the uses of the block argument `from` with value `to`.
   void replaceUsesOfBlockArgument(BlockArgument from, Value to);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 85b9836d5d369..c23b84cf1f62d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -39,13 +39,21 @@ namespace {
 /// Defines the criteria a TensorType must follow in order to be considered
 /// "detensorable".
 ///
-/// NOTE: For now, only 0-D are supported.
+/// NOTE: For now, only 0-D tensors are supported.
 ///
 /// Returns true if tensorType can be detensored.
 bool canBeDetensored(TensorType tensorType) {
   return tensorType.hasRank() && tensorType.getRank() == 0;
 }
 
+bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
+  GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
+  return genericOp && llvm::all_of(genericOp.getShapedOperandTypes(),
+                                   [&](ShapedType shapedType) {
+                                     return !typeConverter.isLegal(shapedType);
+                                   });
+}
+
 /// A conversion patttern for detensoring `linalg.generic` ops.
 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
 public:
@@ -82,16 +90,35 @@ class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
 /// function.
 struct FunctionNonEntryBlockConversion : public ConversionPattern {
   FunctionNonEntryBlockConversion(StringRef functionLikeOpName,
-                                  MLIRContext *ctx, TypeConverter &converter)
-      : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
+                                  MLIRContext *ctx, TypeConverter &converter,
+                                  DenseSet<BlockArgument> blockArgsToDetensor)
+      : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx),
+        blockArgsToDetensor(blockArgsToDetensor) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.startRootUpdate(op);
+    Region &region = mlir::impl::getFunctionBody(op);
+    SmallVector<TypeConverter::SignatureConversion, 2> conversions;
+
+    for (Block &block : llvm::drop_begin(region, 1)) {
+      conversions.emplace_back(block.getNumArguments());
+      TypeConverter::SignatureConversion &back = conversions.back();
+
+      for (BlockArgument blockArgument : block.getArguments()) {
+        int idx = blockArgument.getArgNumber();
+
+        if (blockArgsToDetensor.count(blockArgument))
+          back.addInputs(idx, {getTypeConverter()->convertType(
+                                  block.getArgumentTypes()[idx])});
+        else
+          back.addInputs(idx, {block.getArgumentTypes()[idx]});
+      }
+    }
 
-    if (failed(rewriter.convertNonEntryRegionTypes(
-            &mlir::impl::getFunctionBody(op), *typeConverter))) {
+    if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
+                                                   conversions))) {
       rewriter.cancelRootUpdate(op);
       return failure();
     }
@@ -99,6 +126,9 @@ struct FunctionNonEntryBlockConversion : public ConversionPattern {
     rewriter.finalizeRootUpdate(op);
     return success();
   }
+
+private:
+  const DenseSet<BlockArgument> blockArgsToDetensor;
 };
 
 class DetensorizeTypeConverter : public TypeConverter {
@@ -160,46 +190,309 @@ struct ExtractFromReshapeFromElements
 
 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
+  LinalgDetensorize() = default;
+  LinalgDetensorize(const LinalgDetensorize &pass) {}
+
+  class CostModel {
+  public:
+    virtual ~CostModel() = default;
+
+    /// A cost model algorithm computes the following outputs:
+    ///
+    /// - opsToDetensor: the list of linalg ops that should be
+    /// detensored.
+    ///
+    /// - blockArgsToDetensor: since the operands and results of detensored
+    /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
+    /// from a BB argument and a linalg op's output can be passed to successor
+    /// BBs), we need to maintain the sub-set of arguments that should be
+    /// detensored (i.e. converted by typeConverter) for each affected BB.
+    ///
+    /// Example:
+    ///
+    /// For the following snippet:
+    /// ...
+    /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
+    ///   %7 = linalg.init_tensor [] : tensor<i32>
+    ///   %8 = linalg.generic #attrs
+    ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
+    ///     outs(%7 : tensor<i32>) {
+    ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
+    ///       %9 = addi %arg0, %arg1 : i32
+    ///       linalg.yield %9 : i32
+    ///   } -> tensor<i32>
+    ///   %10 = "some.op"(%9)
+    ///   br ^bb2(%8 : tensor<i32>)
+    /// ...
+    ///
+    /// if the cost model decides that the linalg.generic op should be
+    /// detensored, then:
+    /// - opsToDetensor should be = {linalg.generic{add}}.
+    /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
+    virtual void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
+                         DenseSet<Operation *> &opsToDetensor,
+                         DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
+
+    /// From the blockArgsToDetensor set computed by a CostModel
+    /// implementation, this method computes the corresponding branch op
+    /// detensoring. The result is a map from a branch op to a subset of indices
+    /// of its operands. The indices specify which of the branch op's operands
+    /// should be detensored.
+    ///
+    /// For the previous example, this method would compute: {bb2 -> {0}}.
+    static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
+        const DenseSet<BlockArgument> &blockArgsToDetensor) {
+      DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
+
+      for (auto blockArgumentElem : blockArgsToDetensor) {
+        Block *block = blockArgumentElem.getOwner();
+
+        for (PredecessorIterator pred = block->pred_begin();
+             pred != block->pred_end(); ++pred) {
+          BranchOpInterface terminator =
+              dyn_cast<BranchOpInterface>((*pred)->getTerminator());
+          auto blockOperands =
+              terminator.getSuccessorOperands(pred.getSuccessorIndex());
+
+          if (!blockOperands || blockOperands->empty())
+            continue;
+
+          detensorableBranchOps[terminator].insert(
+              blockOperands->getBeginOperandIndex() +
+              blockArgumentElem.getArgNumber());
+        }
+      }
+
+      return detensorableBranchOps;
+    }
+  };
+
+  /// Detensorize linalg ops involved in control-flow within a function.
+  ///
+  /// This model starts from CondBranchOps within a function. For each cond_br,
+  /// the model then walks the use-def chain for the branch's condition
+  /// backwards in order to understand where the condition's value comes from.
+  /// If the condition value is (indirectly) computed by a linalg op that can be
+  /// detensored, the model then continues walking the use-def chain in order to
+  /// understand where the linalg op's operands come from. This leads to
+  /// discovering a "detensoring component". A detensoring component is the set
+  /// of operations + block arguments that are involved in control-flow AND can
+  /// be detensored.
+  ///
+  /// For examples where this model succeeds to discover a detensoring
+  /// component, see:
+  /// - test/Dialect/Linalg/detensorize_while.mlir
+  /// - test/Dialect/Linalg/detesorize_while_pure_cf.mlir.
+  ///
+  /// For an example where this model marks control-flow as "non-detensorable",
+  /// see:
+  /// - test/Dialect/Linalg/detensorize_while_failure.mlir
+  class PureControlFlowDetectionModel : public CostModel {
+  public:
+    void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
+                 DenseSet<Operation *> &opsToDetensor,
+                 DenseSet<BlockArgument> &blockArgsToDetensor) override {
+      SmallVector<Value> workList;
+
+      func.walk(
+          [&](CondBranchOp condBr) { workList.push_back(condBr.condition()); });
+
+      DenseSet<Value> visitedValues;
+      DenseSet<Operation *> visitedOps;
+
+      while (!workList.empty()) {
+        Value currentItem = workList.pop_back_val();
+
+        if (!visitedValues.insert(currentItem).second)
+          continue;
+
+        // The current item is defined by a block argument.
+        if (auto bbarg = currentItem.dyn_cast<BlockArgument>()) {
+          BlockArgument currentItemBlockArgument =
+              currentItem.cast<BlockArgument>();
+          Block *ownerBlock = currentItemBlockArgument.getOwner();
+
+          // Function arguments are not detensored/converted.
+          if (&*ownerBlock->getParent()->begin() == ownerBlock)
+            continue;
+
+          // This inner-block argument is involved in control-flow, it should be
+          // detensored.
+          blockArgsToDetensor.insert(currentItemBlockArgument);
+
+          for (PredecessorIterator pred = ownerBlock->pred_begin();
+               pred != ownerBlock->pred_end(); ++pred) {
+            BranchOpInterface terminator =
+                dyn_cast<BranchOpInterface>((*pred)->getTerminator());
+
+            // TODO: For now, we give up if any of the control-flow components
+            // in a function is not detensorable. Fix that.
+            if (!terminator) {
+              opsToDetensor.clear();
+              blockArgsToDetensor.clear();
+              return;
+            }
+
+            auto ownerBlockOperands =
+                terminator.getSuccessorOperands(pred.getSuccessorIndex());
+
+            if (!ownerBlockOperands || ownerBlockOperands->empty())
+              continue;
+
+            // For each predecessor, add the value it passes to that argument to
+            // workList to find out how it's computed.
+            workList.push_back(
+                ownerBlockOperands
+                    .getValue()[currentItemBlockArgument.getArgNumber()]);
+          }
+
+          continue;
+        }
+
+        Operation *currentItemDefiningOp = currentItem.getDefiningOp();
+
+        if (!visitedOps.insert(currentItemDefiningOp).second)
+          continue;
+
+        // The current item is computed by a GenericOp.
+        if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
+          // The op was encountered already, no need to inspect it again.
+          if (opsToDetensor.count(genericOp))
+            continue;
+
+          // TODO: For now, we give up if any of the control-flow components
+          // in a function is not detensorable. Fix that.
+          if (!shouldBeDetensored(genericOp, typeConverter)) {
+            opsToDetensor.clear();
+            blockArgsToDetensor.clear();
+            return;
+          }
+
+          opsToDetensor.insert(genericOp);
+
+          for (Value genericOpOperand : genericOp.inputs())
+            workList.push_back(genericOpOperand);
+
+          continue;
+        }
+
+        // The current item is the result of a FromElemntsOp, it will be
+        // trivially detensored later as part of canonicalization patterns
+        // applied at the end of detensoring.
+        //
+        // Note: No need to check whether the result type of this op is
+        // detensorable since if it wasn't we wouldn't reach that point in the
+        // work list.
+        if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
+          continue;
+
+        // The current item is the result of a scalar op, add all its operands
+        // to the work list.
+        if (llvm::all_of(
+                currentItemDefiningOp->getResultTypes(),
+                [&](Type resultType) { return resultType.isIntOrFloat(); }))
+          for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
+            workList.push_back(scalarOpOperand);
+      }
+    }
+  };
+
+  /// Detensorize everything that can detensored.
+  class AggressiveDetensoringModel : public CostModel {
+  public:
+    void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
+                 DenseSet<Operation *> &opsToDetensor,
+                 DenseSet<BlockArgument> &blockArgsToDetensor) override {
+      func.walk([&](GenericOp genericOp) {
+        if (shouldBeDetensored(genericOp, typeConverter))
+          opsToDetensor.insert(genericOp);
+      });
+
+      for (Block &block : llvm::drop_begin(func.getBody(), 1))
+        for (BlockArgument blockArgument : block.getArguments())
+          blockArgsToDetensor.insert(blockArgument);
+    }
+  };
+
   void runOnFunction() override {
-    auto *context = &getContext();
+    MLIRContext *context = &getContext();
     DetensorizeTypeConverter typeConverter;
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
+    DenseSet<Operation *> opsToDetensor;
+    DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
+    DenseSet<BlockArgument> blockArgsToDetensor;
+
+    if (aggressiveMode.getValue()) {
+      AggressiveDetensoringModel costModel;
+      costModel.compute(getFunction(), typeConverter, opsToDetensor,
+                        blockArgsToDetensor);
+
+    } else {
+      PureControlFlowDetectionModel costModel;
+      costModel.compute(getFunction(), typeConverter, opsToDetensor,
+                        blockArgsToDetensor);
+    }
 
-    target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
-      // If any of the operands or results cannot be detensored (i.e. they are
-      // all legal according the DetensorizeTypeConverter), the op is considered
-      // legal and won't be detensored.
-      return llvm::any_of(op.getShapedOperandTypes(),
-                          [&](ShapedType shapedType) {
-                            return typeConverter.isLegal(shapedType);
-                          });
-    });
+    detensorableBranchOps =
+        CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
+
+    target.addDynamicallyLegalOp<GenericOp>(
+        [&](GenericOp op) { return !opsToDetensor.count(op); });
 
     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
-      // A function is legal if all of its non-entry blocks are legal. We don't
-      // legalize the entry block (i.e. the function's signature) since
-      // detensoring can't happen along external calling convention boundaries,
-      // which we conservatively approximate as all function signatures.
+      // A function is legal if all of its non-entry blocks are legal. We
+      // don't legalize the entry block (i.e. the function's signature) since
+      // detensoring can't happen along external calling convention
+      // boundaries, which we conservatively approximate as all function
+      // signatures.
       return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) {
-        return typeConverter.isLegal(block.getArgumentTypes());
+        if (llvm::any_of(blockArgsToDetensor, [&](BlockArgument blockArgument) {
+              return blockArgument.getOwner() == &block &&
+                     !typeConverter.isLegal(blockArgument.getType());
+            })) {
+          return false;
+        }
+        return true;
       });
     });
 
     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
-      return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
-             isLegalForBranchOpInterfaceTypeConversionPattern(op,
-                                                              typeConverter) ||
-             isLegalForReturnOpTypeConversionPattern(
-                 op, typeConverter, /*returnOpAlwaysLegal*/ true);
+      if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
+          isLegalForReturnOpTypeConversionPattern(op, typeConverter,
+                                                  /*returnOpAlwaysLegal*/ true))
+        return true;
+
+      if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
+        if (!detensorableBranchOps.count(branchOp))
+          return true;
+
+        for (auto operandIdx : detensorableBranchOps[branchOp])
+          if (!typeConverter.isLegal(
+                  branchOp->getOperand(operandIdx).getType()))
+            return false;
+
+        return true;
+      }
+
+      return false;
     });
 
-    patterns.add<DetensorizeGenericOp>(typeConverter, context);
-    patterns.add<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
-                                                  context, typeConverter);
-    // Since non-entry block arguments get detensorized, we also need to update
-    // the control flow inside the function to reflect the correct types.
-    populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
+    patterns.insert<DetensorizeGenericOp>(typeConverter, context);
+    patterns.insert<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
+                                                     context, typeConverter,
+                                                     blockArgsToDetensor);
+    // Since non-entry block arguments get detensorized, we also need to
+    // update the control flow inside the function to reflect the correct
+    // types.
+    auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
+                                          int operandIdx) -> bool {
+      return detensorableBranchOps.count(branchOp) &&
+             detensorableBranchOps[branchOp].count(operandIdx);
+    };
+
+    populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
+                                                   shouldConvertBranchOperand);
 
     if (failed(applyFullConversion(getFunction(), target, std::move(patterns))))
       signalPassFailure();
@@ -210,6 +503,11 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
                                             std::move(canonPatterns))))
       signalPassFailure();
   }
+
+  Option<bool> aggressiveMode{
+      *this, "aggressive-mode",
+      llvm::cl::desc("Detensorize all ops that qualify for detensoring along "
+                     "with branch operands and basic-block arguments.")};
 };
 } // namespace
 

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
index bf2dcd69e9cad..218efaccb6e74 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
@@ -52,6 +52,12 @@ class BranchOpInterfaceTypeConversion
   using OpInterfaceConversionPattern<
       BranchOpInterface>::OpInterfaceConversionPattern;
 
+  BranchOpInterfaceTypeConversion(
+      TypeConverter &typeConverter, MLIRContext *ctx,
+      function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
+      : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
+        shouldConvertBranchOperand(shouldConvertBranchOperand) {}
+
   LogicalResult
   matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -61,18 +67,23 @@ class BranchOpInterfaceTypeConversion
     for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
          succIdx < succEnd; ++succIdx) {
       auto successorOperands = op.getSuccessorOperands(succIdx);
-      if (!successorOperands)
+      if (!successorOperands || successorOperands->empty())
         continue;
+
       for (int idx = successorOperands->getBeginOperandIndex(),
                eidx = idx + successorOperands->size();
            idx < eidx; ++idx) {
-        newOperands[idx] = operands[idx];
+        if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
+          newOperands[idx] = operands[idx];
       }
     }
     rewriter.updateRootInPlace(
         op, [newOperands, op]() { op->setOperands(newOperands); });
     return success();
   }
+
+private:
+  function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
 };
 } // end anonymous namespace
 
@@ -98,9 +109,10 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
 } // end anonymous namespace
 
 void mlir::populateBranchOpInterfaceTypeConversionPattern(
-    RewritePatternSet &patterns, TypeConverter &typeConverter) {
-  patterns.add<BranchOpInterfaceTypeConversion>(typeConverter,
-                                                patterns.getContext());
+    RewritePatternSet &patterns, TypeConverter &typeConverter,
+    function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
+  patterns.insert<BranchOpInterfaceTypeConversion>(
+      typeConverter, patterns.getContext(), shouldConvertBranchOperand);
 }
 
 bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bbdaec68364c9..54dce7417e219 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -495,8 +495,17 @@ Block *ArgConverter::applySignatureConversion(
     // to pack the new values. For 1->1 mappings, if there is no materialization
     // provided, use the argument directly instead.
     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
-    Value newArg = converter.materializeArgumentConversion(
-        rewriter, origArg.getLoc(), origArg.getType(), replArgs);
+    Value newArg;
+
+    // If this is a 1->1 mapping and the types of new and replacement arguments
+    // match (i.e. it's an identity map), then the argument is mapped to its
+    // original type.
+    if (replArgs.size() == 1 && replArgs[0].getType() == origArg.getType())
+      newArg = replArgs[0];
+    else
+      newArg = converter.materializeArgumentConversion(
+          rewriter, origArg.getLoc(), origArg.getType(), replArgs);
+
     if (!newArg) {
       assert(replArgs.size() == 1 &&
              "couldn't materialize the result of 1->N conversion");
@@ -754,8 +763,9 @@ struct ConversionPatternRewriterImpl {
                      TypeConverter::SignatureConversion *entryConversion);
 
   /// Convert the types of non-entry block arguments within the given region.
-  LogicalResult convertNonEntryRegionTypes(Region *region,
-                                           TypeConverter &converter);
+  LogicalResult convertNonEntryRegionTypes(
+      Region *region, TypeConverter &converter,
+      ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
 
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
@@ -1173,15 +1183,30 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
 }
 
 LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
-    Region *region, TypeConverter &converter) {
+    Region *region, TypeConverter &converter,
+    ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
   argConverter.setConverter(region, &converter);
   if (region->empty())
     return success();
 
   // Convert the arguments of each block within the region.
-  for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
-    if (failed(convertBlockSignature(&block, converter)))
+  int blockIdx = 0;
+  assert((blockConversions.empty() ||
+          blockConversions.size() == region->getBlocks().size() - 1) &&
+         "expected either to provide no SignatureConversions at all or to "
+         "provide a SignatureConversion for each non-entry block");
+
+  for (Block &block :
+       llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
+    TypeConverter::SignatureConversion *blockConversion =
+        blockConversions.empty()
+            ? nullptr
+            : const_cast<TypeConverter::SignatureConversion *>(
+                  &blockConversions[blockIdx++]);
+
+    if (failed(convertBlockSignature(&block, converter, blockConversion)))
       return failure();
+  }
   return success();
 }
 
@@ -1351,8 +1376,9 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
 }
 
 LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
-    Region *region, TypeConverter &converter) {
-  return impl->convertNonEntryRegionTypes(region, converter);
+    Region *region, TypeConverter &converter,
+    ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
+  return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
 }
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,

diff  --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
new file mode 100644
index 0000000000000..05a3720b50891
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = []
+}
+
+func @main() -> (tensor<i32>) attributes {} {
+  %c0 = constant 0 : i32
+  %0 = tensor.from_elements %c0 : tensor<1xi32>
+  %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+  %c10 = constant 10 : i32
+  %1 = tensor.from_elements %c10 : tensor<1xi32>
+  %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+  br ^bb1(%reshaped0 : tensor<i32>)
+
+^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
+  %3 = linalg.init_tensor [] : tensor<i1>
+  %4 = linalg.generic #attrs
+    ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+    outs(%3 : tensor<i1>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
+      %8 = cmpi slt, %arg0, %arg1 : i32
+      linalg.yield %8 : i1
+  } -> tensor<i1>
+  %5 = tensor.extract %4[] : tensor<i1>
+  cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
+
+^bb2(%6: tensor<i32>):  // pred: ^bb1
+  %7 = linalg.init_tensor [] : tensor<i32>
+  %8 = linalg.generic #attrs
+    ins(%6, %6 : tensor<i32>, tensor<i32>)
+    outs(%7 : tensor<i32>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
+      %9 = addi %arg0, %arg1 : i32
+      linalg.yield %9 : i32
+  } -> tensor<i32>
+  br ^bb3(%8 : tensor<i32>)
+
+^bb3(%10: tensor<i32>):  // pred: ^bb1
+  return %10 : tensor<i32>
+}
+
+// CHECK-LABEL:  func @main()
+// CHECK-NEXT:     constant 0
+// CHECK-NEXT:     constant 10
+// CHECK-NEXT:     br ^[[bb1:.*]](%{{.*}}: i32)
+// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
+// CHECK-NEXT:     tensor.from_elements %{{.*}}
+// CHECK-NEXT:     linalg.tensor_reshape %{{.*}}
+// CHECK-NEXT:     cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:     cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
+// CHECK-NEXT:   ^[[bb2]](%{{.*}}: tensor<i32>)
+// CHECK-NEXT:     linalg.init_tensor
+// CHECK-NEXT:     linalg.generic
+// CHECK-NEXT:     ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i32)
+// CHECK-NEXT:       addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:       linalg.yield %{{.*}}
+// CHECK-NEXT:     } -> tensor<i32>
+// CHECK-NEXT:     br ^[[bb3:.*]](%{{.*}} : tensor<i32>)
+// CHECK-NEXT:   ^[[bb3]](%{{.*}}: tensor<i32>)
+// CHECK-NEXT:     return %{{.*}}
+// CHECK-NEXT:   }

diff  --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
new file mode 100644
index 0000000000000..6fcd056f9b365
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL
+// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF
+
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = []
+}
+
+func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
+  %c10 = constant 10 : i32
+  %1 = tensor.from_elements %c10 : tensor<1xi32>
+  %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+  %3 = linalg.init_tensor [] : tensor<i1>
+  %4 = linalg.generic #attrs
+    ins(%farg0, %reshaped1 : tensor<i32>, tensor<i32>)
+    outs(%3 : tensor<i1>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):
+      %8 = cmpi slt, %arg0, %arg1 : i32
+      linalg.yield %8 : i1
+  } -> tensor<i1>
+  return %4 : tensor<i1>
+}
+
+
+// DET-ALL-LABEL: func @main(%{{.*}}: tensor<i32>)
+// DET-ALL-NEXT:    constant 10
+// DET-ALL-NEXT:    tensor.extract %{{.*}}[]
+// DET-ALL-NEXT:    cmpi slt, %{{.*}}, %{{.*}}
+// DET-ALL-NEXT:    tensor.from_elements %{{.*}}
+// DET-ALL-NEXT:    linalg.tensor_reshape %{{.*}}
+// DET-ALL-NEXT:    return %{{.*}} : tensor<i1>
+// DET-ALL-NEXT:  }
+
+// DET-CF-LABEL: func @main(%{{.*}}: tensor<i32>)
+// DET-CF-NEXT:    constant 10 : i32
+// DET-CF-NEXT:    tensor.from_elements %{{.*}}
+// DET-CF-NEXT:    linalg.tensor_reshape %{{.*}}
+// DET-CF-NEXT:    linalg.init_tensor [] : tensor<i1>
+// DET-CF-NEXT:    linalg.generic
+// DET-CF-NEXT:    ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i1)
+// DET-CF-NEXT:      cmpi slt, %{{.*}}, %{{.*}}
+// DET-CF-NEXT:      linalg.yield %{{.*}}
+// DET-CF-NEXT:    } -> tensor<i1>
+// DET-CF-NEXT:    return %{{.*}}
+// DET-CF-NEXT:  }

diff  --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
new file mode 100644
index 0000000000000..72390f0d76087
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL
+// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = []
+}
+
+func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} {
+  br ^bb1(%farg0 : tensor<i32>)
+
+^bb1(%0: tensor<i32>):  // 2 preds: ^bb0, ^bb2
+  %1 = linalg.init_tensor [] : tensor<i1>
+  %2 = linalg.generic #attrs
+    ins(%0, %farg1 : tensor<i32>, tensor<i32>)
+    outs(%1 : tensor<i1>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
+      %8 = cmpi slt, %arg0, %arg1 : i32
+      linalg.yield %8 : i1
+  } -> tensor<i1>
+  %3 = tensor.extract %2[] : tensor<i1>
+  cond_br %3, ^bb2(%0 : tensor<i32>), ^bb3(%0 : tensor<i32>)
+
+^bb2(%4: tensor<i32>):  // pred: ^bb1
+  %5 = linalg.init_tensor [] : tensor<i32>
+  %6 = linalg.generic #attrs
+    ins(%4, %4 : tensor<i32>, tensor<i32>)
+    outs(%5 : tensor<i32>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
+      %8 = addi %arg0, %arg1 : i32
+      linalg.yield %8 : i32
+  } -> tensor<i32>
+  br ^bb1(%6 : tensor<i32>)
+
+^bb3(%7: tensor<i32>):  // pred: ^bb1
+  return %7 : tensor<i32>
+}
+
+// Test aggresively detensoring all detensorable ops.
+//
+// DET-ALL-LABEL: func @main
+// DET-ALL-SAME:    (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>)
+// DET-ALL:         tensor.extract {{.*}}
+// DET-ALL:         br ^[[bb1:.*]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb1]](%{{.*}}: i32)
+// DET-ALL:         cmpi slt, {{.*}}
+// DET-ALL:         cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
+// DET-ALL:         addi {{.*}}
+// DET-ALL:         br ^[[bb1]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
+// DET-ALL:         tensor.from_elements {{.*}}
+// DET-ALL:         linalg.tensor_reshape {{.*}}
+// DET-ALL:         return %{{.*}} : tensor<i32>
+
+// Test detensoring only ops involed in control-flow.
+//
+// DET-CF-LABEL: func @main
+// DET-CF-SAME:    (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>)
+// DET-CF:         tensor.extract {{.*}}
+// DET-CF:         br ^[[bb1:.*]](%{{.*}} : i32)
+// DET-CF:       ^[[bb1]](%{{.*}}: i32)
+// DET-CF-DAG      tensor.from_elements {{.*}}
+// DET-CF-DAG:     linalg.tensor_reshape {{.*}}
+// DET-CF-DAG:     cmpi slt, {{.*}}
+// DET-CF:         cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : tensor<i32>)
+// DET-CF:       ^[[bb2]](%{{.*}}: i32)
+// DET-CF:         addi {{.*}}
+// DET-CF:         br ^[[bb1]](%{{.*}} : i32)
+// DET-CF:       ^[[bb3]](%{{.*}}: tensor<i32>)
+// DET-CF:         return %{{.*}} : tensor<i32>

diff  --git a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
new file mode 100644
index 0000000000000..36361d5492110
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
@@ -0,0 +1,111 @@
+// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL
+// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF
+
+#map0 = affine_map<() -> ()>
+#map1 = affine_map<(i) -> ()>
+#map2 = affine_map<(i) -> (i)>
+
+#attrs = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = []
+}
+
+#sum_reduction_attrs = {
+  indexing_maps = [#map2, #map1],
+  iterator_types = ["reduction"]
+}
+
+
+#broadcast_attrs = {
+  indexing_maps = [#map1, #map2],
+  iterator_types = ["parallel"]
+}
+
+func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} {
+  br ^bb1(%farg0 : tensor<10xi32>)
+
+^bb1(%0: tensor<10xi32>):  // 2 preds: ^bb0, ^bb2
+  %1 = linalg.init_tensor [] : tensor<i32>
+  %2 = linalg.generic #sum_reduction_attrs
+    ins(%0: tensor<10xi32>)
+    outs(%1: tensor<i32>) {
+      ^bb(%a: i32, %x: i32):
+        %b = addi %x, %a : i32
+        linalg.yield %b : i32
+  } -> tensor<i32>
+
+  %3 = linalg.init_tensor [] : tensor<i1>
+  %4 = linalg.generic #attrs
+    ins(%2, %farg1 : tensor<i32>, tensor<i32>)
+    outs(%3 : tensor<i1>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
+      %8 = cmpi slt, %arg0, %arg1 : i32
+      linalg.yield %8 : i1
+  } -> tensor<i1>
+  %5 = tensor.extract %4[] : tensor<i1>
+  cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
+
+^bb2(%6: tensor<i32>):  // pred: ^bb1
+  %7 = linalg.init_tensor [10] : tensor<10xi32>
+  %9 = linalg.generic #broadcast_attrs
+       ins(%6: tensor<i32>)
+      outs(%7: tensor<10xi32>) {
+    ^bb(%a: i32, %b: i32) :
+      linalg.yield %a : i32
+  } -> tensor<10xi32>
+
+  br ^bb1(%9 : tensor<10xi32>)
+
+^bb3(%10: tensor<i32>):  // pred: ^bb1
+  return %10 : tensor<i32>
+}
+
+// Test aggresively detensoring all detensorable ops.
+//
+// DET-ALL-LABEL: func @main
+// DET-ALL-SAME:    (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
+// DET-ALL:         br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
+// DET-ALL:       ^[[bb1]](%{{.*}}: tensor<10xi32>)
+// DET-ALL:         linalg.init_tensor [] : tensor<i32>
+// DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
+// DET-ALL:         ^bb0(%{{.*}}: i32, %{{.*}}: i32):  // no predecessors
+// DET-ALL:           %{{.*}} = addi %{{.*}}, %{{.*}}
+// DET-ALL:           linalg.yield %{{.*}} : i32
+// DET-ALL:         } -> tensor<i32>
+// DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL:         cmpi slt, %{{.*}}, %{{.*}} : i32
+// DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL:         cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
+// DET-ALL:         tensor.from_elements %{{.*}} : tensor<1xi32>
+// DET-ALL:         linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-ALL:         linalg.init_tensor [10] : tensor<10xi32>
+// DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
+// DET-ALL:         ^bb0(%{{.*}}: i32, %{{.*}}: i32):
+// DET-ALL:           linalg.yield %{{.*}} : i32
+// DET-ALL:         } -> tensor<10xi32>
+// DET-ALL:         br ^[[bb1]](%{{.*}} : tensor<10xi32>)
+// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
+// DET-ALL:         tensor.from_elements %{{.*}} : tensor<1xi32>
+// DET-ALL:         linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-ALL:         return %{{.*}} : tensor<i32>
+// DET-ALL:       }
+
+// Try to detensor pure control-flow. However, that fails since the potential 
+// detensorable component contains some ops that cannot be detensored.
+//
+// DET-CF-LABEL: func @main
+// DET-CF-SAME:    (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
+// DET-CF:         br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
+// DET-CF:       ^bb1(%{{.*}}: tensor<10xi32>)
+// DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
+// DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>) outs(%{{.*}} : tensor<i1>) {
+// DET-CF:         cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
+// DET-CF:       ^bb2(%{{.*}}: tensor<i32>)
+// DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
+// DET-CF:         br ^bb1(%{{.*}} : tensor<10xi32>)
+// DET-CF:       ^bb3(%{{.*}}: tensor<i32>)
+// DET-CF:         return %{{.*}} : tensor<i32>
+// DET-CF:       }

diff  --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
new file mode 100644
index 0000000000000..b0d88efadca70
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = []
+}
+
+func @main() -> () attributes {} {
+  %c0 = constant 0 : i32
+  %0 = tensor.from_elements %c0 : tensor<1xi32>
+  %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+  %c10 = constant 10 : i32
+  %1 = tensor.from_elements %c10 : tensor<1xi32>
+  %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+  br ^bb1(%reshaped0 : tensor<i32>)
+
+^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
+  %3 = linalg.init_tensor [] : tensor<i1>
+  %4 = linalg.generic #attrs
+    ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+    outs(%3 : tensor<i1>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
+      %8 = cmpi slt, %arg0, %arg1 : i32
+      linalg.yield %8 : i1
+  } -> tensor<i1>
+  %5 = tensor.extract %4[] : tensor<i1>
+  cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3
+
+^bb2(%6: tensor<i32>):  // pred: ^bb1
+  %7 = linalg.init_tensor [] : tensor<i32>
+  %8 = linalg.generic #attrs
+    ins(%6, %6 : tensor<i32>, tensor<i32>)
+    outs(%7 : tensor<i32>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
+      %9 = addi %arg0, %arg1 : i32
+      linalg.yield %9 : i32
+  } -> tensor<i32>
+  br ^bb1(%8 : tensor<i32>)
+
+^bb3:  // pred: ^bb1
+  return
+}
+
+// CHECK-LABEL: func @main
+// CHECK-NEXT:    constant 0 : i32
+// CHECK-NEXT:    constant 10
+// CHECK-NEXT:    br ^[[bb1:.*]](%{{.*}} : i32)
+// CHECK-NEXT:  ^[[bb1]](%{{.*}}: i32)
+// CHECK-NEXT:    %{{.*}} = cmpi slt, %{{.*}}, %{{.*}}
+// CHECK-NEXT:    cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]]
+// CHECK-NEXT:  ^[[bb2]](%{{.*}}: i32)
+// CHECK-NEXT:    %{{.*}} = addi %{{.*}}, %{{.*}}
+// CHECK-NEXT:    br ^[[bb1]](%{{.*}} : i32)
+// CHECK-NEXT:  ^[[bb3]]:
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }

diff  --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorized_0d.mlir
index e35a34fd157d4..91e3f080bedcb 100644
--- a/mlir/test/Dialect/Linalg/detensorized_0d.mlir
+++ b/mlir/test/Dialect/Linalg/detensorized_0d.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize=aggressive-mode | FileCheck %s
 
 #map = affine_map<() -> ()>
 

diff  --git a/mlir/test/Dialect/Linalg/detensorized_while.mlir b/mlir/test/Dialect/Linalg/detensorized_while.mlir
deleted file mode 100644
index a227e753006c9..0000000000000
--- a/mlir/test/Dialect/Linalg/detensorized_while.mlir
+++ /dev/null
@@ -1,53 +0,0 @@
-// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s
-
-#map0 = affine_map<() -> ()>
-
-#attrs = {
-  indexing_maps = [#map0, #map0, #map0],
-  iterator_types = []
-}
-
-func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} {
-  br ^bb1(%farg0 : tensor<i32>)
-
-^bb1(%0: tensor<i32>):  // 2 preds: ^bb0, ^bb2
-  %1 = linalg.init_tensor [] : tensor<i1>
-  %2 = linalg.generic #attrs
-    ins(%0, %farg1 : tensor<i32>, tensor<i32>)
-    outs(%1 : tensor<i1>) {
-    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
-      %8 = cmpi slt, %arg0, %arg1 : i32
-      linalg.yield %8 : i1
-  } -> tensor<i1>
-  %3 = tensor.extract %2[] : tensor<i1>
-  cond_br %3, ^bb2(%0 : tensor<i32>), ^bb3(%0 : tensor<i32>)
-
-^bb2(%4: tensor<i32>):  // pred: ^bb1
-  %5 = linalg.init_tensor [] : tensor<i32>
-  %6 = linalg.generic #attrs
-    ins(%4, %4 : tensor<i32>, tensor<i32>)
-    outs(%5 : tensor<i32>) {
-    ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
-      %8 = addi %arg0, %arg1 : i32
-      linalg.yield %8 : i32
-  } -> tensor<i32>
-  br ^bb1(%6 : tensor<i32>)
-
-^bb3(%7: tensor<i32>):  // pred: ^bb1
-  return %7 : tensor<i32>
-}
-
-// CHECK-LABEL: func @main
-// CHECK-SAME:    (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>)
-// CHECK:         tensor.extract {{.*}}
-// CHECK:         br ^[[bb1:.*]](%{{.*}} : i32)
-// CHECK:       ^[[bb1]](%{{.*}}: i32)
-// CHECK:         cmpi slt, {{.*}}
-// CHECK:         cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK:       ^[[bb2]](%{{.*}}: i32)
-// CHECK:         addi {{.*}}
-// CHECK:         br ^[[bb1]](%{{.*}} : i32)
-// CHECK:       ^[[bb3]](%{{.*}}: i32)
-// CHECK:         tensor.from_elements {{.*}}
-// CHECK:         linalg.tensor_reshape {{.*}}
-// CHECK:         return %{{.*}} : tensor<i32>


        


More information about the Mlir-commits mailing list