[Mlir-commits] [mlir] 015192c - [mlir:DialectConversion] Restructure how argument/target materializations get invoked
River Riddle
llvmlistbot at llvm.org
Tue Oct 26 19:12:32 PDT 2021
Author: River Riddle
Date: 2021-10-27T02:09:04Z
New Revision: 015192c63415ae828c3a9bc8b56d0667abcca8ae
URL: https://github.com/llvm/llvm-project/commit/015192c63415ae828c3a9bc8b56d0667abcca8ae
DIFF: https://github.com/llvm/llvm-project/commit/015192c63415ae828c3a9bc8b56d0667abcca8ae.diff
LOG: [mlir:DialectConversion] Restructure how argument/target materializations get invoked
The current implementation invokes materializations
whenever an input operand does not have a mapping for the
desired type, i.e. it requires materialization at the earliest possible
point. This conflicts with goal of dialect conversion (and also the
current documentation) which states that a materialization is only
required if the materialization is supposed to persist after the
conversion process has finished.
This revision refactors this such that whenever a target
materialization "might" be necessary, we insert an
unrealized_conversion_cast to act as a temporary materialization.
This allows for deferring the invocation of the user
materialization hooks until the end of the conversion process,
where we actually have a better sense if it's actually
necessary. This has several benefits:
* In some cases a target materialization hook is no longer
necessary
When performing a full conversion, there are some situations
where a temporary materialization is necessary. Moving forward,
these users won't need to provide any target materializations,
as the temporary materializations do not require the user to
provide materialization hooks.
* getRemappedValue can now handle values that haven't been
converted yet
Before this commit, it wasn't well supported to get the remapped
value of a value that hadn't been converted yet (making it
difficult/impossible to convert multiple operations in many
situations). This commit updates getRemappedValue to properly
handle this case by inserting temporary materializations when
necessary.
Another code-health related benefit is that with this change we
can move a majority of the complexity related to materializations
to the end of the conversion process, instead of handling adhoc
while conversion is happening.
Differential Revision: https://reviews.llvm.org/D111620
Added:
Modified:
mlir/docs/Tutorials/Toy/Ch-5.md
mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
mlir/include/mlir/IR/BlockAndValueMapping.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/ArmSVE/memcpy.mlir
mlir/test/Dialect/Linalg/bufferize.mlir
mlir/test/Dialect/Linalg/detensorize_0d.mlir
mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
mlir/test/Dialect/SCF/bufferize.mlir
mlir/test/Dialect/Standard/bufferize.mlir
mlir/test/Dialect/Standard/func-bufferize.mlir
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Transforms/test-legalize-remapped-value.mlir
mlir/test/Transforms/test-legalize-type-conversion.mlir
mlir/test/Transforms/test-legalizer.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md
index 3d58e4e2d9a0c..20cb4967d4c00 100644
--- a/mlir/docs/Tutorials/Toy/Ch-5.md
+++ b/mlir/docs/Tutorials/Toy/Ch-5.md
@@ -70,9 +70,14 @@ void ToyToAffineLoweringPass::runOnFunction() {
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
// a partial lowering, we explicitly mark the Toy operations that don't want
- // to lower, `toy.print`, as *legal*.
+ // to lower, `toy.print`, as *legal*. `toy.print` will still need its operands
+ // to be updated though (as we convert from TensorType to MemRefType), so we
+ // only treat it as `legal` if its operands are legal.
target.addIllegalDialect<ToyDialect>();
- target.addLegalOp<PrintOp>();
+ target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
+ return llvm::none_of(op->getOperandTypes(),
+ [](Type type) { return type.isa<TensorType>(); });
+ });
...
}
```
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index 10f133709e77e..4a03534b502d1 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -197,6 +197,24 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Print operations
+//===----------------------------------------------------------------------===//
+
+struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
+ using OpConversionPattern<toy::PrintOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ // We don't lower "toy.print" in this pass, but we need to update its
+ // operands.
+ rewriter.updateRootInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Return operations
//===----------------------------------------------------------------------===//
@@ -294,15 +312,21 @@ void ToyToAffineLoweringPass::runOnFunction() {
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
// a partial lowering, we explicitly mark the Toy operations that don't want
- // to lower, `toy.print`, as `legal`.
+ // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands
+ // to be updated though (as we convert from TensorType to MemRefType), so we
+ // only treat it as `legal` if its operands are legal.
target.addIllegalDialect<toy::ToyDialect>();
- target.addLegalOp<toy::PrintOp>();
+ target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
+ return llvm::none_of(op->getOperandTypes(),
+ [](Type type) { return type.isa<TensorType>(); });
+ });
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
- ReturnOpLowering, TransposeOpLowering>(&getContext());
+ PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
+ &getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index 24ff8e174ab80..4a03534b502d1 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -57,9 +57,9 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
/// induction variables for the iteration. It returns a value to store at the
/// current index of the iteration.
using LoopIterationFn = function_ref<Value(
- OpBuilder &builder, ValueRange memRefOperands, ValueRange loopIvs)>;
+ OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>;
-static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
+static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
@@ -162,6 +162,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
constantIndices.push_back(
rewriter.create<arith::ConstantIndexOp>(loc, 0));
}
+
// The constant operation represents a multi-dimensional constant, so we
// will need to generate a store for each of the elements. The following
// functor recursively walks the dimensions of the constant shape,
@@ -196,6 +197,24 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Print operations
+//===----------------------------------------------------------------------===//
+
+struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
+ using OpConversionPattern<toy::PrintOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ // We don't lower "toy.print" in this pass, but we need to update its
+ // operands.
+ rewriter.updateRootInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Return operations
//===----------------------------------------------------------------------===//
@@ -293,15 +312,21 @@ void ToyToAffineLoweringPass::runOnFunction() {
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
// a partial lowering, we explicitly mark the Toy operations that don't want
- // to lower, `toy.print`, as `legal`.
+ // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands
+ // to be updated though (as we convert from TensorType to MemRefType), so we
+ // only treat it as `legal` if its operands are legal.
target.addIllegalDialect<toy::ToyDialect>();
- target.addLegalOp<toy::PrintOp>();
+ target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
+ return llvm::none_of(op->getOperandTypes(),
+ [](Type type) { return type.isa<TensorType>(); });
+ });
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
- ReturnOpLowering, TransposeOpLowering>(&getContext());
+ PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
+ &getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 10f133709e77e..4a03534b502d1 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -197,6 +197,24 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ToyToAffine RewritePatterns: Print operations
+//===----------------------------------------------------------------------===//
+
+struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
+ using OpConversionPattern<toy::PrintOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ // We don't lower "toy.print" in this pass, but we need to update its
+ // operands.
+ rewriter.updateRootInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Return operations
//===----------------------------------------------------------------------===//
@@ -294,15 +312,21 @@ void ToyToAffineLoweringPass::runOnFunction() {
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
// a partial lowering, we explicitly mark the Toy operations that don't want
- // to lower, `toy.print`, as `legal`.
+ // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands
+ // to be updated though (as we convert from TensorType to MemRefType), so we
+ // only treat it as `legal` if its operands are legal.
target.addIllegalDialect<toy::ToyDialect>();
- target.addLegalOp<toy::PrintOp>();
+ target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
+ return llvm::none_of(op->getOperandTypes(),
+ [](Type type) { return type.isa<TensorType>(); });
+ });
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
RewritePatternSet patterns(&getContext());
patterns.add<AddOpLowering, ConstantOpLowering, MulOpLowering,
- ReturnOpLowering, TransposeOpLowering>(&getContext());
+ PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
+ &getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
index c135eee27ee72..74b14bdf2efb1 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
@@ -25,7 +25,7 @@ class SPIRVToLLVMConversion : public OpConversionPattern<SPIRVOp> {
public:
SPIRVToLLVMConversion(MLIRContext *context, LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
- : OpConversionPattern<SPIRVOp>(context, benefit),
+ : OpConversionPattern<SPIRVOp>(typeConverter, context, benefit),
typeConverter(typeConverter) {}
protected:
diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h
index 9959d883f89cb..0988012da8ea6 100644
--- a/mlir/include/mlir/IR/BlockAndValueMapping.h
+++ b/mlir/include/mlir/IR/BlockAndValueMapping.h
@@ -27,10 +27,8 @@ class BlockAndValueMapping {
public:
/// Inserts a new mapping for 'from' to 'to'. If there is an existing mapping,
/// it is overwritten.
- void map(Block *from, Block *to) { valueMap[from] = to; }
- void map(Value from, Value to) {
- valueMap[from.getAsOpaquePointer()] = to.getAsOpaquePointer();
- }
+ void map(Block *from, Block *to) { blockMap[from] = to; }
+ void map(Value from, Value to) { valueMap[from] = to; }
template <
typename S, typename T,
@@ -42,14 +40,12 @@ class BlockAndValueMapping {
}
/// Erases a mapping for 'from'.
- void erase(Block *from) { valueMap.erase(from); }
- void erase(Value from) { valueMap.erase(from.getAsOpaquePointer()); }
+ void erase(Block *from) { blockMap.erase(from); }
+ void erase(Value from) { valueMap.erase(from); }
/// Checks to see if a mapping for 'from' exists.
- bool contains(Block *from) const { return valueMap.count(from); }
- bool contains(Value from) const {
- return valueMap.count(from.getAsOpaquePointer());
- }
+ bool contains(Block *from) const { return blockMap.count(from); }
+ bool contains(Value from) const { return valueMap.count(from); }
/// Lookup a mapped value within the map. If a mapping for the provided value
/// does not exist then return nullptr.
@@ -76,28 +72,26 @@ class BlockAndValueMapping {
/// Clears all mappings held by the mapper.
void clear() { valueMap.clear(); }
- /// Returns a new mapper containing the inverse mapping.
- BlockAndValueMapping getInverse() const {
- BlockAndValueMapping result;
- for (const auto &pair : valueMap)
- result.valueMap.try_emplace(pair.second, pair.first);
- return result;
- }
+ /// Return the held value mapping.
+ const DenseMap<Value, Value> &getValueMap() const { return valueMap; }
+
+ /// Return the held block mapping.
+ const DenseMap<Block *, Block *> &getBlockMap() const { return blockMap; }
private:
/// Utility lookupOrValue that looks up an existing key or returns the
/// provided value.
Block *lookupOrValue(Block *from, Block *value) const {
- auto it = valueMap.find(from);
- return it != valueMap.end() ? reinterpret_cast<Block *>(it->second) : value;
+ auto it = blockMap.find(from);
+ return it != blockMap.end() ? it->second : value;
}
Value lookupOrValue(Value from, Value value) const {
- auto it = valueMap.find(from.getAsOpaquePointer());
- return it != valueMap.end() ? Value::getFromOpaquePointer(it->second)
- : value;
+ auto it = valueMap.find(from);
+ return it != valueMap.end() ? it->second : value;
}
- DenseMap<void *, void *> valueMap;
+ DenseMap<Value, Value> valueMap;
+ DenseMap<Block *, Block *> blockMap;
};
} // end namespace mlir
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 198522885041a..0f74021184b3e 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -151,21 +151,8 @@ class alignas(8) Operation final
/// Replace all uses of results of this operation with the provided 'values'.
template <typename ValuesT>
- std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value>
- replaceAllUsesWith(ValuesT &&values) {
- assert(std::distance(values.begin(), values.end()) == getNumResults() &&
- "expected 'values' to correspond 1-1 with the number of results");
-
- auto valueIt = values.begin();
- for (unsigned i = 0, e = getNumResults(); i != e; ++i)
- getResult(i).replaceAllUsesWith(*(valueIt++));
- }
-
- /// Replace all uses of results of this operation with results of 'op'.
- void replaceAllUsesWith(Operation *op) {
- assert(getNumResults() == op->getNumResults());
- for (unsigned i = 0, e = getNumResults(); i != e; ++i)
- getResult(i).replaceAllUsesWith(op->getResult(i));
+ void replaceAllUsesWith(ValuesT &&values) {
+ getResults().replaceAllUsesWith(std::forward<ValuesT>(values));
}
/// Destroys this operation and its subclass data.
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1e85b4f0f7f91..5c41123878982 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -903,6 +903,7 @@ class ResultRange final
ResultRange, detail::OpResultImpl *, OpResult, OpResult, OpResult> {
public:
using RangeBaseT::RangeBaseT;
+ ResultRange(OpResult result);
//===--------------------------------------------------------------------===//
// Types
@@ -934,6 +935,22 @@ class ResultRange final
[](OpResult result) { return result.use_empty(); });
}
+ /// Replace all uses of results of this range with the provided 'values'. The
+ /// size of `values` must match the size of this range.
+ template <typename ValuesT>
+ std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value>
+ replaceAllUsesWith(ValuesT &&values) {
+ assert(static_cast<size_t>(std::distance(values.begin(), values.end())) ==
+ size() &&
+ "expected 'values' to correspond 1-1 with the number of results");
+
+ for (auto it : llvm::zip(*this, values))
+ std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+ }
+
+ /// Replace all uses of results of this range with results of 'op'.
+ void replaceAllUsesWith(Operation *op);
+
//===--------------------------------------------------------------------===//
// Users
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 86d79e192a0dd..1d89eccf9a8c7 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -118,9 +118,8 @@ class TypeConverter {
/// must return a Value of the converted type on success, an `llvm::None` if
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. It will only be called for (sub)types of `T`.
- /// Materialization functions must be provided when a type conversion
- /// results in more than one type, or if a type conversion may persist after
- /// the conversion has finished.
+ /// Materialization functions must be provided when a type conversion may
+ /// persist after the conversion has finished.
///
/// This method registers a materialization that will be called when
/// converting an illegal block argument type, to a legal type.
@@ -551,10 +550,17 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
- /// Return the converted value that replaces 'key'. Return 'key' if there is
- /// no such a converted value.
+ /// Return the converted value of 'key' with a type defined by the type
+ /// converter of the currently executing pattern. Return nullptr in the case
+ /// of failure, the remapped value otherwise.
Value getRemappedValue(Value key);
+ /// Return the converted values that replace 'keys' with types defined by the
+ /// type converter of the currently executing pattern. Returns failure if the
+ /// remap failed, success otherwise.
+ LogicalResult getRemappedValues(ValueRange keys,
+ SmallVectorImpl<Value> &results);
+
//===--------------------------------------------------------------------===//
// PatternRewriter Hooks
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index dc5bdc78a054c..f581a0a2dfd5f 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -56,7 +56,7 @@ class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
public:
SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
ScfToSPIRVContextImpl *scfToSPIRVContext)
- : OpConversionPattern<OpTy>::OpConversionPattern(context),
+ : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
protected:
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 889b733434411..51df6c763c3fe 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -218,7 +218,7 @@ static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
}
/// Utility for `spv.Load` and `spv.Store` conversion.
-static LogicalResult replaceWithLoadOrStore(Operation *op,
+static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter,
unsigned alignment, bool isVolatile,
@@ -228,12 +228,14 @@ static LogicalResult replaceWithLoadOrStore(Operation *op,
if (!dstType)
return failure();
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
- loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal);
+ loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment,
+ isVolatile, isNonTemporal);
return success();
}
auto storeOp = cast<spirv::StoreOp>(op);
- rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, storeOp.value(),
- storeOp.ptr(), alignment,
+ spirv::StoreOpAdaptor adaptor(operands);
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
+ adaptor.ptr(), alignment,
isVolatile, isNonTemporal);
return success();
}
@@ -308,7 +310,7 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
if (!dstType)
return failure();
// To use GEP we need to add a first 0 index to go through the pointer.
- auto indices = llvm::to_vector<4>(op.indices());
+ auto indices = llvm::to_vector<4>(adaptor.indices());
Type indexType = op.indices().front().getType();
auto llvmIndexType = typeConverter.convertType(indexType);
if (!llvmIndexType)
@@ -316,7 +318,7 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
Value zero = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
indices.insert(indices.begin(), zero);
- rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(),
+ rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(),
indices);
return success();
}
@@ -572,11 +574,11 @@ class CompositeExtractPattern
IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- op, dstType, op.composite(), index);
+ op, dstType, adaptor.composite(), index);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
- op, dstType, op.composite(), op.indices());
+ op, dstType, adaptor.composite(), op.indices());
return success();
}
};
@@ -602,11 +604,11 @@ class CompositeInsertPattern
IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- op, dstType, op.composite(), op.object(), index);
+ op, dstType, adaptor.composite(), adaptor.object(), index);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
- op, dstType, op.composite(), op.object(), op.indices());
+ op, dstType, adaptor.composite(), adaptor.object(), op.indices());
return success();
}
};
@@ -897,9 +899,10 @@ class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.memory_access().hasValue()) {
- return replaceWithLoadOrStore(
- op, rewriter, this->typeConverter, /*alignment=*/0,
- /*isVolatile=*/false, /*isNonTemporal=*/false);
+ return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
+ this->typeConverter, /*alignment=*/0,
+ /*isVolatile=*/false,
+ /*isNonTemporal=*/false);
}
auto memoryAccess = op.memory_access().getValue();
switch (memoryAccess) {
@@ -911,8 +914,9 @@ class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
- return replaceWithLoadOrStore(op, rewriter, this->typeConverter,
- alignment, isVolatile, isNonTemporal);
+ return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
+ this->typeConverter, alignment, isVolatile,
+ isNonTemporal);
}
default:
// There is no support of other memory access attributes.
@@ -1178,13 +1182,13 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
Value extended;
if (isUnsignedIntegerOrVector(op2Type)) {
extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
- operation.operand2());
+ adaptor.operand2());
} else {
extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
- operation.operand2());
+ adaptor.operand2());
}
Value result = rewriter.template create<LLVMOp>(
- loc, dstType, operation.operand1(), extended);
+ loc, dstType, adaptor.operand1(), extended);
rewriter.replaceOp(operation, result);
return success();
}
@@ -1268,7 +1272,7 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
return success();
}
Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
- rewriter.create<LLVM::StoreOp>(loc, init, allocated);
+ rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated);
rewriter.replaceOp(varOp, allocated);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 1046c37588f69..966d3f3b8fceb 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -503,7 +503,6 @@ class VectorShuffleOpConversion
// For all other cases, insert the individual values individually.
Type eltType;
- llvm::errs() << llvmType << "\n";
if (auto arrayType = llvmType.dyn_cast<LLVM::LLVMArrayType>())
eltType = arrayType.getElementType();
else
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index cb34804ee262d..b4ff4af903564 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -24,6 +24,9 @@ using namespace mlir::linalg;
static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
+ if (inputs[0].getType().isa<TensorType>())
+ return nullptr;
+
// A detensored value is converted back by creating a new tensor from its
// element(s).
auto createNewTensorOp = builder.create<tensor::FromElementsOp>(
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index fe6433f6777ee..2a46819f1d168 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -72,11 +72,29 @@ class SPIRVAddressOfOpLayoutInfoDecoration
return success();
}
};
+
+template <typename OpT>
+class SPIRVPassThroughConversion : public OpConversionPattern<OpT> {
+public:
+ using OpConversionPattern<OpT>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(OpT op, typename OpT::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.updateRootInPlace(op,
+ [&] { op->setOperands(adaptor.getOperands()); });
+ return success();
+ }
+};
} // namespace
static void populateSPIRVLayoutInfoPatterns(RewritePatternSet &patterns) {
patterns.add<SPIRVGlobalVariableOpLayoutInfoDecoration,
- SPIRVAddressOfOpLayoutInfoDecoration>(patterns.getContext());
+ SPIRVAddressOfOpLayoutInfoDecoration,
+ SPIRVPassThroughConversion<spirv::AccessChainOp>,
+ SPIRVPassThroughConversion<spirv::LoadOp>,
+ SPIRVPassThroughConversion<spirv::StoreOp>>(
+ patterns.getContext());
}
namespace {
@@ -104,8 +122,17 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
return VulkanLayoutUtils::isLegalType(op.pointer().getType());
});
- // TODO: Change the type for the indirect users such as spv.Load, spv.Store,
- // spv.FunctionCall and so on.
+ // Change the type for the indirect users.
+ target.addDynamicallyLegalOp<spirv::AccessChainOp, spirv::LoadOp,
+ spirv::StoreOp>([&](Operation *op) {
+ for (Value operand : op->getOperands()) {
+ auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>();
+ if (addrOp && !VulkanLayoutUtils::isLegalType(addrOp.pointer().getType()))
+ return false;
+ }
+ return true;
+ });
+
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (auto spirvModule : module.getOps<spirv::ModuleOp>())
if (failed(applyFullConversion(spirvModule, target, frozenPatterns)))
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 955898dc6b569..002d746c1a81f 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -555,6 +555,10 @@ MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
//===----------------------------------------------------------------------===//
// ResultRange
+ResultRange::ResultRange(OpResult result)
+ : ResultRange(static_cast<detail::OpResultImpl *>(Value(result).getImpl()),
+ 1) {}
+
ResultRange::use_range ResultRange::getUses() const {
return {use_begin(), use_end()};
}
@@ -605,6 +609,10 @@ void ResultRange::UseIterator::skipOverResultsWithNoUsers() {
use = (*it).use_begin();
}
+void ResultRange::replaceAllUsesWith(Operation *op) {
+ replaceAllUsesWith(op->getResults());
+}
+
//===----------------------------------------------------------------------===//
// ValueRange
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index da4f125a18a56..b63ed594d537a 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -31,7 +31,8 @@ using namespace mlir::detail;
/// regions pre-filtered to avoid considering them for legalization.
static LogicalResult
computeConversionSet(iterator_range<Region::iterator> region,
- Location regionLoc, std::vector<Operation *> &toConvert,
+ Location regionLoc,
+ SmallVectorImpl<Operation *> &toConvert,
ConversionTarget *target = nullptr) {
if (llvm::empty(region))
return success();
@@ -114,16 +115,32 @@ struct ConversionValueMapping {
/// Lookup a mapped value within the map, or return null if a mapping does not
/// exist. If a mapping exists, this follows the same behavior of
/// `lookupOrDefault`.
- Value lookupOrNull(Value from) const;
+ Value lookupOrNull(Value from, Type desiredType = nullptr) const;
/// Map a value to the one provided.
- void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); }
+ void map(Value oldVal, Value newVal) {
+ LLVM_DEBUG({
+ for (Value it = newVal; it; it = mapping.lookupOrNull(it))
+ assert(it != oldVal && "inserting cyclic mapping");
+ });
+ mapping.map(oldVal, newVal);
+ }
+
+ /// Try to map a value to the one provided. Returns false if a transitive
+ /// mapping from the new value to the old value already exists, true if the
+ /// map was updated.
+ bool tryMap(Value oldVal, Value newVal);
/// Drop the last mapping for the given value.
void erase(Value value) { mapping.erase(value); }
/// Returns the inverse raw value mapping (without recursive query support).
- BlockAndValueMapping getInverse() const { return mapping.getInverse(); }
+ DenseMap<Value, SmallVector<Value>> getInverse() const {
+ DenseMap<Value, SmallVector<Value>> inverse;
+ for (auto &it : mapping.getValueMap())
+ inverse[it.second].push_back(it.first);
+ return inverse;
+ }
private:
/// Current value mappings.
@@ -158,9 +175,19 @@ Value ConversionValueMapping::lookupOrDefault(Value from,
return desiredValue ? desiredValue : from;
}
-Value ConversionValueMapping::lookupOrNull(Value from) const {
- Value result = lookupOrDefault(from);
- return result == from ? nullptr : result;
+Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
+ Value result = lookupOrDefault(from, desiredType);
+ if (result == from || (desiredType && result.getType() != desiredType))
+ return nullptr;
+ return result;
+}
+
+bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) {
+ for (Value it = newVal; it; it = mapping.lookupOrNull(it))
+ if (it == oldVal)
+ return false;
+ map(oldVal, newVal);
+ return true;
}
//===----------------------------------------------------------------------===//
@@ -170,10 +197,13 @@ namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
- RewriterState(unsigned numCreatedOps, unsigned numReplacements,
- unsigned numArgReplacements, unsigned numBlockActions,
- unsigned numIgnoredOperations, unsigned numRootUpdates)
- : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
+ RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
+ unsigned numReplacements, unsigned numArgReplacements,
+ unsigned numBlockActions, unsigned numIgnoredOperations,
+ unsigned numRootUpdates)
+ : numCreatedOps(numCreatedOps),
+ numUnresolvedMaterializations(numUnresolvedMaterializations),
+ numReplacements(numReplacements),
numArgReplacements(numArgReplacements),
numBlockActions(numBlockActions),
numIgnoredOperations(numIgnoredOperations),
@@ -182,6 +212,9 @@ struct RewriterState {
/// The current number of created operations.
unsigned numCreatedOps;
+ /// The current number of unresolved materializations.
+ unsigned numUnresolvedMaterializations;
+
/// The current number of replacements queued.
unsigned numReplacements;
@@ -321,8 +354,103 @@ struct BlockAction {
MergeInfo mergeInfo;
};
};
+
+//===----------------------------------------------------------------------===//
+// UnresolvedMaterialization
+
+/// This class represents an unresolved materialization, i.e. a materialization
+/// that was inserted during conversion that needs to be legalized at the end of
+/// the conversion process.
+class UnresolvedMaterialization {
+public:
+ /// The type of materialization.
+ enum Kind {
+ /// This materialization materializes a conversion for an illegal block
+ /// argument type, to a legal one.
+ Argument,
+
+ /// This materialization materializes a conversion from an illegal type to a
+ /// legal one.
+ Target
+ };
+
+ UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
+ TypeConverter *converter = nullptr,
+ Kind kind = Target, Type origOutputType = nullptr)
+ : op(op), converterAndKind(converter, kind),
+ origOutputType(origOutputType) {}
+
+ /// Return the temporary conversion operation inserted for this
+ /// materialization.
+ UnrealizedConversionCastOp getOp() const { return op; }
+
+ /// Return the type converter of this materialization (which may be null).
+ TypeConverter *getConverter() const { return converterAndKind.getPointer(); }
+
+ /// Return the kind of this materialization.
+ Kind getKind() const { return converterAndKind.getInt(); }
+
+ /// Set the kind of this materialization.
+ void setKind(Kind kind) { converterAndKind.setInt(kind); }
+
+ /// Return the original illegal output type of the input values.
+ Type getOrigOutputType() const { return origOutputType; }
+
+private:
+ /// The unresolved materialization operation created during conversion.
+ UnrealizedConversionCastOp op;
+
+ /// The corresponding type converter to use when resolving this
+ /// materialization, and the kind of this materialization.
+ llvm::PointerIntPair<TypeConverter *, 1, Kind> converterAndKind;
+
+ /// The original output type. This is only used for argument conversions.
+ Type origOutputType;
+};
} // end anonymous namespace
+/// Build an unresolved materialization operation given an output type and set
+/// of input operands.
+static Value buildUnresolvedMaterialization(
+ UnresolvedMaterialization::Kind kind, Block *insertBlock,
+ Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
+ Type origOutputType, TypeConverter *converter,
+ SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
+ // Avoid materializing an unnecessary cast.
+ if (inputs.size() == 1 && inputs.front().getType() == outputType)
+ return inputs.front();
+
+ // Create an unresolved materialization. We use a new OpBuilder to avoid
+ // tracking the materialization like we do for other operations.
+ OpBuilder builder(insertBlock, insertPt);
+ auto convertOp =
+ builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+ unresolvedMaterializations.emplace_back(convertOp, converter, kind,
+ origOutputType);
+ return convertOp.getResult(0);
+}
+static Value buildUnresolvedArgumentMaterialization(
+ PatternRewriter &rewriter, Location loc, ValueRange inputs,
+ Type origOutputType, Type outputType, TypeConverter *converter,
+ SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
+ return buildUnresolvedMaterialization(
+ UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(),
+ rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
+ converter, unresolvedMaterializations);
+}
+static Value buildUnresolvedTargetMaterialization(
+ Location loc, Value input, Type outputType, TypeConverter *converter,
+ SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
+ Block *insertBlock = input.getParentBlock();
+ Block::iterator insertPt = insertBlock->begin();
+ if (OpResult inputRes = input.dyn_cast<OpResult>())
+ insertPt = ++inputRes.getOwner()->getIterator();
+
+ return buildUnresolvedMaterialization(
+ UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
+ outputType, outputType, converter, unresolvedMaterializations);
+}
+
//===----------------------------------------------------------------------===//
// ArgConverter
//===----------------------------------------------------------------------===//
@@ -332,7 +460,11 @@ namespace {
/// types and extracting the block that contains the old illegal types to allow
/// for undoing pending rewrites in the case of failure.
struct ArgConverter {
- ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {}
+ ArgConverter(
+ PatternRewriter &rewriter,
+ SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
+ : rewriter(rewriter),
+ unresolvedMaterializations(unresolvedMaterializations) {}
/// This structure contains the information pertaining to an argument that has
/// been converted.
@@ -356,8 +488,8 @@ struct ArgConverter {
/// This structure contains information pertaining to a block that has had its
/// signature converted.
struct ConvertedBlockInfo {
- ConvertedBlockInfo(Block *origBlock, TypeConverter &converter)
- : origBlock(origBlock), converter(&converter) {}
+ ConvertedBlockInfo(Block *origBlock, TypeConverter *converter)
+ : origBlock(origBlock), converter(converter) {}
/// The original block that was requested to have its signature converted.
Block *origBlock;
@@ -420,7 +552,7 @@ struct ArgConverter {
/// block is returned containing the new arguments. Returns `block` if it did
/// not require conversion.
FailureOr<Block *>
- convertSignature(Block *block, TypeConverter &converter,
+ convertSignature(Block *block, TypeConverter *converter,
ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements);
@@ -431,7 +563,7 @@ struct ArgConverter {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
- Block *block, TypeConverter &converter,
+ Block *block, TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion,
ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements);
@@ -456,6 +588,9 @@ struct ArgConverter {
/// The pattern rewriter to use when materializing conversions.
PatternRewriter &rewriter;
+
+ /// An ordered set of unresolved materializations during conversion.
+ SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
};
} // end anonymous namespace
@@ -519,7 +654,7 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
// Handle the case of a 1->0 value mapping.
if (!argInfo) {
- if (Value newArg = mapping.lookupOrNull(origArg))
+ if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
origArg.replaceAllUsesWith(newArg);
continue;
}
@@ -529,8 +664,10 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
// If the argument is still used, replace it with the generated cast.
- if (!origArg.use_empty())
- origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue));
+ if (!origArg.use_empty()) {
+ origArg.replaceAllUsesWith(
+ mapping.lookupOrDefault(castValue, origArg.getType()));
+ }
}
}
}
@@ -545,31 +682,38 @@ LogicalResult ArgConverter::materializeLiveConversions(
// Process the remapping for each of the original arguments.
for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
- // FIXME: We should run the below checks even if the type conversion was
- // 1->N, but a lot of existing lowering rely on the block argument being
- // blindly replaced. Those usages should be updated, and this if should be
- // removed.
- if (blockInfo.argInfo[i])
+ // FIXME: We should run the below checks even if a type converter wasn't
+ // provided, but a lot of existing lowering rely on the block argument
+ // being blindly replaced. We should rework argument materialization to be
+ // more robust for temporary source materializations, update existing
+ // patterns, and remove these checks.
+ if (!blockInfo.converter && blockInfo.argInfo[i])
continue;
// If the type of this argument changed and the argument is still live, we
// need to materialize a conversion.
BlockArgument origArg = origBlock->getArgument(i);
- auto argReplacementValue = mapping.lookupOrDefault(origArg);
- bool isDroppedArg = argReplacementValue == origArg;
- if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg)
+ if (mapping.lookupOrNull(origArg, origArg.getType()))
continue;
Operation *liveUser = findLiveUser(origArg);
if (!liveUser)
continue;
- if (OpResult result = argReplacementValue.dyn_cast<OpResult>())
- rewriter.setInsertionPointAfter(result.getOwner());
- else
+ Value replacementValue = mapping.lookupOrDefault(origArg);
+ bool isDroppedArg = replacementValue == origArg;
+ if (isDroppedArg)
rewriter.setInsertionPointToStart(newBlock);
- Value newArg = blockInfo.converter->materializeSourceConversion(
- rewriter, origArg.getLoc(), origArg.getType(),
- isDroppedArg ? ValueRange() : ValueRange(argReplacementValue));
+ else
+ rewriter.setInsertionPointAfterValue(replacementValue);
+ Value newArg;
+ if (blockInfo.converter) {
+ newArg = blockInfo.converter->materializeSourceConversion(
+ rewriter, origArg.getLoc(), origArg.getType(),
+ isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+ assert((!newArg || newArg.getType() == origArg.getType()) &&
+ "materialization hook did not provide a value of the expected "
+ "type");
+ }
if (!newArg) {
InFlightDiagnostic diag =
emitError(origArg.getLoc())
@@ -577,7 +721,7 @@ LogicalResult ArgConverter::materializeLiveConversions(
<< " that remained live after conversion, type was "
<< origArg.getType();
if (!isDroppedArg)
- diag << ", with target type " << argReplacementValue.getType();
+ diag << ", with target type " << replacementValue.getType();
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
@@ -592,22 +736,26 @@ LogicalResult ArgConverter::materializeLiveConversions(
// Conversion
FailureOr<Block *> ArgConverter::convertSignature(
- Block *block, TypeConverter &converter, ConversionValueMapping &mapping,
+ Block *block, TypeConverter *converter, ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements) {
// Check if the block was already converted. If the block is detached,
// conservatively assume it is going to be deleted.
if (hasBeenConverted(block) || !block->getParent())
return block;
+ // If a converter wasn't provided, and the block wasn't already converted,
+ // there is nothing we can do.
+ if (!converter)
+ return failure();
// Try to convert the signature for the block with the provided converter.
- if (auto conversion = converter.convertBlockSignature(block))
+ if (auto conversion = converter->convertBlockSignature(block))
return applySignatureConversion(block, converter, *conversion, mapping,
argReplacements);
return failure();
}
Block *ArgConverter::applySignatureConversion(
- Block *block, TypeConverter &converter,
+ Block *block, TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion,
ConversionValueMapping &mapping,
SmallVectorImpl<BlockArgument> &argReplacements) {
@@ -649,26 +797,35 @@ Block *ArgConverter::applySignatureConversion(
continue;
}
- // Otherwise, this is a 1->1+ mapping. Call into the provided type converter
- // to pack the new values. For 1->1 mappings, if there is no materialization
- // provided, use the argument directly instead.
+ // Otherwise, this is a 1->1+ mapping.
auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
Value newArg;
// If this is a 1->1 mapping and the types of new and replacement arguments
// match (i.e. it's an identity map), then the argument is mapped to its
// original type.
- if (replArgs.size() == 1 && replArgs[0].getType() == origArg.getType())
+ // FIXME: We simply pass through the replacement argument if there wasn't a
+ // converter, which isn't great as it allows implicit type conversions to
+ // appear. We should properly restructure this code to handle cases where a
+ // converter isn't provided and also to properly handle the case where an
+ // argument materialization is actually a temporary source materialization
+ // (e.g. in the case of 1->N).
+ if (replArgs.size() == 1 &&
+ (!converter || replArgs[0].getType() == origArg.getType())) {
newArg = replArgs.front();
- else
- newArg = converter.materializeArgumentConversion(
- rewriter, origArg.getLoc(), origArg.getType(), replArgs);
+ } else {
+ Type origOutputType = origArg.getType();
- if (!newArg) {
- assert(replArgs.size() == 1 &&
- "couldn't materialize the result of 1->N conversion");
- newArg = replArgs.front();
+ // Legalize the argument output type.
+ Type outputType = origOutputType;
+ if (Type legalOutputType = converter->convertType(outputType))
+ outputType = legalOutputType;
+
+ newArg = buildUnresolvedArgumentMaterialization(
+ rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
+ converter, unresolvedMaterializations);
}
+
mapping.map(origArg, newArg);
argReplacements.push_back(origArg);
info.argInfo[i] =
@@ -702,7 +859,7 @@ namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl {
ConversionPatternRewriterImpl(PatternRewriter &rewriter)
- : argConverter(rewriter) {}
+ : argConverter(rewriter, unresolvedMaterializations) {}
/// Cleanup and destroy any generated rewrite operations. This method is
/// invoked when the conversion process fails.
@@ -730,13 +887,12 @@ struct ConversionPatternRewriterImpl {
/// "numActionsToKeep" actions remains.
void undoBlockActions(unsigned numActionsToKeep = 0);
- /// Remap the given operands to those with potentially
diff erent types. The
- /// provided type converter is used to ensure that the remapped types are
- /// legal. Returns success if the operands could be remapped, failure
- /// otherwise.
- LogicalResult remapValues(Location loc, PatternRewriter &rewriter,
- TypeConverter *converter,
- Operation::operand_range operands,
+ /// Remap the given values to those with potentially
diff erent types. Returns
+ /// success if the values could be remapped, failure otherwise. `valueDiagTag`
+ /// is the tag used when describing a value within a diagnostic, e.g.
+ /// "operand".
+ LogicalResult remapValues(StringRef valueDiagTag, Optional<Location> inputLoc,
+ PatternRewriter &rewriter, ValueRange values,
SmallVectorImpl<Value> &remapped);
/// Returns true if the given operation is ignored, and does not need to be
@@ -753,7 +909,7 @@ struct ConversionPatternRewriterImpl {
/// Convert the signature of the given block.
FailureOr<Block *> convertBlockSignature(
- Block *block, TypeConverter &converter,
+ Block *block, TypeConverter *converter,
TypeConverter::SignatureConversion *conversion = nullptr);
/// Apply a signature conversion on the given region, using `converter` for
@@ -817,7 +973,11 @@ struct ConversionPatternRewriterImpl {
ArgConverter argConverter;
/// Ordered vector of all of the newly created operations during conversion.
- std::vector<Operation *> createdOps;
+ SmallVector<Operation *> createdOps;
+
+ /// Ordered vector of all unresolved type conversion materializations during
+ /// conversion.
+ SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
/// Ordered map of requested operation replacements.
llvm::MapVector<Operation *, OpReplacement> replacements;
@@ -847,10 +1007,6 @@ struct ConversionPatternRewriterImpl {
/// 1->N conversion of some kind.
SmallVector<unsigned, 4> operationsWithChangedResults;
- /// A default type converter, used when block conversions do not have one
- /// explicitly provided.
- TypeConverter defaultTypeConverter;
-
/// The current type converter, or nullptr if no type converter is currently
/// active.
TypeConverter *currentTypeConverter = nullptr;
@@ -896,6 +1052,8 @@ void ConversionPatternRewriterImpl::discardRewrites() {
undoBlockActions();
// Remove any newly created ops.
+ for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
+ detachNestedAndErase(materialization.getOp());
for (auto *op : llvm::reverse(createdOps))
detachNestedAndErase(op);
}
@@ -904,7 +1062,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// Apply all of the rewrites replacements requested during conversion.
for (auto &repl : replacements) {
for (OpResult result : repl.first->getResults())
- if (Value newValue = mapping.lookupOrNull(result))
+ if (Value newValue = mapping.lookupOrNull(result, result.getType()))
result.replaceAllUsesWith(newValue);
// If this operation defines any regions, drop any pending argument
@@ -915,7 +1073,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// Apply all of the requested argument replacements.
for (BlockArgument arg : argReplacements) {
- Value repl = mapping.lookupOrDefault(arg);
+ Value repl = mapping.lookupOrNull(arg, arg.getType());
+ if (!repl)
+ continue;
+
if (repl.isa<BlockArgument>()) {
arg.replaceAllUsesWith(repl);
continue;
@@ -932,6 +1093,13 @@ void ConversionPatternRewriterImpl::applyRewrites() {
});
}
+ // Drop all of the unresolved materialization operations created during
+ // conversion.
+ for (auto &mat : unresolvedMaterializations) {
+ mat.getOp()->dropAllUses();
+ mat.getOp()->erase();
+ }
+
// In a second pass, erase all of the replaced operations in reverse. This
// allows processing nested operations before their parent region is
// destroyed. Because we process in reverse order, producers may be deleted
@@ -952,9 +1120,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// State Management
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
- return RewriterState(createdOps.size(), replacements.size(),
- argReplacements.size(), blockActions.size(),
- ignoredOps.size(), rootUpdates.size());
+ return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
+ replacements.size(), argReplacements.size(),
+ blockActions.size(), ignoredOps.size(),
+ rootUpdates.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -979,6 +1148,20 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
while (replacements.size() != state.numReplacements)
replacements.pop_back();
+ // Pop all of the newly inserted materializations.
+ while (unresolvedMaterializations.size() !=
+ state.numUnresolvedMaterializations) {
+ UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
+ UnrealizedConversionCastOp op = mat.getOp();
+
+ // If this was a target materialization, drop the mapping that was inserted.
+ if (mat.getKind() == UnresolvedMaterialization::Target) {
+ for (Value input : op->getOperands())
+ mapping.erase(input);
+ }
+ detachNestedAndErase(op);
+ }
+
// Pop all of the newly created operations.
while (createdOps.size() != state.numCreatedOps) {
detachNestedAndErase(createdOps.back());
@@ -1070,25 +1253,27 @@ void ConversionPatternRewriterImpl::undoBlockActions(
}
LogicalResult ConversionPatternRewriterImpl::remapValues(
- Location loc, PatternRewriter &rewriter, TypeConverter *converter,
- Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
- remapped.reserve(llvm::size(operands));
+ StringRef valueDiagTag, Optional<Location> inputLoc,
+ PatternRewriter &rewriter, ValueRange values,
+ SmallVectorImpl<Value> &remapped) {
+ remapped.reserve(llvm::size(values));
SmallVector<Type, 1> legalTypes;
- for (auto it : llvm::enumerate(operands)) {
+ for (auto it : llvm::enumerate(values)) {
Value operand = it.value();
Type origType = operand.getType();
// If a converter was provided, get the desired legal types for this
// operand.
Type desiredType;
- if (converter) {
+ if (currentTypeConverter) {
// If there is no legal conversion, fail to match this pattern.
legalTypes.clear();
- if (failed(converter->convertType(origType, legalTypes))) {
- return notifyMatchFailure(loc, [=](Diagnostic &diag) {
- diag << "unable to convert type for operand #" << it.index()
- << ", type was " << origType;
+ if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+ Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
+ return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
+ diag << "unable to convert type for " << valueDiagTag << " #"
+ << it.index() << ", type was " << origType;
});
}
// TODO: There currently isn't any mechanism to do 1->N type conversion
@@ -1108,18 +1293,13 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// Handle the case where the conversion was 1->1 and the new operand type
// isn't legal.
Type newOperandType = newOperand.getType();
- if (converter && desiredType && newOperandType != desiredType) {
- // Attempt to materialize a conversion for this new value.
- newOperand = converter->materializeTargetConversion(
- rewriter, loc, desiredType, newOperand);
- if (!newOperand) {
- return notifyMatchFailure(loc, [=](Diagnostic &diag) {
- diag << "unable to materialize a conversion for "
- "operand #"
- << it.index() << ", from " << newOperandType << " to "
- << desiredType;
- });
- }
+ if (currentTypeConverter && desiredType && newOperandType != desiredType) {
+ Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
+ Value castValue = buildUnresolvedTargetMaterialization(
+ operandLoc, newOperand, desiredType, currentTypeConverter,
+ unresolvedMaterializations);
+ mapping.map(mapping.lookupOrDefault(newOperand), castValue);
+ newOperand = castValue;
}
remapped.push_back(newOperand);
}
@@ -1148,7 +1328,7 @@ void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
// Type Conversion
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
- Block *block, TypeConverter &converter,
+ Block *block, TypeConverter *converter,
TypeConverter::SignatureConversion *conversion) {
FailureOr<Block *> result =
conversion ? argConverter.applySignatureConversion(
@@ -1167,11 +1347,8 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
Block *ConversionPatternRewriterImpl::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion,
TypeConverter *converter) {
- if (!region->empty()) {
- return *convertBlockSignature(®ion->front(),
- converter ? *converter : defaultTypeConverter,
- &conversion);
- }
+ if (!region->empty())
+ return *convertBlockSignature(®ion->front(), converter, &conversion);
return nullptr;
}
@@ -1186,7 +1363,7 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
return failure();
FailureOr<Block *> newEntry =
- convertBlockSignature(®ion->front(), converter, entryConversion);
+ convertBlockSignature(®ion->front(), &converter, entryConversion);
return newEntry;
}
@@ -1212,7 +1389,7 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
: const_cast<TypeConverter::SignatureConversion *>(
&blockConversions[blockIdx++]);
- if (failed(convertBlockSignature(&block, converter, blockConversion)))
+ if (failed(convertBlockSignature(&block, &converter, blockConversion)))
return failure();
}
return success();
@@ -1393,7 +1570,20 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
- return impl->mapping.lookupOrDefault(key);
+ SmallVector<Value> remappedValues;
+ if (failed(impl->remapValues("value", /*inputLoc=*/llvm::None, *this, key,
+ remappedValues)))
+ return nullptr;
+ return remappedValues.front();
+}
+
+LogicalResult
+ConversionPatternRewriter::getRemappedValues(ValueRange keys,
+ SmallVectorImpl<Value> &results) {
+ if (keys.empty())
+ return success();
+ return impl->remapValues("value", /*inputLoc=*/llvm::None, *this, keys,
+ results);
}
void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
@@ -1505,9 +1695,8 @@ ConversionPattern::matchAndRewrite(Operation *op,
// Remap the operands of the operation.
SmallVector<Value, 4> operands;
- if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter,
- getTypeConverter(), op->getOperands(),
- operands))) {
+ if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
+ op->getOperands(), operands))) {
return failure();
}
return matchAndRewrite(op, operands, dialectRewriter);
@@ -1800,7 +1989,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
auto &os = rewriter.getImpl().logger;
os.getOStream() << "\n";
os.startLine() << "* Pattern : '" << op->getName() << " -> (";
- llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs());
+ llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream());
os.getOStream() << ")' {\n";
os.indent();
});
@@ -1879,7 +2068,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions(
// directly.
if (auto *converter =
impl.argConverter.getConverter(action.block->getParent())) {
- if (failed(impl.convertBlockSignature(action.block, *converter))) {
+ if (failed(impl.convertBlockSignature(action.block, converter))) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
return failure();
@@ -2088,7 +2277,7 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
patternsByDepth.reserve(patterns.size());
for (const Pattern *pattern : patterns) {
- unsigned depth = 0;
+ unsigned depth = 1;
for (auto generatedOp : pattern->getGeneratedOps()) {
unsigned generatedOpDepth = computeOpLegalizationDepth(
generatedOp, minOpPatternDepth, legalizerPatterns);
@@ -2173,6 +2362,12 @@ struct OperationConverter {
legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl);
+ /// Legalize any unresolved type materializations.
+ LogicalResult legalizeUnresolvedMaterializations(
+ ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &rewriterImpl,
+ Optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
+
/// Legalize an operation result that was marked as "erased".
LogicalResult
legalizeErasedResult(Operation *op, OpResult result,
@@ -2180,12 +2375,11 @@ struct OperationConverter {
/// Legalize an operation result that was replaced with a value of a
diff erent
/// type.
- LogicalResult
- legalizeChangedResultType(Operation *op, OpResult result, Value newValue,
- TypeConverter *replConverter,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &rewriterImpl,
- const BlockAndValueMapping &inverseMapping);
+ LogicalResult legalizeChangedResultType(
+ Operation *op, OpResult result, Value newValue,
+ TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &rewriterImpl,
+ const DenseMap<Value, SmallVector<Value>> &inverseMapping);
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
@@ -2236,7 +2430,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
ConversionTarget &target = opLegalizer.getTarget();
// Compute the set of operations and blocks to convert.
- std::vector<Operation *> toConvert;
+ SmallVector<Operation *> toConvert;
for (auto *op : ops) {
toConvert.emplace_back(op);
for (auto ®ion : op->getRegions())
@@ -2277,17 +2471,16 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
+ Optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
-
- // Legalize converted block arguments.
- if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
+ if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
+ inverseMapping)) ||
+ failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();
if (rewriterImpl.operationsWithChangedResults.empty())
return success();
- Optional<BlockAndValueMapping> inverseMapping;
-
// Process requested operation replacements.
for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
i != e; ++i) {
@@ -2338,22 +2531,290 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
});
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
};
+ return rewriterImpl.argConverter.materializeLiveConversions(
+ rewriterImpl.mapping, rewriter, findLiveUser);
+}
+
+/// Replace the results of a materialization operation with the given values.
+static void
+replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
+ ResultRange matResults, ValueRange values,
+ DenseMap<Value, SmallVector<Value>> &inverseMapping) {
+ matResults.replaceAllUsesWith(values);
+
+ // For each of the materialization results, update the inverse mappings to
+ // point to the replacement values.
+ for (auto it : llvm::zip(matResults, values)) {
+ Value matResult, newValue;
+ std::tie(matResult, newValue) = it;
+ auto inverseMapIt = inverseMapping.find(matResult);
+ if (inverseMapIt == inverseMapping.end())
+ continue;
- // Materialize any necessary conversions for converted block arguments that
- // are still live.
- size_t numCreatedOps = rewriterImpl.createdOps.size();
- if (failed(rewriterImpl.argConverter.materializeLiveConversions(
- rewriterImpl.mapping, rewriter, findLiveUser)))
- return failure();
+ // Update the reverse mapping, or remove the mapping if we couldn't update
+ // it. Not being able to update signals that the mapping would have become
+ // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
+ // propagated through temporary materializations. We simply drop the
+ // mapping, and let the post-conversion replacement logic handle updating
+ // uses.
+ for (Value inverseMapVal : inverseMapIt->second)
+ if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
+ rewriterImpl.mapping.erase(inverseMapVal);
+ }
+}
- // Legalize any newly created operations during argument materialization.
- for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
- if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
- return rewriterImpl.createdOps[i]->emitError()
- << "failed to legalize conversion operation generated for block "
- "argument that remained live after conversion";
+/// Compute all of the unresolved materializations that will persist beyond the
+/// conversion process, and require inserting a proper user materialization for.
+static void computeNecessaryMaterializations(
+ DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps,
+ ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &rewriterImpl,
+ DenseMap<Value, SmallVector<Value>> &inverseMapping,
+ SetVector<UnresolvedMaterialization *> &necessaryMaterializations) {
+ auto isLive = [&](Value value) {
+ auto findFn = [&](Operation *user) {
+ auto matIt = materializationOps.find(user);
+ if (matIt != materializationOps.end())
+ return !necessaryMaterializations.count(matIt->second);
+ return rewriterImpl.isOpIgnored(user);
+ };
+ return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
+ };
+
+ llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
+ [&](Value invalidRoot, Value value, Type type) {
+ // Check to see if the input operation was remapped to a variant of the
+ // output.
+ Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
+ if (remappedValue.getType() == type && remappedValue != invalidRoot)
+ return remappedValue;
+
+ // Check to see if the input is a materialization operation that
+ // provides an inverse conversion. We just check blindly for
+ // UnrealizedConversionCastOp here, but it has no effect on correctness.
+ auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
+ if (inputCastOp && inputCastOp->getNumOperands() == 1)
+ return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
+ type);
+
+ return Value();
+ };
+
+ SetVector<UnresolvedMaterialization *> worklist;
+ for (auto &mat : rewriterImpl.unresolvedMaterializations) {
+ materializationOps.try_emplace(mat.getOp(), &mat);
+ worklist.insert(&mat);
+ }
+ while (!worklist.empty()) {
+ UnresolvedMaterialization *mat = worklist.pop_back_val();
+ UnrealizedConversionCastOp op = mat->getOp();
+
+ // We currently only handle target materializations here.
+ assert(op->getNumResults() == 1 && "unexpected materialization type");
+ OpResult opResult = op->getOpResult(0);
+ Type outputType = opResult.getType();
+ Operation::operand_range inputOperands = op.getOperands();
+
+ // Try to forward propagate operands for user conversion casts that result
+ // in the input types of the current cast.
+ for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
+ if (!castOp)
+ continue;
+ if (castOp->getResultTypes() == inputOperands.getTypes()) {
+ replaceMaterialization(rewriterImpl, opResult, inputOperands,
+ inverseMapping);
+ necessaryMaterializations.remove(materializationOps.lookup(user));
+ }
+ }
+
+ // Try to avoid materializing a resolved materialization if possible.
+ // Handle the case of a 1-1 materialization.
+ if (inputOperands.size() == 1) {
+ // Check to see if the input operation was remapped to a variant of the
+ // output.
+ Value remappedValue =
+ lookupRemappedValue(opResult, inputOperands[0], outputType);
+ if (remappedValue && remappedValue != opResult) {
+ replaceMaterialization(rewriterImpl, opResult, remappedValue,
+ inverseMapping);
+ necessaryMaterializations.remove(mat);
+ continue;
+ }
+ } else {
+ // TODO: Avoid materializing other types of conversions here.
+ }
+
+ // Check to see if this is an argument materialization.
+ auto isBlockArg = [](Value v) { return v.isa<BlockArgument>(); };
+ if (llvm::any_of(op->getOperands(), isBlockArg) ||
+ llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
+ mat->setKind(UnresolvedMaterialization::Argument);
+ }
+
+ // If the materialization does not have any live users, we don't need to
+ // generate a user materialization for it.
+ // FIXME: For argument materializations, we currently need to check if any
+ // of the inverse mapped values are used because some patterns expect blind
+ // value replacement even if the types
diff er in some cases. When those
+ // patterns are fixed, we can drop the argument special case here.
+ bool isMaterializationLive = isLive(opResult);
+ if (mat->getKind() == UnresolvedMaterialization::Argument)
+ isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
+ if (!isMaterializationLive)
+ continue;
+ if (!necessaryMaterializations.insert(mat))
+ continue;
+
+ // Reprocess input materializations to see if they have an updated status.
+ for (Value input : inputOperands) {
+ if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
+ if (auto *mat = materializationOps.lookup(parentOp))
+ worklist.insert(mat);
+ }
+ }
+ }
+}
+
+/// Legalize the given unresolved materialization. Returns success if the
+/// materialization was legalized, failure otherise.
+static LogicalResult legalizeUnresolvedMaterialization(
+ UnresolvedMaterialization &mat,
+ DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps,
+ ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &rewriterImpl,
+ DenseMap<Value, SmallVector<Value>> &inverseMapping) {
+ auto findLiveUser = [&](auto &&users) {
+ auto liveUserIt = llvm::find_if_not(
+ users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
+ return liveUserIt == users.end() ? nullptr : *liveUserIt;
+ };
+
+ llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
+ [&](Value value, Type type) {
+ // Check to see if the input operation was remapped to a variant of the
+ // output.
+ Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
+ if (remappedValue.getType() == type)
+ return remappedValue;
+ return Value();
+ };
+
+ UnrealizedConversionCastOp op = mat.getOp();
+ if (!rewriterImpl.ignoredOps.insert(op))
+ return success();
+
+ // We currently only handle target materializations here.
+ OpResult opResult = op->getOpResult(0);
+ Operation::operand_range inputOperands = op.getOperands();
+ Type outputType = opResult.getType();
+
+ // If any input to this materialization is another materialization, resolve
+ // the input first.
+ for (Value value : op->getOperands()) {
+ auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
+ if (!valueCast)
+ continue;
+
+ auto matIt = materializationOps.find(valueCast);
+ if (matIt != materializationOps.end())
+ if (failed(legalizeUnresolvedMaterialization(
+ *matIt->second, materializationOps, rewriter, rewriterImpl,
+ inverseMapping)))
+ return failure();
+ }
+
+ // Perform a last ditch attempt to avoid materializing a resolved
+ // materialization if possible.
+ // Handle the case of a 1-1 materialization.
+ if (inputOperands.size() == 1) {
+ // Check to see if the input operation was remapped to a variant of the
+ // output.
+ Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
+ if (remappedValue && remappedValue != opResult) {
+ replaceMaterialization(rewriterImpl, opResult, remappedValue,
+ inverseMapping);
+ return success();
+ }
+ } else {
+ // TODO: Avoid materializing other types of conversions here.
+ }
+
+ // Try to materialize the conversion.
+ if (TypeConverter *converter = mat.getConverter()) {
+ // FIXME: Determine a suitable insertion location when there are multiple
+ // inputs.
+ if (inputOperands.size() == 1)
+ rewriter.setInsertionPointAfterValue(inputOperands.front());
+ else
+ rewriter.setInsertionPoint(op);
+
+ Value newMaterialization;
+ switch (mat.getKind()) {
+ case UnresolvedMaterialization::Argument:
+ // Try to materialize an argument conversion.
+ // FIXME: The current argument materialization hook expects the original
+ // output type, even though it doesn't use that as the actual output type
+ // of the generated IR. The output type is just used as an indicator of
+ // the type of materialization to do. This behavior is really awkward in
+ // that it diverges from the behavior of the other hooks, and can be
+ // easily misunderstood. We should clean up the argument hooks to better
+ // represent the desired invariants we actually care about.
+ newMaterialization = converter->materializeArgumentConversion(
+ rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
+ if (newMaterialization)
+ break;
+
+ // If an argument materialization failed, fallback to trying a target
+ // materialization.
+ LLVM_FALLTHROUGH;
+ case UnresolvedMaterialization::Target:
+ newMaterialization = converter->materializeTargetConversion(
+ rewriter, op->getLoc(), outputType, inputOperands);
+ break;
+ default:
+ llvm_unreachable("unknown materialization kind");
+ }
+ if (newMaterialization) {
+ replaceMaterialization(rewriterImpl, opResult, newMaterialization,
+ inverseMapping);
+ return success();
}
}
+
+ InFlightDiagnostic diag = op->emitError()
+ << "failed to legalize unresolved materialization "
+ "from "
+ << inputOperands.getTypes() << " to " << outputType
+ << " that remained live after conversion";
+ if (Operation *liveUser = findLiveUser(op->getUsers())) {
+ diag.attachNote(liveUser->getLoc())
+ << "see existing live user here: " << *liveUser;
+ }
+ return failure();
+}
+
+LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
+ ConversionPatternRewriter &rewriter,
+ ConversionPatternRewriterImpl &rewriterImpl,
+ Optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
+ if (rewriterImpl.unresolvedMaterializations.empty())
+ return success();
+ inverseMapping = rewriterImpl.mapping.getInverse();
+
+ // As an initial step, compute all of the inserted materializations that we
+ // expect to persist beyond the conversion process.
+ DenseMap<Operation *, UnresolvedMaterialization *> materializationOps;
+ SetVector<UnresolvedMaterialization *> necessaryMaterializations;
+ computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
+ *inverseMapping, necessaryMaterializations);
+
+ // Once computed, legalize any necessary materializations.
+ for (auto *mat : necessaryMaterializations) {
+ if (failed(legalizeUnresolvedMaterialization(
+ *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
+ return failure();
+ }
return success();
}
@@ -2378,10 +2839,13 @@ LogicalResult OperationConverter::legalizeErasedResult(
/// Finds a user of the given value, or of any other value that the given value
/// replaced, that was not replaced in the conversion process.
-static Operation *
-findLiveUserOfReplaced(Value value, ConversionPatternRewriterImpl &rewriterImpl,
- const BlockAndValueMapping &inverseMapping) {
- do {
+static Operation *findLiveUserOfReplaced(
+ Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
+ const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
+ SmallVector<Value> worklist(1, initialValue);
+ while (!worklist.empty()) {
+ Value value = worklist.pop_back_val();
+
// Walk the users of this value to see if there are any live users that
// weren't replaced during conversion.
auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
@@ -2389,8 +2853,10 @@ findLiveUserOfReplaced(Value value, ConversionPatternRewriterImpl &rewriterImpl,
});
if (liveUserIt != value.user_end())
return *liveUserIt;
- value = inverseMapping.lookupOrNull(value);
- } while (value != nullptr);
+ auto mapIt = inverseMapping.find(value);
+ if (mapIt != inverseMapping.end())
+ worklist.append(mapIt->second);
+ }
return nullptr;
}
@@ -2398,30 +2864,14 @@ LogicalResult OperationConverter::legalizeChangedResultType(
Operation *op, OpResult result, Value newValue,
TypeConverter *replConverter, ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
- const BlockAndValueMapping &inverseMapping) {
+ const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
Operation *liveUser =
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
if (!liveUser)
return success();
- // If the replacement has a type converter, attempt to materialize a
- // conversion back to the original type.
- if (!replConverter) {
- // TODO: We should emit an error here, similarly to the case where the
- // result is replaced with null. Unfortunately a lot of existing
- // patterns rely on this behavior, so until those patterns are updated
- // we keep the legacy behavior here of just forwarding the new value.
- return success();
- }
-
- // Track the number of created operations so that new ones can be legalized.
- size_t numCreatedOps = rewriterImpl.createdOps.size();
-
- // Materialize a conversion for this live result value.
- Type resultType = result.getType();
- Value convertedValue = replConverter->materializeSourceConversion(
- rewriter, op->getLoc(), resultType, newValue);
- if (!convertedValue) {
+ // Functor used to emit a conversion error for a failed materialization.
+ auto emitConversionError = [&] {
InFlightDiagnostic diag = op->emitError()
<< "failed to materialize conversion for result #"
<< result.getResultNumber() << " of operation '"
@@ -2430,16 +2880,19 @@ LogicalResult OperationConverter::legalizeChangedResultType(
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
- }
+ };
- // Legalize all of the newly created conversion operations.
- for (int i : llvm::seq<int>(numCreatedOps, rewriterImpl.createdOps.size())) {
- if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) {
- return op->emitError("failed to legalize conversion operation generated ")
- << "for result #" << result.getResultNumber() << " of operation '"
- << op->getName() << "' that remained live after conversion";
- }
- }
+ // If the replacement has a type converter, attempt to materialize a
+ // conversion back to the original type.
+ if (!replConverter)
+ return emitConversionError();
+
+ // Materialize a conversion for this live result value.
+ Type resultType = result.getType();
+ Value convertedValue = replConverter->materializeSourceConversion(
+ rewriter, op->getLoc(), resultType, newValue);
+ if (!convertedValue)
+ return emitConversionError();
rewriterImpl.mapping.map(result, convertedValue);
return success();
diff --git a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
index 8cdf2222d866d..c21db125a318d 100644
--- a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir
@@ -322,7 +322,7 @@ func @fcmp(f32, f32) -> () {
func @index_vector(%arg0: vector<4xindex>) {
// CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64>
%0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
- // CHECK: %[[V:.*]] = llvm.add %1, %[[CST]] : vector<4xi64>
+ // CHECK: %[[V:.*]] = llvm.add %{{.*}}, %[[CST]] : vector<4xi64>
%1 = arith.addi %arg0, %0 : vector<4xindex>
std.return
}
diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
index de31cb60c0ddd..0c2aee27ca508 100644
--- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir
@@ -161,9 +161,10 @@ module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
} {
+// expected-error @+1 {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'vector<4xi64>', with target type 'vector<4xi32>'}}
func @int_vector4_invalid(%arg0: vector<4xi64>) {
// expected-error @+2 {{bitwidth emulation is not implemented yet on unsigned op}}
- // expected-error @+1 {{op requires the same type for all operands and results}}
+ // expected-note @+1 {{see existing live user here}}
%0 = arith.divui %arg0, %arg0: vector<4xi64>
return
}
diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
index ac7117e48f2e6..6b3fd20fa15f2 100644
--- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
+++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir
@@ -14,8 +14,7 @@ func @complex_create(%real: f32, %imag: f32) -> complex<f32> {
// CHECK-SAME: (%[[CPLX:.*]]: complex<f32>)
// CHECK-NEXT: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[CPLX]] : complex<f32> to !llvm.struct<(f32, f32)>
// CHECK-NEXT: %[[REAL:.*]] = llvm.extractvalue %[[CAST0]][0] : !llvm.struct<(f32, f32)>
-// CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[CPLX]] : complex<f32> to !llvm.struct<(f32, f32)>
-// CHECK-NEXT: %[[IMAG:.*]] = llvm.extractvalue %[[CAST1]][1] : !llvm.struct<(f32, f32)>
+// CHECK-NEXT: %[[IMAG:.*]] = llvm.extractvalue %[[CAST0]][1] : !llvm.struct<(f32, f32)>
func @complex_extract(%cplx: complex<f32>) {
%real1 = complex.re %cplx : complex<f32>
%imag1 = complex.im %cplx : complex<f32>
@@ -70,8 +69,8 @@ func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
%div = complex.div %lhs, %rhs : complex<f32>
return %div : complex<f32>
}
-// CHECK: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex<f32> to ![[C_TY:.*>]]
-// CHECK: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex<f32> to ![[C_TY]]
+// CHECK-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex<f32> to ![[C_TY]]
// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]]
// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]]
@@ -106,8 +105,8 @@ func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
%mul = complex.mul %lhs, %rhs : complex<f32>
return %mul : complex<f32>
}
-// CHECK: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex<f32> to ![[C_TY:.*>]]
-// CHECK: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex<f32> to ![[C_TY]]
+// CHECK-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex<f32> to ![[C_TY:.*>]]
+// CHECK-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex<f32> to ![[C_TY]]
// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]]
// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]]
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
index a5bf5737eb0f2..321e6190c066c 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir
@@ -4,8 +4,8 @@
// CHECK-LABEL: func @mixed_alloc(
// CHECK: %[[Marg:.*]]: index, %[[Narg:.*]]: index)
func @mixed_alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> {
-// CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]]
-// CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]]
+// CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]]
+// CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]]
// CHECK: %[[c42:.*]] = llvm.mlir.constant(42 : index) : i64
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK-NEXT: %[[st0:.*]] = llvm.mul %[[N]], %[[c42]] : i64
@@ -46,8 +46,8 @@ func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
// CHECK-LABEL: func @dynamic_alloc(
// CHECK: %[[Marg:.*]]: index, %[[Narg:.*]]: index)
func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
-// CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]]
-// CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]]
+// CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]]
+// CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]]
// CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK-NEXT: %[[sz:.*]] = llvm.mul %[[N]], %[[M]] : i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr<f32>
@@ -73,8 +73,8 @@ func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {
// CHECK-LABEL: func @dynamic_alloca
// CHECK: %[[Marg:.*]]: index, %[[Narg:.*]]: index)
func @dynamic_alloca(%arg0: index, %arg1: index) -> memref<?x?xf32> {
-// CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]]
-// CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]]
+// CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]]
+// CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]]
// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %[[N]], %[[M]] : i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr<f32>
@@ -119,7 +119,7 @@ func @dynamic_dealloc(%arg0: memref<?x?xf32>) {
// CHECK-LABEL: func @stdlib_aligned_alloc({{.*}})
// ALIGNED-ALLOC-LABEL: func @stdlib_aligned_alloc({{.*}})
func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> {
-// ALIGNED-ALLOC-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : i64
+// ALIGNED-ALLOC: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : i64
// ALIGNED-ALLOC-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : i64
// ALIGNED-ALLOC-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64
// ALIGNED-ALLOC-NEXT: %[[num_elems:.*]] = llvm.mlir.constant(576 : index) : i64
@@ -148,7 +148,7 @@ func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> {
%4 = memref.alloc() {alignment = 8} : memref<1024xvector<4xf32>>
// Bump the memref allocation size if its size is not a multiple of alignment.
// ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : index) : i64
- // ALIGNED-ALLOC-NEXT: llvm.mlir.constant(1 : index) : i64
+ // ALIGNED-ALLOC: llvm.mlir.constant(1 : index) : i64
// ALIGNED-ALLOC-NEXT: llvm.sub
// ALIGNED-ALLOC-NEXT: llvm.add
// ALIGNED-ALLOC-NEXT: llvm.urem
@@ -167,8 +167,8 @@ func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> {
// CHECK-LABEL: func @mixed_load(
// CHECK: %{{.*}}, %[[Iarg:.*]]: index, %[[Jarg:.*]]: index)
func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
-// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]]
-// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
+// CHECK-DAG: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]]
+// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
@@ -184,8 +184,8 @@ func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
// CHECK-LABEL: func @dynamic_load(
// CHECK: %{{.*}}, %[[Iarg:.*]]: index, %[[Jarg:.*]]: index)
func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
-// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]]
-// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
+// CHECK-DAG: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]]
+// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
@@ -201,8 +201,8 @@ func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-LABEL: func @prefetch
// CHECK: %{{.*}}, %[[Iarg:.*]]: index, %[[Jarg:.*]]: index)
func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
-// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]]
-// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
+// CHECK-DAG: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]]
+// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
@@ -231,8 +231,8 @@ func @prefetch(%A : memref<?x?xf32>, %i : index, %j : index) {
// CHECK-LABEL: func @dynamic_store
// CHECK: %{{.*}}, %[[Iarg:.*]]: index, %[[Jarg:.*]]: index
func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f32) {
-// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]]
-// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
+// CHECK-DAG: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]]
+// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
@@ -248,8 +248,8 @@ func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %val : f
// CHECK-LABEL: func @mixed_store
// CHECK: %{{.*}}, %[[Iarg:.*]]: index, %[[Jarg:.*]]: index
func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) {
-// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]]
-// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
+// CHECK-DAG: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]]
+// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
@@ -376,12 +376,12 @@ func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) {
// CHECK-LABEL: @memref_dim_with_dyn_index
// CHECK: %{{.*}}, %[[IDXarg:.*]]: index
func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) -> index {
+ // CHECK-DAG: %[[IDX:.*]] = builtin.unrealized_conversion_cast %[[IDXarg]]
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK-DAG: %[[SIZES:.*]] = llvm.extractvalue %{{.*}}[3] : ![[DESCR_TY:.*]]
// CHECK-DAG: %[[SIZES_PTR:.*]] = llvm.alloca %[[C1]] x !llvm.array<2 x i64> : (i64) -> !llvm.ptr<array<2 x i64>>
// CHECK-DAG: llvm.store %[[SIZES]], %[[SIZES_PTR]] : !llvm.ptr<array<2 x i64>>
- // CHECK-DAG: %[[IDX:.*]] = builtin.unrealized_conversion_cast %[[IDXarg]]
// CHECK-DAG: %[[RESULT_PTR:.*]] = llvm.getelementptr %[[SIZES_PTR]][%[[C0]], %[[IDX]]] : (!llvm.ptr<array<2 x i64>>, i64, i64) -> !llvm.ptr<i64>
// CHECK-DAG: %[[RESULT:.*]] = llvm.load %[[RESULT_PTR]] : !llvm.ptr<i64>
%result = memref.dim %arg, %idx : memref<3x?xf32>
@@ -433,12 +433,12 @@ func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index,
// CHECK-SAME: ([[OFFSETarg:%[a-z,0-9]+]]: index,
// CHECK-SAME: [[SIZE_0arg:%[a-z,0-9]+]]: index, [[SIZE_1arg:%[a-z,0-9]+]]: index,
// CHECK-SAME: [[STRIDE_0arg:%[a-z,0-9]+]]: index, [[STRIDE_1arg:%[a-z,0-9]+]]: index,
-// CHECK: [[INPUT:%.*]] = builtin.unrealized_conversion_cast
-// CHECK: [[OFFSET:%.*]] = builtin.unrealized_conversion_cast [[OFFSETarg]]
-// CHECK: [[SIZE_0:%.*]] = builtin.unrealized_conversion_cast [[SIZE_0arg]]
-// CHECK: [[SIZE_1:%.*]] = builtin.unrealized_conversion_cast [[SIZE_1arg]]
-// CHECK: [[STRIDE_0:%.*]] = builtin.unrealized_conversion_cast [[STRIDE_0arg]]
-// CHECK: [[STRIDE_1:%.*]] = builtin.unrealized_conversion_cast [[STRIDE_1arg]]
+// CHECK-DAG: [[OFFSET:%.*]] = builtin.unrealized_conversion_cast [[OFFSETarg]]
+// CHECK-DAG: [[SIZE_0:%.*]] = builtin.unrealized_conversion_cast [[SIZE_0arg]]
+// CHECK-DAG: [[SIZE_1:%.*]] = builtin.unrealized_conversion_cast [[SIZE_1arg]]
+// CHECK-DAG: [[STRIDE_0:%.*]] = builtin.unrealized_conversion_cast [[STRIDE_0arg]]
+// CHECK-DAG: [[STRIDE_1:%.*]] = builtin.unrealized_conversion_cast [[STRIDE_1arg]]
+// CHECK-DAG: [[INPUT:%.*]] = builtin.unrealized_conversion_cast
// CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY:!.*]]
// CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr<i8> to !llvm.ptr<ptr<f32>>
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 2b6532b4942d3..a26638a34151f 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -5,14 +5,14 @@
// CHECK-LABEL: func @view(
// CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index
func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
+ // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F:.*]]
+ // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0F:.*]]
+ // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1F:.*]]
// CHECK: llvm.mlir.constant(2048 : index) : i64
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
%0 = memref.alloc() : memref<2048xi8>
// Test two dynamic sizes.
- // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F:.*]]
- // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0F:.*]]
- // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1F:.*]]
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[SHIFTED_BASE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]][%[[ARG2]]] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
@@ -29,8 +29,6 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
%1 = memref.view %0[%arg2][%arg0, %arg1] : memref<2048xi8> to memref<?x?xf32>
// Test one dynamic size.
- // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F:.*]]
- // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1F:.*]]
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_PTR_2:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[SHIFTED_BASE_PTR_2:.*]] = llvm.getelementptr %[[BASE_PTR_2]][%[[ARG2]]] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
@@ -48,7 +46,6 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
%3 = memref.view %0[%arg2][%arg1] : memref<2048xi8> to memref<4x?xf32>
// Test static sizes.
- // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F:.*]]
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_PTR_3:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i8>, ptr<i8>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[SHIFTED_BASE_PTR_3:.*]] = llvm.getelementptr %[[BASE_PTR_3]][%[[ARG2]]] : (!llvm.ptr<i8>, i64) -> !llvm.ptr<i8>
@@ -71,7 +68,6 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<i8, 4>, ptr<i8, 4>, i64, array<1 x i64>, array<1 x i64>)>
%6 = memref.alloc() : memref<2048xi8, 4>
- // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F:.*]]
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32, 4>, ptr<f32, 4>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_PTR_4:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<i8, 4>, ptr<i8, 4>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[SHIFTED_BASE_PTR_4:.*]] = llvm.getelementptr %[[BASE_PTR_4]][%[[ARG2]]] : (!llvm.ptr<i8, 4>, i64) -> !llvm.ptr<i8, 4>
@@ -105,21 +101,13 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK32: %[[ARG1f:[a-zA-Z0-9]*]]: index,
// CHECK32: %[[ARG2f:.*]]: index)
func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) {
- // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
- // CHECK32: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
-
- // CHECK: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK: %[[ARG0c:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK: %[[ARG1c:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK32: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK32: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK32: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK32: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK32: %[[ARG0c:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK32: %[[ARG1c:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
+ // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
+ // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
+
+ // CHECK32-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK32-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
+ // CHECK32-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<f32> to !llvm.ptr<f32>
@@ -129,16 +117,16 @@ func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index,
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i64
+ // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
// CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i64
- // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i64
+ // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64
// CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1c]], %[[STRIDE1]] : i64
- // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64
+ // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0c]], %[[STRIDE0]] : i64
- // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<f32> to !llvm.ptr<f32>
@@ -148,16 +136,16 @@ func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index,
// CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i32
+ // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32
// CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i32
- // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i32
+ // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32
// CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32
// CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1c]], %[[STRIDE1]] : i32
- // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32
+ // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0c]], %[[STRIDE0]] : i32
- // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32
+ // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
%1 = memref.subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] :
memref<64x4xf32, offset: 0, strides: [4, 1]>
@@ -178,21 +166,12 @@ func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index,
// CHECK32: %[[ARG1f:[a-zA-Z0-9]*]]: index,
// CHECK32: %[[ARG2f:.*]]: index)
func @subview_non_zero_addrspace(%0 : memref<64x4xf32, offset: 0, strides: [4, 1], 3>, %arg0 : index, %arg1 : index, %arg2 : index) {
- // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
- // CHECK32: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
-
- // CHECK: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK: %[[ARG0c:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK: %[[ARG1c:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK32: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK32: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK32: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK32: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK32: %[[ARG0c:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK32: %[[ARG1c:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
+ // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
+ // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
+ // CHECK32-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK32-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
+ // CHECK32-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<f32, 3> to !llvm.ptr<f32, 3>
@@ -202,16 +181,16 @@ func @subview_non_zero_addrspace(%0 : memref<64x4xf32, offset: 0, strides: [4, 1
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i64
+ // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
// CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i64
- // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i64
+ // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64
// CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1c]], %[[STRIDE1]] : i64
- // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64
+ // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0c]], %[[STRIDE0]] : i64
- // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
+ // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<f32, 3> to !llvm.ptr<f32, 3>
@@ -221,16 +200,16 @@ func @subview_non_zero_addrspace(%0 : memref<64x4xf32, offset: 0, strides: [4, 1
// CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i32
+ // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32
// CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i32
- // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i32
+ // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32
// CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32
// CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1c]], %[[STRIDE1]] : i32
- // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32
+ // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0c]], %[[STRIDE0]] : i32
- // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32
+ // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32, 3>, ptr<f32, 3>, i32, array<2 x i32>, array<2 x i32>)>
%1 = memref.subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] :
memref<64x4xf32, offset: 0, strides: [4, 1], 3>
@@ -251,17 +230,12 @@ func @subview_non_zero_addrspace(%0 : memref<64x4xf32, offset: 0, strides: [4, 1
// CHECK32-SAME: %[[ARG1f:[a-zA-Z0-9]*]]: index
// CHECK32-SAME: %[[ARG2f:[a-zA-Z0-9]*]]: index
func @subview_const_size(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) {
- // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
- // CHECK32: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
-
- // CHECK: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK32: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK32: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK32: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK32: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
+ // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
+ // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
+ // CHECK32-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK32-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
+ // CHECK32-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<f32> to !llvm.ptr<f32>
@@ -271,17 +245,17 @@ func @subview_const_size(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i64
+ // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
// CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i64
- // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i64
+ // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64
// CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64)
- // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1b]], %[[STRIDE1]] : i64
+ // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
- // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0b]], %[[STRIDE0]] : i64
+ // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
@@ -292,17 +266,17 @@ func @subview_const_size(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg
// CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i32
+ // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32
// CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i32
- // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i32
+ // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32
// CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32
// CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[CST2:.*]] = llvm.mlir.constant(2 : i64)
- // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1b]], %[[STRIDE1]] : i32
+ // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32
// CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
- // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0b]], %[[STRIDE0]] : i32
+ // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32
// CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
%1 = memref.subview %0[%arg0, %arg1][4, 2][%arg0, %arg1] :
@@ -324,17 +298,12 @@ func @subview_const_size(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg
// CHECK32-SAME: %[[ARG1f:[a-zA-Z0-9]*]]: index
// CHECK32-SAME: %[[ARG2f:[a-zA-Z0-9]*]]: index
func @subview_const_stride(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) {
- // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
- // CHECK32: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
-
- // CHECK: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK32: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK32: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK32: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
- // CHECK32: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
+ // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
+ // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
+ // CHECK32-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK32-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
+ // CHECK32-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<f32> to !llvm.ptr<f32>
@@ -344,16 +313,16 @@ func @subview_const_stride(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %a
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i64
+ // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64
// CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i64
- // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i64
+ // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64
// CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64)
- // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST2]], %[[DESC3]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
- // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<f32> to !llvm.ptr<f32>
@@ -363,16 +332,16 @@ func @subview_const_stride(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %a
// CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
- // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i32
+ // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32
// CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i32
- // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i32
+ // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32
// CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32
// CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[CST2:.*]] = llvm.mlir.constant(2 : i64)
- // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[CST2]], %[[DESC3]][4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
- // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
+ // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
%1 = memref.subview %0[%arg0, %arg1][%arg0, %arg1][1, 2] :
memref<64x4xf32, offset: 0, strides: [4, 1]>
@@ -425,10 +394,10 @@ func @subview_const_stride_and_offset(%0 : memref<64x4xf32, offset: 0, strides:
// CHECK32: %[[ARG1f:[a-zA-Z0-9]*]]: index,
// CHECK32: %[[ARG2f:.*]]: index)
func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) {
- // CHECK32: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
- // CHECK32: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
- // CHECK32: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2f]]
- // CHECK32: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
+ // CHECK32-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK32-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]]
+ // CHECK32-DAG: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2f]]
+ // CHECK32-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]]
// CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i32, array<2 x i32>, array<2 x i32>)>
// CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr<f32> to !llvm.ptr<f32>
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 7896ac527b5d0..90492ff14247d 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -17,13 +17,13 @@ module attributes {
// CHECK-LABEL: @load_store_zero_rank_float
func @load_store_zero_rank_float(%arg0: memref<f32>, %arg1: memref<f32>) {
// CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32> to !spv.ptr<!spv.struct<(!spv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+ // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32> to !spv.ptr<!spv.struct<(!spv.array<1 x f32, stride=4> [0])>, StorageBuffer>
// CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32
// CHECK: spv.AccessChain [[ARG0]][
// CHECK-SAME: [[ZERO1]], [[ZERO1]]
// CHECK-SAME: ] :
// CHECK: spv.Load "StorageBuffer" %{{.*}} : f32
%0 = memref.load %arg0[] : memref<f32>
- // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32> to !spv.ptr<!spv.struct<(!spv.array<1 x f32, stride=4> [0])>, StorageBuffer>
// CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32
// CHECK: spv.AccessChain [[ARG1]][
// CHECK-SAME: [[ZERO2]], [[ZERO2]]
@@ -36,13 +36,13 @@ func @load_store_zero_rank_float(%arg0: memref<f32>, %arg1: memref<f32>) {
// CHECK-LABEL: @load_store_zero_rank_int
func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
// CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32> to !spv.ptr<!spv.struct<(!spv.array<1 x i32, stride=4> [0])>, StorageBuffer>
+ // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32> to !spv.ptr<!spv.struct<(!spv.array<1 x i32, stride=4> [0])>, StorageBuffer>
// CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32
// CHECK: spv.AccessChain [[ARG0]][
// CHECK-SAME: [[ZERO1]], [[ZERO1]]
// CHECK-SAME: ] :
// CHECK: spv.Load "StorageBuffer" %{{.*}} : i32
%0 = memref.load %arg0[] : memref<i32>
- // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32> to !spv.ptr<!spv.struct<(!spv.array<1 x i32, stride=4> [0])>, StorageBuffer>
// CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32
// CHECK: spv.AccessChain [[ARG1]][
// CHECK-SAME: [[ZERO2]], [[ZERO2]]
@@ -55,10 +55,10 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
// CHECK-LABEL: func @load_store_unknown_dim
func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?xi32>) {
// CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32> to !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+ // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32> to !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
// CHECK: %[[AC0:.+]] = spv.AccessChain %[[SRC]]
// CHECK: spv.Load "StorageBuffer" %[[AC0]]
%0 = memref.load %source[%i] : memref<?xi32>
- // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref<?xi32> to !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
// CHECK: %[[AC1:.+]] = spv.AccessChain %[[DST]]
// CHECK: spv.Store "StorageBuffer" %[[AC1]]
memref.store %0, %dest[%i]: memref<?xi32>
@@ -68,8 +68,8 @@ func @load_store_unknown_dim(%i: index, %source: memref<?xi32>, %dest: memref<?x
// CHECK-LABEL: func @load_i1
// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1>, %[[IDX:.+]]: index)
func @load_i1(%src: memref<4xi1>, %i : index) -> i1 {
- // CHECK: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
- // CHECK: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+ // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+ // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
// CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
// CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
@@ -89,8 +89,8 @@ func @load_i1(%src: memref<4xi1>, %i : index) -> i1 {
// CHECK-SAME: %[[IDX:.+]]: index
func @store_i1(%dst: memref<4xi1>, %i: index) {
%true = arith.constant true
- // CHECK: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
- // CHECK: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
+ // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1> to !spv.ptr<!spv.struct<(!spv.array<4 x i8, stride=1> [0])>, StorageBuffer>
+ // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
// CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32
// CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32
// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
@@ -237,8 +237,8 @@ func @store_i1(%arg0: memref<i1>, %value: i1) {
// CHECK-LABEL: @store_i8
// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
func @store_i8(%arg0: memref<i8>, %value: i8) {
- // CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
- // CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+ // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
+ // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
// CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32
// CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32
@@ -261,9 +261,9 @@ func @store_i8(%arg0: memref<i8>, %value: i8) {
// CHECK-LABEL: @store_i16
// CHECK: (%[[ARG0:.+]]: memref<10xi16>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i16)
func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
- // CHECK: %[[ARG2_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : i16 to i32
- // CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
- // CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
+ // CHECK-DAG: %[[ARG2_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : i16 to i32
+ // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+ // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
// CHECK: %[[OFFSET:.+]] = spv.Constant 0 : i32
// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
@@ -350,8 +350,8 @@ func @load_i16(%arg0: memref<i16>) {
// CHECK-LABEL: @store_i8
// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8)
func @store_i8(%arg0: memref<i8>, %value: i8) {
- // CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
- // CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+ // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
+ // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
// CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32
// CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index e3ed2a2479c15..d38c7dcf85449 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -6,7 +6,9 @@ func @master_block_arg() {
omp.master {
// CHECK-NEXT: ^[[BB0:.*]](%[[ARG1:.*]]: i64, %[[ARG2:.*]]: i64):
^bb0(%arg1: index, %arg2: index):
- // CHECK-NEXT: "test.payload"(%[[ARG1]], %[[ARG2]]) : (i64, i64) -> ()
+ // CHECK-DAG: %[[CAST_ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : i64 to index
+ // CHECK-DAG: %[[CAST_ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : i64 to index
+ // CHECK-NEXT: "test.payload"(%[[CAST_ARG1]], %[[CAST_ARG2]]) : (index, index) -> ()
"test.payload"(%arg1, %arg2) : (index, index) -> ()
omp.terminator
}
@@ -50,7 +52,9 @@ func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: inde
// CHECK: omp.wsloop (%[[ARG6:.*]], %[[ARG7:.*]]) : i64 = (%[[ARG0]], %[[ARG1]]) to (%[[ARG2]], %[[ARG3]]) step (%[[ARG4]], %[[ARG5]]) {
"omp.wsloop"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) ( {
^bb0(%arg6: index, %arg7: index): // no predecessors
- // CHECK: "test.payload"(%[[ARG6]], %[[ARG7]]) : (i64, i64) -> ()
+ // CHECK-DAG: %[[CAST_ARG6:.*]] = builtin.unrealized_conversion_cast %[[ARG6]] : i64 to index
+ // CHECK-DAG: %[[CAST_ARG7:.*]] = builtin.unrealized_conversion_cast %[[ARG7]] : i64 to index
+ // CHECK: "test.payload"(%[[CAST_ARG6]], %[[CAST_ARG7]]) : (index, index) -> ()
"test.payload"(%arg6, %arg7) : (index, index) -> ()
omp.yield
}) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0, 0]> : vector<10xi32>} : (index, index, index, index, index, index) -> ()
diff --git a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
index 2f9c52351c696..4c6b0da4b1dac 100644
--- a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
@@ -148,11 +148,13 @@ func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes { ll
// Match the construction of the unranked descriptor.
// CHECK: %[[ALLOCA:.*]] = llvm.alloca
// CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]]
+ // CHECK: %[[RANK:.*]] = llvm.mlir.constant(2 : i64)
// CHECK: %[[DESC_0:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
- // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_0]][0]
+ // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[DESC_0]][0]
// CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[MEMORY]], %[[DESC_1]][1]
%0 = memref.cast %arg0: memref<4x3xf32> to memref<*xf32>
+
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index)
// CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index)
// These sizes may depend on the data layout, not matching specific values.
@@ -160,17 +162,14 @@ func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes { ll
// CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant
// CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
- // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[DESC_2]][0] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
// CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
// CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
// CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
// CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false)
// CHECK: %[[ALLOCATED:.*]] = llvm.call @malloc(%[[ALLOC_SIZE]])
- // CHECK: %[[SOURCE:.*]] = llvm.extractvalue %[[DESC_2]][1]
- // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[SOURCE]], %[[ALLOC_SIZE]], %[[FALSE]])
+ // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[MEMORY]], %[[ALLOC_SIZE]], %[[FALSE]])
// CHECK: %[[NEW_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
- // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[DESC_2]][0] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[NEW_DESC]][0]
// CHECK: %[[NEW_DESC_2:.*]] = llvm.insertvalue %[[ALLOCATED]], %[[NEW_DESC_1]][1]
// CHECK: llvm.return %[[NEW_DESC_2]]
@@ -224,15 +223,13 @@ func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*x
// convention requires the caller to free them and the caller cannot know
// whether they are the same value or not.
// CHECK: %[[ALLOCATED_1:.*]] = llvm.call @malloc(%{{.*}})
- // CHECK: %[[SOURCE_1:.*]] = llvm.extractvalue %[[DESC_2]][1]
- // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[SOURCE_1]], %{{.*}}, %[[FALSE:.*]])
+ // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[MEMORY]], %{{.*}}, %[[FALSE:.*]])
// CHECK: %[[RES_1:.*]] = llvm.mlir.undef
// CHECK: %[[RES_11:.*]] = llvm.insertvalue %{{.*}}, %[[RES_1]][0]
// CHECK: %[[RES_12:.*]] = llvm.insertvalue %[[ALLOCATED_1]], %[[RES_11]][1]
// CHECK: %[[ALLOCATED_2:.*]] = llvm.call @malloc(%{{.*}})
- // CHECK: %[[SOURCE_2:.*]] = llvm.extractvalue %[[DESC_2]][1]
- // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[SOURCE_2]], %{{.*}}, %[[FALSE]])
+ // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[MEMORY]], %{{.*}}, %[[FALSE]])
// CHECK: %[[RES_2:.*]] = llvm.mlir.undef
// CHECK: %[[RES_21:.*]] = llvm.insertvalue %{{.*}}, %[[RES_2]][0]
// CHECK: %[[RES_22:.*]] = llvm.insertvalue %[[ALLOCATED_2]], %[[RES_21]][1]
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index f0a6d2b5c4c32..36a6d793e7214 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -173,9 +173,10 @@ module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
} {
+// expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
func @int_vector4_invalid(%arg0: vector<4xi64>) {
- // expected-error @+2 {{bitwidth emulation is not implemented yet on unsigned op}}
- // expected-error @+1 {{op requires the same type for all operands and results}}
+ // expected-error at below {{bitwidth emulation is not implemented yet on unsigned op}}
+ // expected-note at below {{see existing live user here}}
%0 = arith.divui %arg0, %arg0: vector<4xi64>
return
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 887979ea19be8..1951cf8339b76 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -111,8 +111,8 @@ func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x2xind
}
// CHECK-LABEL: @broadcast_vec2d_from_index_vec1d(
// CHECK-SAME: %[[A:.*]]: vector<2xindex>)
-// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<3x2xindex>
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64>
+// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<3x2xindex>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xindex> to !llvm.array<3 x vector<2xi64>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<2xi64>>
@@ -128,14 +128,14 @@ func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> {
// CHECK-LABEL: @broadcast_vec3d_from_vec1d(
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
-// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[A]], %[[T4]][2] : !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T5]], %[[T6]][0] : !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T5]], %[[T7]][1] : !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T5]], %[[T8]][2] : !llvm.array<4 x array<3 x vector<2xf32>>>
@@ -152,16 +152,13 @@ func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
}
// CHECK-LABEL: @broadcast_vec3d_from_vec2d(
// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>)
-// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
+// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<2xf32>>>
-// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T4]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<2xf32>>>
-// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T5]][2] : !llvm.array<4 x array<3 x vector<2xf32>>>
-// CHECK: %[[T8:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T8]], %[[T7]][3] : !llvm.array<4 x array<3 x vector<2xf32>>>
+// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<2xf32>>>
+// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T1]], %[[T5]][2] : !llvm.array<4 x array<3 x vector<2xf32>>>
+// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T1]], %[[T7]][3] : !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<4 x array<3 x vector<2xf32>>> to vector<4x3x2xf32>
// CHECK: return %[[T10]] : vector<4x3x2xf32>
@@ -187,10 +184,10 @@ func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
}
// CHECK-LABEL: @broadcast_stretch_at_start(
// CHECK-SAME: %[[A:.*]]: vector<1x4xf32>)
-// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x4xf32> to !llvm.array<1 x vector<4xf32>>
-// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<4xf32>>
+// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32>
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x4xf32> to !llvm.array<3 x vector<4xf32>>
+// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<4xf32>>
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][0] : !llvm.array<3 x vector<4xf32>>
// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T3]], %[[T5]][1] : !llvm.array<3 x vector<4xf32>>
// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T3]], %[[T6]][2] : !llvm.array<3 x vector<4xf32>>
@@ -205,28 +202,25 @@ func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
}
// CHECK-LABEL: @broadcast_stretch_at_end(
// CHECK-SAME: %[[A:.*]]: vector<4x1xf32>)
-// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>>
+// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32>
+// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3xf32> to !llvm.array<4 x vector<3xf32>>
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<4 x vector<1xf32>>
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T5:.*]] = llvm.extractelement %[[T3]]{{\[}}%[[T4]] : i64] : vector<1xf32>
// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
-// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3xf32> to !llvm.array<4 x vector<3xf32>>
// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][0] : !llvm.array<4 x vector<3xf32>>
-// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>>
-// CHECK: %[[T10:.*]] = llvm.extractvalue %[[T9]][1] : !llvm.array<4 x vector<1xf32>>
+// CHECK: %[[T10:.*]] = llvm.extractvalue %[[T2]][1] : !llvm.array<4 x vector<1xf32>>
// CHECK: %[[T11:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T12:.*]] = llvm.extractelement %[[T10]]{{\[}}%[[T11]] : i64] : vector<1xf32>
// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<3xf32>
// CHECK: %[[T14:.*]] = llvm.insertvalue %[[T13]], %[[T8]][1] : !llvm.array<4 x vector<3xf32>>
-// CHECK: %[[T15:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>>
-// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T15]][2] : !llvm.array<4 x vector<1xf32>>
+// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T2]][2] : !llvm.array<4 x vector<1xf32>>
// CHECK: %[[T17:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T18:.*]] = llvm.extractelement %[[T16]]{{\[}}%[[T17]] : i64] : vector<1xf32>
// CHECK: %[[T19:.*]] = splat %[[T18]] : vector<3xf32>
// CHECK: %[[T20:.*]] = llvm.insertvalue %[[T19]], %[[T14]][2] : !llvm.array<4 x vector<3xf32>>
-// CHECK: %[[T21:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>>
-// CHECK: %[[T22:.*]] = llvm.extractvalue %[[T21]][3] : !llvm.array<4 x vector<1xf32>>
+// CHECK: %[[T22:.*]] = llvm.extractvalue %[[T2]][3] : !llvm.array<4 x vector<1xf32>>
// CHECK: %[[T23:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T24:.*]] = llvm.extractelement %[[T22]]{{\[}}%[[T23]] : i64] : vector<1xf32>
// CHECK: %[[T25:.*]] = splat %[[T24]] : vector<3xf32>
@@ -242,34 +236,28 @@ func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32>
}
// CHECK-LABEL: @broadcast_stretch_in_middle(
// CHECK-SAME: %[[A:.*]]: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
+// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>>
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
-// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>>
-// CHECK: %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
+// CHECK: %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T4]], %[[T6]][1] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T4]], %[[T7]][2] : !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T10:.*]] = llvm.insertvalue %[[T8]], %[[T9]][0] : !llvm.array<4 x array<3 x vector<2xf32>>>
-// CHECK: %[[T11:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>>
-// CHECK: %[[T12:.*]] = llvm.extractvalue %[[T11]][1, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
-// CHECK: %[[T13:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T14:.*]] = llvm.insertvalue %[[T12]], %[[T13]][0] : !llvm.array<3 x vector<2xf32>>
+// CHECK: %[[T12:.*]] = llvm.extractvalue %[[T3]][1, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
+// CHECK: %[[T14:.*]] = llvm.insertvalue %[[T12]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T15:.*]] = llvm.insertvalue %[[T12]], %[[T14]][1] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T16:.*]] = llvm.insertvalue %[[T12]], %[[T15]][2] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T17:.*]] = llvm.insertvalue %[[T16]], %[[T10]][1] : !llvm.array<4 x array<3 x vector<2xf32>>>
-// CHECK: %[[T18:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>>
-// CHECK: %[[T19:.*]] = llvm.extractvalue %[[T18]][2, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
-// CHECK: %[[T20:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T21:.*]] = llvm.insertvalue %[[T19]], %[[T20]][0] : !llvm.array<3 x vector<2xf32>>
+// CHECK: %[[T19:.*]] = llvm.extractvalue %[[T3]][2, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
+// CHECK: %[[T21:.*]] = llvm.insertvalue %[[T19]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T22:.*]] = llvm.insertvalue %[[T19]], %[[T21]][1] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T23:.*]] = llvm.insertvalue %[[T19]], %[[T22]][2] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T24:.*]] = llvm.insertvalue %[[T23]], %[[T17]][2] : !llvm.array<4 x array<3 x vector<2xf32>>>
-// CHECK: %[[T25:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>>
-// CHECK: %[[T26:.*]] = llvm.extractvalue %[[T25]][3, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
-// CHECK: %[[T27:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T28:.*]] = llvm.insertvalue %[[T26]], %[[T27]][0] : !llvm.array<3 x vector<2xf32>>
+// CHECK: %[[T26:.*]] = llvm.extractvalue %[[T3]][3, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
+// CHECK: %[[T28:.*]] = llvm.insertvalue %[[T26]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T29:.*]] = llvm.insertvalue %[[T26]], %[[T28]][1] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T30:.*]] = llvm.insertvalue %[[T26]], %[[T29]][2] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T31:.*]] = llvm.insertvalue %[[T30]], %[[T24]][3] : !llvm.array<4 x array<3 x vector<2xf32>>>
@@ -286,11 +274,11 @@ func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32
// CHECK-SAME: %[[A:.*]]: vector<2xf32>,
// CHECK-SAME: %[[B:.*]]: vector<3xf32>)
// CHECK: %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T3:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T4:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T3]] : i64] : vector<2xf32>
// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xf32>
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
-// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][0] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T9:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[T10:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T9]] : i64] : vector<2xf32>
@@ -309,15 +297,15 @@ func @outerproduct_index(%arg0: vector<2xindex>, %arg1: vector<3xindex>) -> vect
// CHECK-LABEL: @outerproduct_index(
// CHECK-SAME: %[[A:.*]]: vector<2xindex>,
// CHECK-SAME: %[[B:.*]]: vector<3xindex>)
-// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<2x3xindex>
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64>
+// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<2x3xindex>
+// CHECK: %[[T8:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<2x3xindex> to !llvm.array<2 x vector<3xi64>>
// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T3:.*]] = llvm.extractelement %[[T1]]{{\[}}%[[T2]] : i64] : vector<2xi64>
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : i64 to index
// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xindex>
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xindex>
// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T6]] : vector<3xindex> to vector<3xi64>
-// CHECK: %[[T8:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<2x3xindex> to !llvm.array<2 x vector<3xi64>>
// CHECK: %{{.*}} = llvm.insertvalue %[[T7]], %[[T8]][0] : !llvm.array<2 x vector<3xi64>>
// -----
@@ -330,20 +318,19 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
// CHECK-SAME: %[[A:.*]]: vector<2xf32>,
// CHECK-SAME: %[[B:.*]]: vector<3xf32>,
// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>) -> vector<2x3xf32>
+// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[C]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
+// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[T5:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T4]] : i64] : vector<2xf32>
// CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32>
-// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[C]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T9:.*]] = "llvm.intr.fmuladd"(%[[T6]], %[[B]], %[[T8]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][0] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T12:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[T13:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T12]] : i64] : vector<2xf32>
// CHECK: %[[T14:.*]] = splat %[[T13]] : vector<3xf32>
-// CHECK: %[[T15:.*]] = builtin.unrealized_conversion_cast %[[C]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T15]][1] : !llvm.array<2 x vector<3xf32>>
+// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T17:.*]] = "llvm.intr.fmuladd"(%[[T14]], %[[B]], %[[T16]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32>
// CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T11]][1] : !llvm.array<2 x vector<3xf32>>
// CHECK: %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
@@ -370,8 +357,8 @@ func @shuffle_1D_index_direct(%arg0: vector<2xindex>, %arg1: vector<2xindex>) ->
// CHECK-LABEL: @shuffle_1D_index_direct(
// CHECK-SAME: %[[A:.*]]: vector<2xindex>,
// CHECK-SAME: %[[B:.*]]: vector<2xindex>)
-// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64>
-// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2xindex> to vector<2xi64>
+// CHECK-DAG: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64>
+// CHECK-DAG: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2xindex> to vector<2xi64>
// CHECK: %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T1]] [0, 1] : vector<2xi64>, vector<2xi64>
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2xi64> to vector<2xindex>
// CHECK: return %[[T3]] : vector<2xindex>
@@ -417,8 +404,8 @@ func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
// CHECK-LABEL: @shuffle_2D(
// CHECK-SAME: %[[A:.*]]: vector<1x4xf32>,
// CHECK-SAME: %[[B:.*]]: vector<2x4xf32>)
-// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x4xf32> to !llvm.array<1 x vector<4xf32>>
-// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
+// CHECK-DAG: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x4xf32> to !llvm.array<1 x vector<4xf32>>
+// CHECK-DAG: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: %[[u0:.*]] = llvm.mlir.undef : !llvm.array<3 x vector<4xf32>>
// CHECK: %[[e1:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<2 x vector<4xf32>>
// CHECK: %[[i1:.*]] = llvm.insertvalue %[[e1]], %[[u0]][0] : !llvm.array<3 x vector<4xf32>>
@@ -533,8 +520,8 @@ func @insert_index_element_into_vec_1d(%arg0: index, %arg1: vector<4xindex>) ->
// CHECK-LABEL: @insert_index_element_into_vec_1d(
// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: vector<4xindex>)
-// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64
-// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<4xindex> to vector<4xi64>
+// CHECK-DAG: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64
+// CHECK-DAG: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<4xindex> to vector<4xi64>
// CHECK: %[[T3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[T4:.*]] = llvm.insertelement %[[T0]], %[[T1]][%[[T3]] : i64] : vector<4xi64>
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : vector<4xi64> to vector<4xindex>
@@ -845,8 +832,7 @@ func @extract_strided_index_slice1(%arg0: vector<4xindex>) -> vector<2xindex> {
// CHECK-LABEL: @extract_strided_index_slice1(
// CHECK-SAME: %[[A:.*]]: vector<4xindex>)
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4xindex> to vector<4xi64>
-// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4xindex> to vector<4xi64>
-// CHECK: %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T1]] [2, 3] : vector<4xi64>, vector<4xi64>
+// CHECK: %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T0]] [2, 3] : vector<4xi64>, vector<4xi64>
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2xi64> to vector<2xindex>
// CHECK: return %[[T3]] : vector<2xindex>
@@ -875,14 +861,13 @@ func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
}
// CHECK-LABEL: @extract_strided_slice3(
// CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>)
+// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>>
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_2:.*]] = splat %[[VAL_1]] : vector<2x2xf32>
-// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>>
+// CHECK: %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
// CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<8xf32>>
// CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2, 3] : vector<8xf32>, vector<8xf32>
-// CHECK: %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[VAL_6]][0] : !llvm.array<2 x vector<2xf32>>
-// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>>
// CHECK: %[[T5:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<8xf32>>
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T5]] [2, 3] : vector<8xf32>, vector<8xf32>
// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T4]][1] : !llvm.array<2 x vector<2xf32>>
@@ -918,8 +903,8 @@ func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<
// CHECK-LABEL: @insert_strided_slice2
//
// Subvector vector<2xf32> @0 into vector<4xf32> @2
+// CHECK: unrealized_conversion_cast %{{.*}} : vector<4x4xf32> to !llvm.array<4 x vector<4xf32>>
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm.array<2 x vector<2xf32>>
-// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : vector<4x4xf32> to !llvm.array<4 x vector<4xf32>>
// CHECK-NEXT: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vector<4xf32>>
// Element @0 -> element @2
// CHECK-NEXT: arith.constant 0 : index
@@ -935,12 +920,10 @@ func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<
// CHECK-NEXT: arith.constant 3 : index
// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64
// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32>
-// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : vector<4x4xf32> to !llvm.array<4 x vector<4xf32>>
// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x vector<4xf32>>
//
// Subvector vector<2xf32> @1 into vector<4xf32> @3
// CHECK: llvm.extractvalue {{.*}}[1] : !llvm.array<2 x vector<2xf32>>
-// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : vector<4x4xf32> to !llvm.array<4 x vector<4xf32>>
// CHECK-NEXT: llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vector<4xf32>>
// Element @0 -> element @2
// CHECK-NEXT: arith.constant 0 : index
@@ -968,12 +951,11 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
// CHECK-LABEL: @insert_strided_slice3(
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>,
// CHECK-SAME: %[[B:.*]]: vector<16x4x8xf32>)
-// CHECK: %[[s2:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>>
+// CHECK-DAG: %[[s2:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>>
+// CHECK-DAG: %[[s4:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: %[[s3:.*]] = llvm.extractvalue %[[s2]][0] : !llvm.array<16 x array<4 x vector<8xf32>>>
-// CHECK: %[[s4:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: %[[s5:.*]] = llvm.extractvalue %[[s4]][0] : !llvm.array<2 x vector<4xf32>>
-// CHECK: %[[s6:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>>
-// CHECK: %[[s7:.*]] = llvm.extractvalue %[[s6]][0, 0] : !llvm.array<16 x array<4 x vector<8xf32>>>
+// CHECK: %[[s7:.*]] = llvm.extractvalue %[[s2]][0, 0] : !llvm.array<16 x array<4 x vector<8xf32>>>
// CHECK: %[[s8:.*]] = arith.constant 0 : index
// CHECK: %[[s9:.*]] = builtin.unrealized_conversion_cast %[[s8]] : index to i64
// CHECK: %[[s10:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s9]] : i64] : vector<4xf32>
@@ -999,10 +981,8 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
// CHECK: %[[s30:.*]] = builtin.unrealized_conversion_cast %[[s29]] : index to i64
// CHECK: %[[s31:.*]] = llvm.insertelement %[[s28]], %[[s25]]{{\[}}%[[s30]] : i64] : vector<8xf32>
// CHECK: %[[s32:.*]] = llvm.insertvalue %[[s31]], %[[s3]][0] : !llvm.array<4 x vector<8xf32>>
-// CHECK: %[[s33:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
-// CHECK: %[[s34:.*]] = llvm.extractvalue %[[s33]][1] : !llvm.array<2 x vector<4xf32>>
-// CHECK: %[[s35:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>>
-// CHECK: %[[s36:.*]] = llvm.extractvalue %[[s35]][0, 1] : !llvm.array<16 x array<4 x vector<8xf32>>>
+// CHECK: %[[s34:.*]] = llvm.extractvalue %[[s4]][1] : !llvm.array<2 x vector<4xf32>>
+// CHECK: %[[s36:.*]] = llvm.extractvalue %[[s2]][0, 1] : !llvm.array<16 x array<4 x vector<8xf32>>>
// CHECK: %[[s37:.*]] = arith.constant 0 : index
// CHECK: %[[s38:.*]] = builtin.unrealized_conversion_cast %[[s37]] : index to i64
// CHECK: %[[s39:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s38]] : i64] : vector<4xf32>
@@ -1028,8 +1008,7 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
// CHECK: %[[s59:.*]] = builtin.unrealized_conversion_cast %[[s58]] : index to i64
// CHECK: %[[s60:.*]] = llvm.insertelement %[[s57]], %[[s54]]{{\[}}%[[s59]] : i64] : vector<8xf32>
// CHECK: %[[s61:.*]] = llvm.insertvalue %[[s60]], %[[s32]][1] : !llvm.array<4 x vector<8xf32>>
-// CHECK: %[[s62:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>>
-// CHECK: %[[s63:.*]] = llvm.insertvalue %[[s61]], %[[s62]][0] : !llvm.array<16 x array<4 x vector<8xf32>>>
+// CHECK: %[[s63:.*]] = llvm.insertvalue %[[s61]], %[[s2]][0] : !llvm.array<16 x array<4 x vector<8xf32>>>
// CHECK: %[[s64:.*]] = builtin.unrealized_conversion_cast %[[s63]] : !llvm.array<16 x array<4 x vector<8xf32>>> to vector<16x4x8xf32>
// CHECK: return %[[s64]] : vector<16x4x8xf32>
@@ -1039,24 +1018,19 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
// CHECK-LABEL: @vector_fma
// CHECK-SAME: %[[A:.*]]: vector<8xf32>
// CHECK-SAME: %[[B:.*]]: vector<2x4xf32>
+ // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: "llvm.intr.fmuladd"
// CHECK-SAME: (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
%0 = vector.fma %a, %a, %a : vector<8xf32>
- // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: %[[b00:.*]] = llvm.extractvalue %[[BL]][0] : !llvm.array<2 x vector<4xf32>>
- // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: %[[b01:.*]] = llvm.extractvalue %[[BL]][0] : !llvm.array<2 x vector<4xf32>>
- // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: %[[b02:.*]] = llvm.extractvalue %[[BL]][0] : !llvm.array<2 x vector<4xf32>>
// CHECK: %[[B0:.*]] = "llvm.intr.fmuladd"(%[[b00]], %[[b01]], %[[b02]]) :
// CHECK-SAME: (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32>
// CHECK: llvm.insertvalue %[[B0]], {{.*}}[0] : !llvm.array<2 x vector<4xf32>>
- // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: %[[b10:.*]] = llvm.extractvalue %[[BL]][1] : !llvm.array<2 x vector<4xf32>>
- // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: %[[b11:.*]] = llvm.extractvalue %[[BL]][1] : !llvm.array<2 x vector<4xf32>>
- // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: %[[b12:.*]] = llvm.extractvalue %[[BL]][1] : !llvm.array<2 x vector<4xf32>>
// CHECK: %[[B1:.*]] = "llvm.intr.fmuladd"(%[[b10]], %[[b11]], %[[b12]]) :
// CHECK-SAME: (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32>
diff --git a/mlir/test/Dialect/ArmSVE/memcpy.mlir b/mlir/test/Dialect/ArmSVE/memcpy.mlir
index e8fdba8245495..93a02ed42b50d 100644
--- a/mlir/test/Dialect/ArmSVE/memcpy.mlir
+++ b/mlir/test/Dialect/ArmSVE/memcpy.mlir
@@ -7,19 +7,18 @@ func @memcopy(%src : memref<?xf32>, %dst : memref<?xf32>, %size : index) {
%vs = arm_sve.vector_scale : index
%step = arith.muli %c4, %vs : index
+ // CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref<?xf32> to !llvm.struct<(ptr<f32>
+ // CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref<?xf32> to !llvm.struct<(ptr<f32>
// CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}}
scf.for %i0 = %c0 to %size step %step {
- // CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref<?xf32> to !llvm.struct<(ptr<f32>
// CHECK: [[SRCIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64
// CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr<f32>
// CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
// CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]] : !llvm.ptr<vec<? x 4 x f32>>
%0 = arm_sve.load %src[%i0] : !arm_sve.vector<4xf32> from memref<?xf32>
- // CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref<?xf32> to !llvm.struct<(ptr<f32>
- // CHECK: [[DSTIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64
// CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr<f32>
- // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[DSTIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+ // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr<f32> to !llvm.ptr<vec<? x 4 x f32>>
// CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]] : !llvm.ptr<vec<? x 4 x f32>>
arm_sve.store %0, %dst[%i0] : !arm_sve.vector<4xf32> to memref<?xf32>
diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index 18f62bb8ee772..9978eb510ff9b 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -139,8 +139,8 @@ func @dynamic_results(%arg0: tensor<?x?xf32>)
// CHECK-LABEL: func @generic_with_init_tensor(
// CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<2x3x4xvector<3x4xi4>>,
// CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<3x2xf32>) -> tensor<3x2xf32> {
-// CHECK: %[[ARG0_MEMREF:.*]] = memref.buffer_cast %[[ARG0_TENSOR]] : memref<2x3x4xvector<3x4xi4>>
-// CHECK: %[[ARG1_MEMREF:.*]] = memref.buffer_cast %[[ARG1_TENSOR]] : memref<3x2xf32>
+// CHECK-DAG: %[[ARG0_MEMREF:.*]] = memref.buffer_cast %[[ARG0_TENSOR]] : memref<2x3x4xvector<3x4xi4>>
+// CHECK-DAG: %[[ARG1_MEMREF:.*]] = memref.buffer_cast %[[ARG1_TENSOR]] : memref<3x2xf32>
// CHECK: %[[INIT_BUFFER:.*]] = memref.alloc() : memref<3x2xf32>
// CHECK: linalg.copy(%[[ARG1_MEMREF]], %[[INIT_BUFFER]]) : memref<3x2xf32>, memref<3x2xf32>
// CHECK: linalg.generic
@@ -169,10 +169,11 @@ func private @make_index() -> index
// CHECK-LABEL: func @bufferize_slice(
// CHECK-SAME: %[[T:[0-9a-z]*]]: tensor<?x?xf32>
func @bufferize_slice(%t : tensor<?x?xf32>) -> (tensor<2x3xf32>, tensor<2x?xf32>) {
+ // CHECK: %[[M:.*]] = memref.buffer_cast %[[T]] : memref<?x?xf32>
+
// CHECK: %[[IDX:.*]] = call @make_index() : () -> index
%i0 = call @make_index() : () -> index
- // CHECK: %[[M:.*]] = memref.buffer_cast %[[T]] : memref<?x?xf32>
// CHECK-NEXT: %[[A0:.*]] = memref.alloc() : memref<2x3xf32>
// CHECK-NEXT: %[[SM0:.*]] = memref.subview %[[M]][0, 0] [2, 3] [1, 1]
// CHECK-SAME: memref<?x?xf32> to memref<2x3xf32, #[[$MAP0]]>
@@ -204,6 +205,10 @@ func private @make_index() -> index
// CHECK-SAME: %[[ST1:[0-9a-z]*]]: tensor<2x?xf32>
func @bufferize_insert_slice(%t : tensor<?x?xf32>, %st0 : tensor<2x3xf32>, %st1 : tensor<2x?xf32>) ->
(tensor<?x?xf32>, tensor<?x?xf32>) {
+ // CHECK-DAG: %[[M:.*]] = memref.buffer_cast %[[T]] : memref<?x?xf32>
+ // CHECK-DAG: %[[SM0:.*]] = memref.buffer_cast %[[ST0]] : memref<2x3xf32>
+ // CHECK-DAG: %[[SM1:.*]] = memref.buffer_cast %[[ST1]] : memref<2x?xf32>
+
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
@@ -212,8 +217,6 @@ func @bufferize_insert_slice(%t : tensor<?x?xf32>, %st0 : tensor<2x3xf32>, %st1
// CHECK: %[[IDX:.*]] = call @make_index() : () -> index
- // CHECK-DAG: %[[M:.*]] = memref.buffer_cast %[[T]] : memref<?x?xf32>
- // CHECK-DAG: %[[SM0:.*]] = memref.buffer_cast %[[ST0]] : memref<2x3xf32>
// CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[T]], %[[C0]] : tensor<?x?xf32>
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[T]], %[[C1]] : tensor<?x?xf32>
// CHECK-NEXT: %[[M_COPY0:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
@@ -224,7 +227,6 @@ func @bufferize_insert_slice(%t : tensor<?x?xf32>, %st0 : tensor<2x3xf32>, %st1
// CHECK-NEXT: %[[RT0:.*]] = memref.tensor_load %[[M_COPY0]] : memref<?x?xf32>
%t0 = tensor.insert_slice %st0 into %t[0, 0][2, 3][1, 1] : tensor<2x3xf32> into tensor<?x?xf32>
- // CHECK-DAG: %[[SM1:.*]] = memref.buffer_cast %[[ST1]] : memref<2x?xf32>
// CHECK-NEXT: %[[M_COPY1:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
// CHECK-NEXT: linalg.copy(%[[M]], %[[M_COPY1]]) : memref<?x?xf32>, memref<?x?xf32>
// CHECK-NEXT: %[[SUBVIEW1:.*]] = memref.subview %[[M_COPY1]][0, %[[IDX]]] [2, %[[IDX]]] [1, 2]
@@ -285,13 +287,13 @@ func @pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tens
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[IN_MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x?x2x?xf32>
// CHECK: %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
// CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[OFFSET]], %[[C2]] : index
// CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
// CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index
// CHECK: %[[FILLED:.*]] = memref.alloc(%[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]) : memref<4x?x?x?xf32>
// CHECK: linalg.fill(%[[CST]], %[[FILLED]]) : f32, memref<4x?x?x?xf32>
-// CHECK: %[[IN_MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x?x2x?xf32>
// CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]) : memref<4x?x?x?xf32>
// CHECK: linalg.copy(%[[FILLED]], %[[OUT]]) : memref<4x?x?x?xf32>, memref<4x?x?x?xf32>
// CHECK: %[[INTERIOR:.*]] = memref.subview %[[OUT]][0, 0, %[[OFFSET]], 0] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : memref<4x?x?x?xf32> to memref<4x?x2x?xf32, #map>
diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
index 3551aa18a0547..c9084d6371b30 100644
--- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
@@ -57,8 +57,7 @@ func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]]
-// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]]
-// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val2]], %[[detensored_res]]
+// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]]
// CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]]
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]]
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index ebe022d7cbf1b..88ab1478d1e69 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -73,10 +73,7 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
// DET-ALL: linalg.yield %{{.*}} : i32
// DET-ALL: } -> tensor<i32>
// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32>
-// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32>
-// DET-ALL: arith.cmpi slt, %{{.*}}, %{{.*}} : i32
-// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32>
-// DET-ALL: tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL: cmpi slt, %{{.*}}, %{{.*}} : i32
// DET-ALL: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
// DET-ALL: ^[[bb2]](%{{.*}}: i32)
// DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32>
@@ -99,8 +96,7 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
// DET-CF: ^bb1(%{{.*}}: tensor<10xi32>)
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
-// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
-// DET-CF: arith.cmpi slt, %{{.*}}, %{{.*}} : i32
+// DET-CF: cmpi slt, %{{.*}}, %{{.*}} : i32
// DET-CF: cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
// DET-CF: ^bb2(%{{.*}}: tensor<i32>)
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
diff --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir
index 853727d6305c8..38b8234334546 100644
--- a/mlir/test/Dialect/SCF/bufferize.mlir
+++ b/mlir/test/Dialect/SCF/bufferize.mlir
@@ -4,11 +4,11 @@
// CHECK-SAME: %[[PRED:.*]]: i1,
// CHECK-SAME: %[[TRUE_TENSOR:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[FALSE_TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK: %[[TRUE_MEMREF:.*]] = memref.buffer_cast %[[TRUE_TENSOR]] : memref<?xf32>
+// CHECK: %[[FALSE_MEMREF:.*]] = memref.buffer_cast %[[FALSE_TENSOR]] : memref<?xf32>
// CHECK: %[[RESULT_MEMREF:.*]] = scf.if %[[PRED]] -> (memref<?xf32>) {
-// CHECK: %[[TRUE_MEMREF:.*]] = memref.buffer_cast %[[TRUE_TENSOR]] : memref<?xf32>
// CHECK: scf.yield %[[TRUE_MEMREF]] : memref<?xf32>
// CHECK: } else {
-// CHECK: %[[FALSE_MEMREF:.*]] = memref.buffer_cast %[[FALSE_TENSOR]] : memref<?xf32>
// CHECK: scf.yield %[[FALSE_MEMREF]] : memref<?xf32>
// CHECK: }
// CHECK: %[[RESULT_TENSOR:.*]] = memref.tensor_load %[[RESULT_MEMREF:.*]] : memref<?xf32>
@@ -29,9 +29,7 @@ func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tens
// CHECK-SAME: %[[STEP:.*]]: index) -> tensor<f32> {
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<f32>
// CHECK: %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
-// CHECK: %[[TENSOR_ITER:.*]] = memref.tensor_load %[[ITER]] : memref<f32>
-// CHECK: %[[MEMREF_YIELDED:.*]] = memref.buffer_cast %[[TENSOR_ITER]] : memref<f32>
-// CHECK: scf.yield %[[MEMREF_YIELDED]] : memref<f32>
+// CHECK: scf.yield %[[ITER]] : memref<f32>
// CHECK: }
// CHECK: %[[VAL_8:.*]] = memref.tensor_load %[[VAL_9:.*]] : memref<f32>
// CHECK: return %[[VAL_8]] : tensor<f32>
diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
index 83a979cc826b2..e709158bfc2e8 100644
--- a/mlir/test/Dialect/Standard/bufferize.mlir
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -4,8 +4,8 @@
// CHECK-SAME: %[[PRED:.*]]: i1,
// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor<f32>,
// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK: %[[TRUE_VAL_MEMREF:.*]] = memref.buffer_cast %[[TRUE_VAL]] : memref<f32>
-// CHECK: %[[FALSE_VAL_MEMREF:.*]] = memref.buffer_cast %[[FALSE_VAL]] : memref<f32>
+// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = memref.buffer_cast %[[TRUE_VAL]] : memref<f32>
+// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = memref.buffer_cast %[[FALSE_VAL]] : memref<f32>
// CHECK: %[[RET_MEMREF:.*]] = select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
// CHECK: %[[RET:.*]] = memref.tensor_load %[[RET_MEMREF]] : memref<f32>
// CHECK: return %[[RET]] : tensor<f32>
diff --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir
index a3f2e15c18b97..4a778cd8ce091 100644
--- a/mlir/test/Dialect/Standard/func-bufferize.mlir
+++ b/mlir/test/Dialect/Standard/func-bufferize.mlir
@@ -2,22 +2,16 @@
// CHECK-LABEL: func @identity(
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[ARG]] : memref<f32>
-// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<f32>
-// CHECK: return %[[MEMREF]] : memref<f32>
+// CHECK: return %[[ARG]] : memref<f32>
func @identity(%arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
}
// CHECK-LABEL: func @block_arguments(
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK: %[[T1:.*]] = memref.tensor_load %[[ARG]] : memref<f32>
-// CHECK: %[[M1:.*]] = memref.buffer_cast %[[T1]] : memref<f32>
-// CHECK: br ^bb1(%[[M1]] : memref<f32>)
+// CHECK: br ^bb1(%[[ARG]] : memref<f32>)
// CHECK: ^bb1(%[[BBARG:.*]]: memref<f32>):
-// CHECK: %[[T2:.*]] = memref.tensor_load %[[BBARG]] : memref<f32>
-// CHECK: %[[M2:.*]] = memref.buffer_cast %[[T2]] : memref<f32>
-// CHECK: return %[[M2]] : memref<f32>
+// CHECK: return %[[BBARG]] : memref<f32>
func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
br ^bb1(%arg0: tensor<f32>)
^bb1(%bbarg: tensor<f32>):
@@ -35,9 +29,7 @@ func @call_source() -> tensor<f32> {
}
// CHECK-LABEL: func @call_sink(
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) {
-// CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[ARG]] : memref<f32>
-// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<f32>
-// CHECK: call @sink(%[[MEMREF]]) : (memref<f32>) -> ()
+// CHECK: call @sink(%[[ARG]]) : (memref<f32>) -> ()
// CHECK: return
func private @sink(tensor<f32>)
func @call_sink(%arg0: tensor<f32>) {
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index f75b8bd8965c2..f85c07cbbc3ab 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -74,11 +74,11 @@ func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> {
// CHECK-LABEL: func @tensor.generate(
// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
+// CHECK: %[[CASTED:.*]] = memref.buffer_cast %[[ARG]] : memref<*xf32>
// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref<?xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
-// CHECK: %[[CASTED:.*]] = memref.buffer_cast %[[ARG]] : memref<*xf32>
// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
// CHECK: scf.yield
diff --git a/mlir/test/Transforms/test-legalize-remapped-value.mlir b/mlir/test/Transforms/test-legalize-remapped-value.mlir
index ff571c93f9386..781f3eba5e826 100644
--- a/mlir/test/Transforms/test-legalize-remapped-value.mlir
+++ b/mlir/test/Transforms/test-legalize-remapped-value.mlir
@@ -1,13 +1,28 @@
// RUN: mlir-opt %s -test-remapped-value | FileCheck %s
// Simple test that exercises ConvertPatternRewriter::getRemappedValue.
+
+// CHECK-LABEL: func @remap_input_1_to_1
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+// CHECK-NEXT: %[[VAL:.*]] = "test.one_variadic_out_one_variadic_in1"(%[[ARG]], %[[ARG]])
+// CHECK-NEXT: "test.one_variadic_out_one_variadic_in1"(%[[VAL]], %[[VAL]])
+
func @remap_input_1_to_1(%arg0: i32) {
%0 = "test.one_variadic_out_one_variadic_in1"(%arg0) : (i32) -> i32
%1 = "test.one_variadic_out_one_variadic_in1"(%0) : (i32) -> i32
"test.return"() : () -> ()
}
-// CHECK-LABEL: func @remap_input_1_to_1
-// CHECK-SAME: (%[[ARG:.*]]: i32)
-// CHECK-NEXT: %[[VAL:.*]] = "test.one_variadic_out_one_variadic_in1"(%[[ARG]], %[[ARG]])
-// CHECK-NEXT: "test.one_variadic_out_one_variadic_in1"(%[[VAL]], %[[VAL]])
+// Test the case where an operation is converted before its operands are.
+
+// CHECK-LABEL: func @remap_unconverted
+// CHECK-NEXT: %[[VAL:.*]] = "test.type_producer"() : () -> f64
+// CHECK-NEXT: "test.type_consumer"(%[[VAL]]) : (f64)
+func @remap_unconverted() {
+ %region_result = "test.remapped_value_region"() ({
+ %result = "test.type_producer"() : () -> f32
+ "test.return"(%result) : (f32) -> ()
+ }) : () -> (f32)
+ "test.type_consumer"(%region_result) : (f32) -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index 59c62d188d45f..4887a87d0156f 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -10,13 +10,6 @@ func @test_invalid_arg_materialization(
// -----
-// expected-error at below {{failed to legalize conversion operation generated for block argument}}
-func @test_invalid_arg_illegal_materialization(%arg0: i32) {
- "foo.return"(%arg0) : (i32) -> ()
-}
-
-// -----
-
// CHECK-LABEL: func @test_valid_arg_materialization
func @test_valid_arg_materialization(%arg0: i64) {
// CHECK: %[[ARG:.*]] = "test.type_producer"
@@ -67,14 +60,6 @@ func @test_transitive_use_invalid_materialization() {
// -----
-func @test_invalid_result_legalization() {
- // expected-error at below {{failed to legalize conversion operation generated for result #0 of operation 'test.type_producer' that remained live after conversion}}
- %result = "test.type_producer"() : () -> i16
- "foo.return"(%result) : (i16) -> ()
-}
-
-// -----
-
// CHECK-LABEL: func @test_valid_result_legalization
func @test_valid_result_legalization() {
// CHECK: %[[RESULT:.*]] = "test.type_producer"() : () -> f64
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 25c3eb34f849a..3342402740209 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -28,7 +28,6 @@ func @remap_input_1_to_1(%arg0: i64) {
// CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64)
func @remap_call_1_to_1(%arg0: i64) {
// CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> ()
- // expected-remark at +1 {{op 'std.call' is not legalizable}}
call @remap_input_1_to_1(%arg0) : (i64) -> ()
// expected-remark at +1 {{op 'std.return' is not legalizable}}
return
@@ -36,7 +35,6 @@ func @remap_call_1_to_1(%arg0: i64) {
// CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16)
func @remap_input_1_to_N(%arg0: f32) -> f32 {
- // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32
// CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> ()
"test.return"(%arg0) : (f32) -> ()
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c7c51f4160bed..656ec7ef86b6f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1497,13 +1497,25 @@ def TestValidOp : TEST_Op<"valid", [Terminator]>,
def TestMergeBlocksOp : TEST_Op<"merge_blocks"> {
let summary = "merge_blocks operation";
let description = [{
- Test op with multiple blocks that are merged with Dialect Conversion"
+ Test op with multiple blocks that are merged with Dialect Conversion
}];
let regions = (region AnyRegion:$body);
let results = (outs Variadic<AnyType>:$result);
}
+def TestRemappedValueRegionOp : TEST_Op<"remapped_value_region",
+ [SingleBlock]> {
+ let summary = "remapped_value_region operation";
+ let description = [{
+ Test op that remaps values that haven't yet been converted in Dialect
+ Conversion.
+ }];
+
+ let regions = (region SizedRegion<1>:$body);
+ let results = (outs Variadic<AnyType>:$result);
+}
+
def TestSignatureConversionUndoOp : TEST_Op<"signature_conversion_undo"> {
let regions = (region AnyRegion);
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3ebebe03b6534..c93d233900a2e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -429,7 +429,8 @@ struct TestSplitReturnType : public ConversionPattern {
// Check if the first operation is a cast operation, if it is we use the
// results directly.
auto *defOp = operands[0].getDefiningOp();
- if (auto packerOp = llvm::dyn_cast_or_null<TestCastOp>(defOp)) {
+ if (auto packerOp =
+ llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
return success();
}
@@ -586,16 +587,6 @@ struct TestTypeConverter : public TypeConverter {
addConversion(convertType);
addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
-
- /// Materialize the cast for one-to-one conversion from i64 to f64.
- const auto materializeOneToOneCast =
- [](OpBuilder &builder, IntegerType resultType, ValueRange inputs,
- Location loc) -> Optional<Value> {
- if (resultType.getWidth() == 42 && inputs.size() == 1)
- return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
- return llvm::None;
- };
- addArgumentMaterialization(materializeOneToOneCast);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
@@ -630,8 +621,6 @@ struct TestTypeConverter : public TypeConverter {
/// 1->N type mappings.
static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
- if (inputs.size() == 1)
- return inputs[0];
return builder.create<TestCastOp>(loc, resultType, inputs).getResult();
}
};
@@ -684,6 +673,8 @@ struct TestLegalizePatternDriver
return converter.isSignatureLegal(op.getType()) &&
converter.isLegal(&op.getBody());
});
+ target.addDynamicallyLegalOp<CallOp>(
+ [&](CallOp op) { return converter.isLegal(op); });
// TestCreateUnregisteredOp creates `arith.constant` operation,
// which was not added to target intentionally to test
@@ -771,6 +762,16 @@ static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
// to get the remapped value of an original value that was replaced using
// ConversionPatternRewriter.
namespace {
+struct TestRemapValueTypeConverter : public TypeConverter {
+ using TypeConverter::TypeConverter;
+
+ TestRemapValueTypeConverter() {
+ addConversion(
+ [](Float32Type type) { return Float64Type::get(type.getContext()); });
+ addConversion([](Type type) { return type; });
+ }
+};
+
/// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with
/// a one-operand two-result OneVResOneVOperandOp1 by replicating its original
/// operand twice.
@@ -802,6 +803,36 @@ struct OneVResOneVOperandOp1Converter
}
};
+/// A rewriter pattern that tests that blocks can be merged.
+struct TestRemapValueInRegion
+ : public OpConversionPattern<TestRemappedValueRegionOp> {
+ using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ Block &block = op.getBody().front();
+ Operation *terminator = block.getTerminator();
+
+ // Merge the block into the parent region.
+ Block *parentBlock = op->getBlock();
+ Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator());
+ rewriter.mergeBlocks(&block, parentBlock, ValueRange());
+ rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange());
+
+ // Replace the results of this operation with the remapped terminator
+ // values.
+ SmallVector<Value> terminatorOperands;
+ if (failed(rewriter.getRemappedValues(terminator->getOperands(),
+ terminatorOperands)))
+ return failure();
+
+ rewriter.eraseOp(terminator);
+ rewriter.replaceOp(op, terminatorOperands);
+ return success();
+ }
+};
+
struct TestRemappedValue
: public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
StringRef getArgument() const final { return "test-remapped-value"; }
@@ -809,18 +840,29 @@ struct TestRemappedValue
return "Test public remapped value mechanism in ConversionPatternRewriter";
}
void runOnFunction() override {
+ TestRemapValueTypeConverter typeConverter;
+
mlir::RewritePatternSet patterns(&getContext());
patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
+ patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>(
+ &getContext());
+ patterns.add<TestRemapValueInRegion>(typeConverter, &getContext());
mlir::ConversionTarget target(getContext());
target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>();
+
+ // Expect the type_producer/type_consumer operations to only operate on f64.
+ target.addDynamicallyLegalOp<TestTypeProducerOp>(
+ [](TestTypeProducerOp op) { return op.getType().isF64(); });
+ target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) {
+ return op.getOperand().getType().isF64();
+ });
+
// We make OneVResOneVOperandOp1 legal only when it has more that one
// operand. This will trigger the conversion that will replace one-operand
// OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
target.addDynamicallyLegalOp<OneVResOneVOperandOp1>(
- [](Operation *op) -> bool {
- return std::distance(op->operand_begin(), op->operand_end()) > 1;
- });
+ [](Operation *op) { return op->getNumOperands() > 1; });
if (failed(mlir::applyFullConversion(getFunction(), target,
std::move(patterns)))) {
More information about the Mlir-commits
mailing list