[Mlir-commits] [mlir] 1752740 - [mlir][tensor] Fix FoldTensorCastProducerOp for multiple result operations (#93374)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 6 22:52:39 PDT 2024
Author: Prashant Kumar
Date: 2024-06-07T11:22:36+05:30
New Revision: 1752740f4b4b752bbe2987a0de398c6f671ceb71
URL: https://github.com/llvm/llvm-project/commit/1752740f4b4b752bbe2987a0de398c6f671ceb71
DIFF: https://github.com/llvm/llvm-project/commit/1752740f4b4b752bbe2987a0de398c6f671ceb71.diff
LOG: [mlir][tensor] Fix FoldTensorCastProducerOp for multiple result operations (#93374)
For patterns where there are multiple results apart from dpsInits, this
fails.
E.g.:
```
%13:2 = iree_codegen.ukernel.generic "iree_uk_unpack"
ins(%extracted_slice : tensor<?x1x16x16xf32>) outs(%11 :
tensor<?x?xf32>) ... -> tensor<?x?xf32>, i32
```
The above op has results apart from dpsInit and hence fails. The PR
assumes that the result has dpsInits followed by nonDpsInits.
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8545c7b9af8f7..7fc29ec0139c2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4531,17 +4531,18 @@ struct FoldTensorCastProducerOp
if (!hasTensorCastOperand)
return failure();
- SmallVector<Type, 4> newResultTypes;
- newResultTypes.reserve(op->getNumResults());
+ SmallVector<Type, 4> newResultTypes(op->getResultTypes());
SmallVector<Value, 4> newOperands;
newOperands.reserve(op->getNumOperands());
+ // Assumes that the result has dpsInits followed by nonDpsInits.
+ int64_t dpsInitIdx = 0;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
- newResultTypes.push_back(newOperands.back().getType());
+ newResultTypes[dpsInitIdx++] = newOperands.back().getType();
}
// Clone op.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f7fbd3834288b..6b51d0b294bcf 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2523,3 +2523,18 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
%16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref<?x26x2xi32>, vector<7xi32>
return %16 : vector<7xi32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @test_destination_multiple_result(
+// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xf32>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
+// CHECK: %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
+// CHECK-SAME: outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
+// CHECK: return %[[RES]]#1 : index
+func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
+ %cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
+ %cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
+ %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
+ return %0#1 : index
+}
More information about the Mlir-commits
mailing list