[Mlir-commits] [mlir] [MLIR][Vector] Allow any shaped typed to be distributed for vector.wa… (PR #114215)

Petr Kurapov llvmlistbot at llvm.org
Wed Oct 30 04:57:23 PDT 2024


https://github.com/kurapov-peter created https://github.com/llvm/llvm-project/pull/114215

…rp_execute_on_lane_0's return values

The second part of https://github.com/llvm/llvm-project/pull/112945.

>From 7d5b1b46881a77dc10608e559f37ae25a568184b Mon Sep 17 00:00:00 2001
From: Petr Kurapov <petr.a.kurapov at intel.com>
Date: Thu, 17 Oct 2024 17:29:00 +0000
Subject: [PATCH] [MLIR][Vector] Allow any shaped typed to be distributed for
 vector.warp_execute_on_lane_0's return values

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++------
 mlir/test/Dialect/Vector/invalid.mlir    |  6 +++---
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1853ae04f45d90..af5a2a276042ca 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
   // If the types matches there is no distribution.
   if (expanded == distributed)
     return success();
-  auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
-  auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
+  auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
+  auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
   if (!expandedVecType || !distributedVecType)
-    return op->emitOpError("expected vector type for distributed operands.");
+    return op->emitOpError("expected shaped type for distributed operands.");
   if (expandedVecType.getRank() != distributedVecType.getRank() ||
       expandedVecType.getElementType() != distributedVecType.getElementType())
     return op->emitOpError(
-        "expected distributed vectors to have same rank and element type.");
+        "expected distributed types to have same rank and element type.");
 
   SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
   for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
@@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
       continue;
     if (eDim % dDim != 0)
       return op->emitOpError()
-             << "expected expanded vector dimension #" << i << " (" << eDim
-             << ") to be a multipler of the distributed vector dimension ("
+             << "expected expanded type dimension #" << i << " (" << eDim
+             << ") to be a multipler of the distributed type dimension ("
              << dDim << ")";
     scales[i] = eDim / dDim;
   }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 56039d04549aa5..f2b7685d79effb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1665,7 +1665,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
 // -----
 
 func.func @warp_2_distributed_dims(%laneid: index) {
-  // expected-error at +1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}}
+  // expected-error at +1 {{expected expanded type dimension #1 (8) to be a multipler of the distributed type dimension (3)}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) {
     %0 = arith.constant dense<2>: vector<4x8xi32>
     vector.yield %0 : vector<4x8xi32>
@@ -1676,7 +1676,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
 // -----
 
 func.func @warp_mismatch_rank(%laneid: index) {
-  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected distributed types to have same rank and element type.}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
     %0 = arith.constant dense<2>: vector<128xi32>
     vector.yield %0 : vector<128xi32>
@@ -1687,7 +1687,7 @@ func.func @warp_mismatch_rank(%laneid: index) {
 // -----
 
 func.func @warp_mismatch_rank(%laneid: index) {
-  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected vector type for distributed operands.}}
+  // expected-error at +1 {{'vector.warp_execute_on_lane_0' op expected shaped type for distributed operands.}}
   %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (i32) {
     %0 = arith.constant dense<2>: vector<128xi32>
     vector.yield %0 : vector<128xi32>



More information about the Mlir-commits mailing list