[Mlir-commits] [mlir] [mlir][Linalg] Fix Linalg behavior in the context of vector elemental… (PR #71041)

Nicolas Vasilache llvmlistbot at llvm.org
Thu Nov 2 03:32:33 PDT 2023


https://github.com/nicolasvasilache created https://github.com/llvm/llvm-project/pull/71041

… types

>From 29364b5adf6606efacae1b91f43a34bdcc28cc83 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nicolas.vasilache at gmail.com>
Date: Thu, 2 Nov 2023 10:31:30 +0000
Subject: [PATCH] [mlir][Linalg] Fix Linalg behavior in the context of vector
 elemental types

---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     | 28 +++++++++++++++----
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    |  4 ++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 13 +++++----
 .../Dialect/Linalg/generalize-named-ops.mlir  | 20 +++++++++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 11 ++++++++
 5 files changed, 64 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 69ca888a8acdbe0..fbf3f19cde0e9b8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -344,7 +344,7 @@ def LinalgStructuredInterface
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the `opOperand` rank or zero for scalars.
+        Return the `opOperand` rank or zero for scalars or vectors not wrapped within a tensor or a memref.
       }],
       /*retTy=*/"int64_t",
       /*methodName=*/"getRank",
@@ -352,9 +352,17 @@ def LinalgStructuredInterface
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         assert(opOperand->getOwner() == this->getOperation());
-        if (auto shapedType =
-              ::llvm::dyn_cast<ShapedType>(opOperand->get().getType()))
+        Type t = opOperand->get().getType();
+        // A VectorType is an elemental type, do not consider its rank for the operand.
+        if (isa<VectorType>(t))
+          return 0;
+        // Tensor and Memref container types have a rank.
+        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
+          // Failsafe.
+          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
+                 "expected a ranked tensor or memref in LinalgInterface::getRank");
           return shapedType.getRank();
+        }
         return 0;
       }]
     >,
@@ -384,7 +392,8 @@ def LinalgStructuredInterface
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the `opOperand` shape or an empty vector for scalars.
+        Return the `opOperand` shape or an empty vector for scalars or vectors
+        not wrapped within a tensor or a memref.
       }],
       /*retTy=*/"ArrayRef<int64_t>",
       /*methodName=*/"getShape",
@@ -392,9 +401,16 @@ def LinalgStructuredInterface
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         assert(opOperand->getOwner() == this->getOperation());
-        if (auto shapedType =
-              ::llvm::dyn_cast<ShapedType>(opOperand->get().getType()))
+        Type t = opOperand->get().getType();
+        // A VectorType is an elemental type, do not consider its rank for the operand.
+        if (isa<VectorType>(t))
+          return {};
+        if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
+          // Failsafe.
+          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
+                 "expected a ranked tensor or memref in LinalgInterface::getRank");
           return shapedType.getShape();
+        }
         return {};
       }]
     >,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index dfd6b991e7da159..08d46f236f8ab3b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1130,7 +1130,9 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
                            "arguments as the number of input/output operands");
 
   for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
-    Type elementType = getElementTypeOrSelf(opOperand->get());
+    Type elementType = opOperand->get().getType();
+    if (isa<MemRefType, RankedTensorType>(elementType))
+      elementType = getElementTypeOrSelf(opOperand->get().getType());
     Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
     if (elementType != argType)
       return op->emitOpError("expected type of bb argument #")
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5457d51db1cc180..5a593fbb2b6024d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -122,13 +122,12 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
   assert(llvm::all_of(outputTypes,
                       [](Type t) { return llvm::isa<ShapedType>(t); }));
 
-  // TODO: atm all operands go through getElementTypeOrSelf,
-  // reconsider when we have evidence we need to.
   SmallVector<Type, 8> argTypes;
   SmallVector<Location, 8> argLocs;
   for (auto containers : {inputTypes, outputTypes}) {
     for (auto t : containers) {
-      argTypes.push_back(getElementTypeOrSelf(t));
+      argTypes.push_back(
+          isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
 
       // TODO: Pass in a proper location here.
       argLocs.push_back(opBuilder.getUnknownLoc());
@@ -826,7 +825,9 @@ static void buildGenericRegion(
   SmallVector<Location, 4> blockArgLocs;
   for (ValueRange container : {inputs, outputs}) {
     for (Value v : container) {
-      blockArgTypes.push_back(getElementTypeOrSelf(v));
+      Type t = v.getType();
+      blockArgTypes.push_back(
+          isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
       blockArgLocs.push_back(v.getLoc());
     }
   }
@@ -1927,7 +1928,9 @@ static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
   for (OpOperand &opOperand : op->getOpOperands()) {
     OpOperand *outputOperand =
         linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
-    Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
+    Type elementType = outputOperand->get().getType();
+    if (isa<MemRefType, RankedTensorType>(elementType))
+      elementType = getElementTypeOrSelf(outputOperand->get().getType());
     if (opOperand.get().getType() != elementType)
       return op.emitOpError("type of yield operand ")
              << (opOperand.getOperandNumber() + 1) << " ("
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 2259d47eb2b2b0d..e852824cdb73675 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -587,3 +587,23 @@ func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
 // CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
 // CHECK-NEXT:      %[[max:.+]] = arith.maximumf %[[BBARG0]], %[[BBARG1]] : f32
 // CHECK-NEXT:      linalg.yield %[[max]] : f32
+
+// -----
+
+
+// CHECK-LABEL: func @fill_tensor
+func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
+  %e0 = tensor.empty() : tensor<f32>
+  %0 = linalg.fill ins(%f : f32) outs(%e0 : tensor<f32>) -> tensor<f32>
+// CHECK: linalg.generic
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      linalg.yield %[[BBARG0]] : f32
+
+  %e1 = tensor.empty() : tensor<vector<2x4xf32>>
+  %1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
+// CHECK: linalg.generic
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: vector<2x4xf32>, %[[BBARG1:.+]]: vector<2x4xf32>)
+// CHECK-NEXT:      linalg.yield %[[BBARG0]] : vector<2x4xf32>
+
+  return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 76146b17014ebb5..5ca35155854d332 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1585,3 +1585,14 @@ func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> t
   %1 = linalg.max ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
   return %1 : tensor<4x8x16xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @fill_tensor
+func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
+  %e0 = tensor.empty() : tensor<f32>
+  %0 = linalg.fill ins(%f : f32) outs(%e0 : tensor<f32>) -> tensor<f32>
+  %e1 = tensor.empty() : tensor<vector<2x4xf32>>
+  %1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
+  return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
+}



More information about the Mlir-commits mailing list