[Mlir-commits] [mlir] [mlir][Linalg] Fix Linalg behavior in the context of vector elemental… (PR #71041)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 2 03:33:03 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Nicolas Vasilache (nicolasvasilache)
<details>
<summary>Changes</summary>
… types
---
Full diff: https://github.com/llvm/llvm-project/pull/71041.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+22-6)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+3-1)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+8-5)
- (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+20)
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+11)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 69ca888a8acdbe0..fbf3f19cde0e9b8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -344,7 +344,7 @@ def LinalgStructuredInterface
>,
InterfaceMethod<
/*desc=*/[{
- Return the `opOperand` rank or zero for scalars.
+ Return the `opOperand` rank or zero for scalars or vectors not wrapped within a tensor or a memref.
}],
/*retTy=*/"int64_t",
/*methodName=*/"getRank",
@@ -352,9 +352,17 @@ def LinalgStructuredInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
- if (auto shapedType =
- ::llvm::dyn_cast<ShapedType>(opOperand->get().getType()))
+ Type t = opOperand->get().getType();
+ // A VectorType is an elemental type, do not consider its rank for the operand.
+ if (isa<VectorType>(t))
+ return 0;
+ // Tensor and Memref container types have a rank.
+ if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
+ // Failsafe.
+ assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
+ "expected a ranked tensor or memref in LinalgInterface::getRank");
return shapedType.getRank();
+ }
return 0;
}]
>,
@@ -384,7 +392,8 @@ def LinalgStructuredInterface
>,
InterfaceMethod<
/*desc=*/[{
- Return the `opOperand` shape or an empty vector for scalars.
+ Return the `opOperand` shape or an empty vector for scalars or vectors
+ not wrapped within a tensor or a memref.
}],
/*retTy=*/"ArrayRef<int64_t>",
/*methodName=*/"getShape",
@@ -392,9 +401,16 @@ def LinalgStructuredInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
- if (auto shapedType =
- ::llvm::dyn_cast<ShapedType>(opOperand->get().getType()))
+ Type t = opOperand->get().getType();
+ // A VectorType is an elemental type, do not consider its rank for the operand.
+ if (isa<VectorType>(t))
+ return {};
+ if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
+ // Failsafe.
+ assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
+ "expected a ranked tensor or memref in LinalgInterface::getRank");
return shapedType.getShape();
+ }
return {};
}]
>,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index dfd6b991e7da159..08d46f236f8ab3b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1130,7 +1130,9 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
"arguments as the number of input/output operands");
for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
- Type elementType = getElementTypeOrSelf(opOperand->get());
+ Type elementType = opOperand->get().getType();
+ if (isa<MemRefType, RankedTensorType>(elementType))
+ elementType = getElementTypeOrSelf(opOperand->get().getType());
Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
if (elementType != argType)
return op->emitOpError("expected type of bb argument #")
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5457d51db1cc180..5a593fbb2b6024d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -122,13 +122,12 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
assert(llvm::all_of(outputTypes,
[](Type t) { return llvm::isa<ShapedType>(t); }));
- // TODO: atm all operands go through getElementTypeOrSelf,
- // reconsider when we have evidence we need to.
SmallVector<Type, 8> argTypes;
SmallVector<Location, 8> argLocs;
for (auto containers : {inputTypes, outputTypes}) {
for (auto t : containers) {
- argTypes.push_back(getElementTypeOrSelf(t));
+ argTypes.push_back(
+ isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
// TODO: Pass in a proper location here.
argLocs.push_back(opBuilder.getUnknownLoc());
@@ -826,7 +825,9 @@ static void buildGenericRegion(
SmallVector<Location, 4> blockArgLocs;
for (ValueRange container : {inputs, outputs}) {
for (Value v : container) {
- blockArgTypes.push_back(getElementTypeOrSelf(v));
+ Type t = v.getType();
+ blockArgTypes.push_back(
+ isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
blockArgLocs.push_back(v.getLoc());
}
}
@@ -1927,7 +1928,9 @@ static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
for (OpOperand &opOperand : op->getOpOperands()) {
OpOperand *outputOperand =
linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
- Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
+ Type elementType = outputOperand->get().getType();
+ if (isa<MemRefType, RankedTensorType>(elementType))
+ elementType = getElementTypeOrSelf(outputOperand->get().getType());
if (opOperand.get().getType() != elementType)
return op.emitOpError("type of yield operand ")
<< (opOperand.getOperandNumber() + 1) << " ("
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 2259d47eb2b2b0d..e852824cdb73675 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -587,3 +587,23 @@ func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
// CHECK-NEXT: %[[max:.+]] = arith.maximumf %[[BBARG0]], %[[BBARG1]] : f32
// CHECK-NEXT: linalg.yield %[[max]] : f32
+
+// -----
+
+
+// CHECK-LABEL: func @fill_tensor
+func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
+ %e0 = tensor.empty() : tensor<f32>
+ %0 = linalg.fill ins(%f : f32) outs(%e0 : tensor<f32>) -> tensor<f32>
+// CHECK: linalg.generic
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT: linalg.yield %[[BBARG0]] : f32
+
+ %e1 = tensor.empty() : tensor<vector<2x4xf32>>
+ %1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
+// CHECK: linalg.generic
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: vector<2x4xf32>, %[[BBARG1:.+]]: vector<2x4xf32>)
+// CHECK-NEXT: linalg.yield %[[BBARG0]] : vector<2x4xf32>
+
+ return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 76146b17014ebb5..5ca35155854d332 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1585,3 +1585,14 @@ func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> t
%1 = linalg.max ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
return %1 : tensor<4x8x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @fill_tensor
+func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
+ %e0 = tensor.empty() : tensor<f32>
+ %0 = linalg.fill ins(%f : f32) outs(%e0 : tensor<f32>) -> tensor<f32>
+ %e1 = tensor.empty() : tensor<vector<2x4xf32>>
+ %1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
+ return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/71041
More information about the Mlir-commits
mailing list