[Mlir-commits] [mlir] [mlir][XeGPU][Transform] Add vectorlinearize transform pass. (PR #158084)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Thu Sep 18 15:06:54 PDT 2025


https://github.com/mshahneo updated https://github.com/llvm/llvm-project/pull/158084

>From fc1656c2eafddf88cbf6312e777fa6b158350cc0 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 11 Sep 2025 14:05:03 +0000
Subject: [PATCH 1/6] [XeGPU][Transform] Add vectorlinearize transform pass.

Use upstream patterns to create a vectorlinearize pass needed
for lowering to xevm.
Linearizes n-D vectors to 1-D vectors.
---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |   9 ++
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   1 +
 .../XeGPU/Transforms/XeGPUVectorLinearize.cpp | 111 ++++++++++++++++++
 3 files changed, 121 insertions(+)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 59dca9f0d852a..77c57ccb0746f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -75,4 +75,13 @@ def XeGPUBlocking: Pass<"xegpu-blocking"> {
       "index::IndexDialect"];
 }
 
+def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> {
+  let summary = "Linearize n-D vectors to 1-D vectors";
+  let description = [{
+    This pass linearizes n-D vectors to 1-D vectors for lowering to XeVM.
+  }];
+  let dependentDialects = ["arith::ArithDialect", "memref::MemRefDialect",
+                           "scf::SCFDialect", "vector::VectorDialect"];
+}
+
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 9c178d1d85642..e6f76067094ce 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUUnroll.cpp
   XeGPUWgToSgDistribute.cpp
   XeGPUPropagateLayout.cpp
+  XeGPUVectorLinearize.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
new file mode 100644
index 0000000000000..a6a68716547c9
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
@@ -0,0 +1,111 @@
+//===- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors
+//-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUVECTORLINEARIZE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-vector-linearize"
+
+using namespace mlir;
+
+namespace {
+struct XeGPUVectorLinearizePass final
+    : public xegpu::impl::XeGPUVectorLinearizeBase<XeGPUVectorLinearizePass> {
+  void runOnOperation() override {
+    auto *context = &getContext();
+
+    // vector.broadcast and vector.gather requires progressive lowering
+    {
+      mlir::RewritePatternSet patterns(&getContext());
+      mlir::vector::populateVectorBroadcastLoweringPatterns(patterns);
+      mlir::vector::populateVectorGatherLoweringPatterns(patterns);
+      mlir::vector::populateVectorGatherToConditionalLoadPatterns(patterns);
+      // vector.transpose lowering
+      // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes.
+      mlir::vector::populateVectorTransposeLoweringPatterns(
+          patterns, mlir::vector::VectorTransposeLowering::Shuffle16x16);
+      (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
+    }
+
+    // Unroll load store from <<MxN> to M <1xN> load/stores and then linearize
+    {
+      mlir::RewritePatternSet patterns(&getContext());
+      mlir::vector::UnrollVectorOptions vectorOptions;
+      vectorOptions.setNativeShapeFn(
+          [](mlir::Operation *op) -> std::optional<mlir::SmallVector<int64_t>> {
+            // Only unroll for vector::LoadOp and vector::StoreOp
+            if (mlir::isa<mlir::vector::LoadOp>(op)) {
+              if (auto vecType = mlir::dyn_cast<mlir::VectorType>(
+                      op->getResult(0).getType())) {
+                auto shape = vecType.getShape();
+                if (shape.size() == 2)
+                  return mlir::SmallVector<int64_t>{1, shape[1]};
+              }
+            }
+            if (mlir::isa<mlir::vector::StoreOp>(op)) {
+              if (auto vecType = mlir::dyn_cast<mlir::VectorType>(
+                      op->getOperand(0).getType())) {
+                auto shape = vecType.getShape();
+                if (shape.size() == 2)
+                  return mlir::SmallVector<int64_t>{1, shape[1]};
+              }
+            }
+            return std::nullopt;
+          });
+      mlir::vector::populateVectorUnrollPatterns(patterns, vectorOptions);
+      (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
+    }
+
+    // Use upstream linearization patterns
+    {
+      mlir::MLIRContext &context = getContext();
+      mlir::TypeConverter converter;
+      mlir::RewritePatternSet patterns(&context);
+      mlir::ConversionTarget target(context);
+      mlir::vector::populateForVectorLinearize(converter, target);
+      mlir::vector::populateVectorLinearizeBasePatterns(converter, target,
+                                                        patterns);
+      mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
+          converter, target, patterns);
+      mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+          converter, patterns, target);
+      if (failed(applyPartialConversion(getOperation(), target,
+                                        std::move(patterns))))
+        return signalPassFailure();
+    }
+
+    mlir::TypeConverter typeConverter;
+    mlir::RewritePatternSet patterns(context);
+    mlir::ConversionTarget target(*context);
+    typeConverter.addConversion([](mlir::Type type) { return type; });
+
+    target.addIllegalOp<mlir::vector::TransposeOp>();
+    target.addLegalOp<mlir::vector::ShapeCastOp>();
+    target.addLegalOp<mlir::vector::ExtractOp>();
+    target.addLegalDialect<mlir::xegpu::XeGPUDialect>();
+  }
+};
+} // namespace

>From 884a06924bde2ccd73633977bcceed9bc10579e0 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Mon, 15 Sep 2025 22:12:20 +0000
Subject: [PATCH 2/6] Address review comments.

Add test case.
---
 .../XeGPU/Transforms/XeGPUVectorLinearize.cpp |  52 +--
 .../Dialect/XeGPU/xegpu-vector-linearize.mlir | 362 ++++++++++++++++++
 2 files changed, 383 insertions(+), 31 deletions(-)
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
index a6a68716547c9..78648042ae127 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
@@ -35,8 +35,6 @@ namespace {
 struct XeGPUVectorLinearizePass final
     : public xegpu::impl::XeGPUVectorLinearizeBase<XeGPUVectorLinearizePass> {
   void runOnOperation() override {
-    auto *context = &getContext();
-
     // vector.broadcast and vector.gather requires progressive lowering
     {
       mlir::RewritePatternSet patterns(&getContext());
@@ -56,30 +54,32 @@ struct XeGPUVectorLinearizePass final
       mlir::vector::UnrollVectorOptions vectorOptions;
       vectorOptions.setNativeShapeFn(
           [](mlir::Operation *op) -> std::optional<mlir::SmallVector<int64_t>> {
-            // Only unroll for vector::LoadOp and vector::StoreOp
-            if (mlir::isa<mlir::vector::LoadOp>(op)) {
-              if (auto vecType = mlir::dyn_cast<mlir::VectorType>(
-                      op->getResult(0).getType())) {
-                auto shape = vecType.getShape();
-                if (shape.size() == 2)
-                  return mlir::SmallVector<int64_t>{1, shape[1]};
-              }
-            }
-            if (mlir::isa<mlir::vector::StoreOp>(op)) {
-              if (auto vecType = mlir::dyn_cast<mlir::VectorType>(
-                      op->getOperand(0).getType())) {
-                auto shape = vecType.getShape();
-                if (shape.size() == 2)
-                  return mlir::SmallVector<int64_t>{1, shape[1]};
-              }
-            }
-            return std::nullopt;
+            auto extractVectorType =
+                [](mlir::Operation *op) -> mlir::VectorType {
+              if (auto loadOp = mlir::dyn_cast<mlir::vector::LoadOp>(op))
+                return mlir::dyn_cast<mlir::VectorType>(
+                    loadOp.getResult().getType());
+              if (auto storeOp = mlir::dyn_cast<mlir::vector::StoreOp>(op))
+                return mlir::dyn_cast<mlir::VectorType>(
+                    storeOp.getValueToStore().getType());
+              return nullptr;
+            };
+
+            auto vecType = extractVectorType(op);
+            if (!vecType)
+              return std::nullopt;
+
+            auto shape = vecType.getShape();
+            if (shape.size() != 2)
+              return std::nullopt;
+
+            return mlir::SmallVector<int64_t>{1, shape[1]};
           });
       mlir::vector::populateVectorUnrollPatterns(patterns, vectorOptions);
       (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
     }
 
-    // Use upstream linearization patterns
+    // Use vector linearization patterns
     {
       mlir::MLIRContext &context = getContext();
       mlir::TypeConverter converter;
@@ -96,16 +96,6 @@ struct XeGPUVectorLinearizePass final
                                         std::move(patterns))))
         return signalPassFailure();
     }
-
-    mlir::TypeConverter typeConverter;
-    mlir::RewritePatternSet patterns(context);
-    mlir::ConversionTarget target(*context);
-    typeConverter.addConversion([](mlir::Type type) { return type; });
-
-    target.addIllegalOp<mlir::vector::TransposeOp>();
-    target.addLegalOp<mlir::vector::ShapeCastOp>();
-    target.addLegalOp<mlir::vector::ExtractOp>();
-    target.addLegalDialect<mlir::xegpu::XeGPUDialect>();
   }
 };
 } // namespace
diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
new file mode 100644
index 0000000000000..61720884002c2
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -0,0 +1,362 @@
+// RUN: mlir-opt %s -split-input-file -xegpu-vector-linearize | FileCheck %s
+
+// CHECK-LABEL: @test_linearize
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xf32>) -> vector<2x2xf32> {
+//       CHECK: %[[T0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xf32> to vector<4xf32>
+//       CHECK: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+//       CHECK: %[[T1:.*]] = math.sin %[[T0]] : vector<4xf32>
+//       CHECK: %[[T2:.*]] = arith.addf %[[T0]], %[[CST]] : vector<4xf32>
+//       CHECK: %[[T3:.*]] = arith.addf %[[T2]], %[[T1]] : vector<4xf32>
+//       CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<4xf32> to vector<2x2xf32>
+//       CHECK: return %[[T4]] : vector<2x2xf32>
+func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+  %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+// Arith and math ops are handled in generic way, check some of them
+  %1 = math.sin %arg0 : vector<2x2xf32>
+  %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
+  %3 = arith.addf %2, %1 :  vector<2x2xf32>
+  return %3 : vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_const_novector
+//       CHECK:  %[[R:.*]] = arith.constant 42 : i32
+//       CHECK:  return %[[R]] : i32
+func.func @test_const_novector() -> i32 {
+  %0 = arith.constant 42 : i32
+  return %0 : i32
+}
+
+// -----
+// CHECK-LABEL: test_create_mask
+//       CHECK: vector.create_mask {{.*}} : vector<16xi1>
+func.func @test_create_mask() -> vector<1x16xi1> {
+  %c0 = arith.constant 0 : index
+  %c20 = arith.constant 20 : index
+  %0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
+  return %0 : vector<1x16xi1>
+}
+
+// -----
+// CHECK-LABEL: test_extract_strided_slice
+//  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<8x16xf32>) -> vector<8x8xf32>
+//       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<8x16xf32> to vector<128xf32>
+//       CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+//       CHECK: [8, 9, 10, 11, 12, 13, 14, 15,
+//       CHECK: 24, 25, 26, 27, 28, 29, 30, 31,
+//       CHECK: 40, 41, 42, 43, 44, 45, 46, 47,
+//       CHECK: 56, 57, 58, 59, 60, 61, 62, 63,
+//       CHECK: 72, 73, 74, 75, 76, 77, 78, 79,
+//       CHECK: 88, 89, 90, 91, 92, 93, 94, 95,
+//       CHECK: 104, 105, 106, 107, 108, 109, 110, 111,
+//       CHECK: 120, 121, 122, 123, 124, 125, 126, 127] : vector<128xf32>, vector<128xf32>
+//       CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<8x8xf32>
+//       CHECK: return %[[RES]] : vector<8x8xf32>
+func.func @test_extract_strided_slice_1(%arg0 : vector<8x16xf32>) -> vector<8x8xf32> {
+  %0 = vector.extract_strided_slice %arg0 { sizes = [8, 8], strides = [1, 1], offsets = [0, 8]}
+     : vector<8x16xf32> to vector<8x8xf32>
+  return %0 : vector<8x8xf32>
+}
+
+// -----
+// CHECK-LABEL: test_extract_strided_slice_2
+//  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x32x8xf32>) -> vector<1x8x8xf32>
+//       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x32x8xf32> to vector<512xf32>
+//       CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+//       CHECK: [448, 449, 450, 451, 452, 453, 454, 455,
+//       CHECK: 456, 457, 458, 459, 460, 461, 462, 463,
+//       CHECK: 464, 465, 466, 467, 468, 469, 470, 471,
+//       CHECK: 472, 473, 474, 475, 476, 477, 478, 479,
+//       CHECK: 480, 481, 482, 483, 484, 485, 486, 487,
+//       CHECK: 488, 489, 490, 491, 492, 493, 494, 495,
+//       CHECK: 496, 497, 498, 499, 500, 501, 502, 503,
+//       CHECK: 504, 505, 506, 507, 508, 509, 510, 511] : vector<512xf32>, vector<512xf32>
+//       CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<1x8x8xf32>
+//       CHECK: return %[[RES]] : vector<1x8x8xf32>
+func.func @test_extract_strided_slice_2(%arg0 : vector<2x32x8xf32>) -> vector<1x8x8xf32> {
+  %0 = vector.extract_strided_slice %arg0 { offsets = [1, 24], strides = [1, 1], sizes = [1, 8] }
+    : vector<2x32x8xf32> to vector<1x8x8xf32>
+  return %0 : vector<1x8x8xf32>
+}
+
+// -----
+// CHECK-LABEL: test_vector_shuffle
+//  CHECK-SAME: (%[[ORIG_ARG1:.*]]: vector<4x4xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>) -> vector<8x4xf32> {
+//       CHECK: %[[ARG2:.*]] = vector.shape_cast %[[ORIG_ARG2]] : vector<4x4xf32> to vector<16xf32>
+//       CHECK: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x4xf32> to vector<16xf32>
+//       CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG1]], %[[ARG2]]
+//       CHECK: [0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23,
+//       CHECK: 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+//       CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<32xf32> to vector<8x4xf32>
+//       CHECK: return %[[RES]] : vector<8x4xf32>
+func.func @test_vector_shuffle(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>) -> vector<8x4xf32> {
+  %0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x4xf32>, vector<4x4xf32>
+  return %0 : vector<8x4xf32>
+}
+
+// -----
+// CHECK-LABEL: test_vector_extract
+// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x4xf32>) -> vector<8x4xf32>
+// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x4xf32> to vector<64xf32>
+// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+// CHECK: [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
+// CHECK: 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<64xf32>
+// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<32xf32> to vector<8x4xf32>
+// CHECK: return %[[RES]] : vector<8x4xf32>
+func.func @test_vector_extract(%arg0: vector<2x8x4xf32>) -> vector<8x4xf32> {
+  %0 = vector.extract %arg0[1]: vector<8x4xf32> from vector<2x8x4xf32>
+  return %0 : vector<8x4xf32>
+}
+
+// -----
+// CHECK-LABEL: test_vector_insert
+// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32>
+// CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
+// CHECK: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
+// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
+// CHECK: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
+// CHECK-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
+// CHECK-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
+// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
+// CHECK: return %[[RES]] : vector<2x8x4xf32>
+func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
+  %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
+  return %0 : vector<2x8x4xf32>
+}
+
+// -----
+// CHECK-LABEL: test_vector_insert_2d_idx
+// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<4xf32>) -> vector<2x8x4xf32>
+// CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
+// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[SRC]]
+// CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 64, 65, 66, 67, 16, 17, 18, 19, 20, 21,
+// CHECK-SAME: 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
+// CHECK-SAME: 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<4xf32>
+// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
+// CHECK: return %[[RES]] : vector<2x8x4xf32>
+func.func @test_vector_insert_2d_idx(%arg0: vector<2x8x4xf32>, %arg1: vector<4xf32>) -> vector<2x8x4xf32> {
+  %0 = vector.insert %arg1, %arg0[0, 3]: vector<4xf32> into vector<2x8x4xf32>
+  return %0 : vector<2x8x4xf32>
+}
+
+// -----
+// CHECK-LABEL: test_vector_transpose
+// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8xf32>) -> vector<8x2xf32>
+// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8xf32> to vector<16xf32>
+// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
+// CHECK: [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<16xf32>, vector<16xf32>
+// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
+// CHECK: return %[[RES]] : vector<8x2xf32>
+func.func @test_vector_transpose(%arg: vector<2x8xf32>) -> vector<8x2xf32> {
+  %0 = vector.transpose %arg, [1, 0] : vector<2x8xf32> to vector<8x2xf32>
+  return %0 : vector<8x2xf32>
+}
+
+// -----
+// CHECK-LABEL: test_vector_transpose_16x16
+// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+func.func @test_vector_transpose_16x16(%arg: vector<16x16xf32>) -> vector<16x16xf32> {
+  %0 = vector.transpose %arg, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
+  return %0 : vector<16x16xf32>
+}
+
+// -----
+// CHECK-LABEL: func.func @test_vector_store_load_4x4
+// CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4xf32>)
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[V0:.*]] = vector.load %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
+// CHECK: %[[V1:.*]] = vector.load %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
+// CHECK: %[[V2:.*]] = vector.load %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
+// CHECK: %[[V3:.*]] = vector.load %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
+// CHECK: vector.store %[[V0]], %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
+// CHECK: vector.store %[[V1]], %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
+// CHECK: vector.store %[[V2]], %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
+// CHECK: vector.store %[[V3]], %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
+func.func @test_vector_store_load_4x4(%buffer: memref<4x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %buffer[%c0, %c0] : memref<4x4xf32>, vector<4x4xf32>
+  vector.store %0, %buffer[%c0, %c0] : memref<4x4xf32>, vector<4x4xf32>
+  return
+}
+
+// -----
+
+func.func @test_vector_store_load_4x4_f16(%buffer: memref<4x4xf16>) {
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  vector.store %0, %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
+  return
+}
+// CHECK-LABEL: func.func @test_vector_store_load_4x4_f16
+// CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4xf16>)
+// CHECK: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[V0:.*]] = vector.load %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: %[[V1:.*]] = vector.load %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: %[[V2:.*]] = vector.load %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: %[[V3:.*]] = vector.load %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: vector.store %[[V0]], %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: vector.store %[[V1]], %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: vector.store %[[V2]], %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: vector.store %[[V3]], %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+
+// -----
+// CHECK-LABEL: @test_linearize_index
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xindex>, %[[ARG1:.*]]: vector<2x2xi32>) -> vector<2x2xindex> {
+//       CHECK: %[[T0:.*]] = vector.shape_cast %[[ARG1]] : vector<2x2xi32> to vector<4xi32>
+//       CHECK: %[[T1:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xindex> to vector<4xindex>
+//       CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+//       CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[CST]] : vector<4xindex>
+//       CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : vector<4xindex> to vector<4xi32>
+//       CHECK: %[[T4:.*]] = arith.muli %[[T3]], %[[T0]] : vector<4xi32>
+//       CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : vector<4xi32> to vector<4xindex>
+//       CHECK: %[[T6:.*]] = vector.shape_cast %[[T5]] : vector<4xindex> to vector<2x2xindex>
+//       CHECK: return %[[T6]] : vector<2x2xindex>
+func.func @test_linearize_index(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi32>) -> vector<2x2xindex> {
+  %0 = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xindex>
+// Arith and math ops are handled in generic way, check some of them
+  %1 = arith.addi %arg0, %0 :  vector<2x2xindex>
+  %2 = arith.index_cast %1 : vector<2x2xindex> to vector<2x2xi32>
+  %3 = arith.muli %2, %arg1 : vector<2x2xi32>
+  %4 = arith.index_cast %3 : vector<2x2xi32> to vector<2x2xindex>
+  return %4 : vector<2x2xindex>
+}
+
+// -----
+// CHECK-LABEL: @add_kernel_f32
+//       CHECK: %[[CST0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>
+//       CHECK: %[[CST1:.*]] = arith.constant dense<[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : vector<16xindex>
+//       CHECK: %[[T0:.*]] = vector.splat %{{.*}} : vector<16xindex>
+//       CHECK: %[[T1:.*]] = arith.addi %[[T0]], %[[CST0]] : vector<16xindex>
+//       CHECK: %[[T2:.*]] = arith.addi %[[T0]], %[[CST1]] : vector<16xindex>
+//       CHECK: %[[T3:.*]] = arith.index_cast %[[T1]] : vector<16xindex> to vector<16xi32>
+//       CHECK: %[[T4:.*]] = arith.index_cast %[[T2]] : vector<16xindex> to vector<16xi32>
+//       CHECK: %[[T5:.*]] = vector.splat %{{.*}} : vector<16xi32>
+//       CHECK: %[[T6:.*]] = arith.addi %[[T5]], %[[T3]] : vector<16xi32>
+//       CHECK: %[[T7:.*]] = arith.addi %[[T5]], %[[T4]] : vector<16xi32>
+//       CHECK: %[[T8:.*]] = arith.index_cast %[[T6]] : vector<16xi32> to vector<16xindex>
+//       CHECK: %[[T9:.*]] = arith.index_cast %[[T7]] : vector<16xi32> to vector<16xindex>
+gpu.module @add_kernel_f32 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Bfloat16ConversionINTEL, BFloat16TypeKHR, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, VectorAnyINTEL, VectorComputeINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_bfloat16, SPV_KHR_expect_assume, SPV_INTEL_bfloat16_conversion, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+  gpu.func @add_kernel_f32(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array<i32: 1, 32, 1>, known_grid_size = array<i32: 1, 1, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+    %cst = arith.constant dense<true> : vector<16xi1>
+    %c32 = arith.constant 32 : index
+    %c1024_i32 = arith.constant 1024 : i32
+    %cst_0 = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
+    %cst_1 = arith.constant dense<[[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]> : vector<1x16xindex>
+    %thread_id_x = gpu.thread_id  x
+    %thread_id_y = gpu.thread_id  y
+    %block_dim_y = gpu.block_dim  y
+    %0 = arith.muli %thread_id_x, %block_dim_y : index
+    %1 = arith.addi %0, %thread_id_y : index
+    %cast = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
+    %cast_2 = memref.cast %arg1 : memref<*xf32> to memref<?xf32>
+    %cast_3 = memref.cast %arg2 : memref<*xf32> to memref<?xf32>
+    %2 = arith.remsi %1, %c32 : index
+    %3 = arith.muli %2, %c32 : index
+    %4 = vector.splat %3 : vector<1x16xindex>
+    %5 = arith.addi %4, %cst_0 : vector<1x16xindex>
+    %6 = arith.addi %4, %cst_1 : vector<1x16xindex>
+    %7 = arith.index_cast %5 : vector<1x16xindex> to vector<1x16xi32>
+    %8 = arith.index_cast %6 : vector<1x16xindex> to vector<1x16xi32>
+    %block_id_x = gpu.block_id  x
+    %9 = arith.index_cast %block_id_x : index to i32
+    %10 = arith.muli %9, %c1024_i32 : i32
+    %11 = vector.splat %10 : vector<1x16xi32>
+    %12 = arith.addi %11, %7 : vector<1x16xi32>
+    %13 = arith.addi %11, %8 : vector<1x16xi32>
+    %14 = arith.index_cast %12 : vector<1x16xi32> to vector<1x16xindex>
+    %15 = arith.index_cast %13 : vector<1x16xi32> to vector<1x16xindex>
+    %16 = vector.shape_cast %14 : vector<1x16xindex> to vector<16xindex>
+    %17 = xegpu.create_tdesc %cast, %16 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
+    %18 = vector.shape_cast %15 : vector<1x16xindex> to vector<16xindex>
+    %19 = xegpu.create_tdesc %cast, %18 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
+    %20 = xegpu.load %17, %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
+    %21 = vector.shape_cast %20 : vector<16xf32> to vector<1x16xf32>
+    %22 = xegpu.load %19, %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
+    %23 = vector.shape_cast %22 : vector<16xf32> to vector<1x16xf32>
+    %24 = xegpu.create_tdesc %cast_2, %16 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
+    %25 = xegpu.create_tdesc %cast_2, %18 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
+    %26 = xegpu.load %24, %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
+    %27 = vector.shape_cast %26 : vector<16xf32> to vector<1x16xf32>
+    %28 = xegpu.load %25, %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
+    %29 = vector.shape_cast %28 : vector<16xf32> to vector<1x16xf32>
+    %30 = arith.addf %21, %27 : vector<1x16xf32>
+    %31 = arith.addf %23, %29 : vector<1x16xf32>
+    %32 = xegpu.create_tdesc %cast_3, %16 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
+    %33 = xegpu.create_tdesc %cast_3, %18 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
+    %34 = vector.shape_cast %30 : vector<1x16xf32> to vector<16xf32>
+    xegpu.store %34, %32, %cst <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1>
+    %35 = vector.shape_cast %31 : vector<1x16xf32> to vector<16xf32>
+    xegpu.store %35, %33, %cst <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1>
+    gpu.return
+  }
+}

>From f12ddd03cdeed752b7b6784e81642ae6df46cc1c Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Wed, 17 Sep 2025 23:06:48 +0000
Subject: [PATCH 3/6] Address review comments.

Update the test case to remove duplication with vector-linearize.
Add new test cases for XeGPU, vector.broadcast, vector.gather.
---
 .../XeGPU/Transforms/XeGPUVectorLinearize.cpp |  67 ++-
 .../Dialect/XeGPU/xegpu-vector-linearize.mlir | 529 ++++++++----------
 2 files changed, 258 insertions(+), 338 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
index 78648042ae127..2bb302f4287c4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
@@ -1,5 +1,4 @@
-//===- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors
-//-------===//
+//===-- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors --===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -37,31 +36,29 @@ struct XeGPUVectorLinearizePass final
   void runOnOperation() override {
     // vector.broadcast and vector.gather requires progressive lowering
     {
-      mlir::RewritePatternSet patterns(&getContext());
-      mlir::vector::populateVectorBroadcastLoweringPatterns(patterns);
-      mlir::vector::populateVectorGatherLoweringPatterns(patterns);
-      mlir::vector::populateVectorGatherToConditionalLoadPatterns(patterns);
+      RewritePatternSet patterns(&getContext());
+      vector::populateVectorBroadcastLoweringPatterns(patterns);
+      vector::populateVectorGatherLoweringPatterns(patterns);
+      vector::populateVectorGatherToConditionalLoadPatterns(patterns);
       // vector.transpose lowering
       // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes.
-      mlir::vector::populateVectorTransposeLoweringPatterns(
-          patterns, mlir::vector::VectorTransposeLowering::Shuffle16x16);
-      (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
+      vector::populateVectorTransposeLoweringPatterns(
+          patterns, vector::VectorTransposeLowering::Shuffle16x16);
+      if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+        return signalPassFailure();
     }
 
-    // Unroll load store from <<MxN> to M <1xN> load/stores and then linearize
+    // Unroll load store from <MxN> to M <1xN> load/stores and then linearize
     {
-      mlir::RewritePatternSet patterns(&getContext());
-      mlir::vector::UnrollVectorOptions vectorOptions;
+      RewritePatternSet patterns(&getContext());
+      vector::UnrollVectorOptions vectorOptions;
       vectorOptions.setNativeShapeFn(
-          [](mlir::Operation *op) -> std::optional<mlir::SmallVector<int64_t>> {
-            auto extractVectorType =
-                [](mlir::Operation *op) -> mlir::VectorType {
-              if (auto loadOp = mlir::dyn_cast<mlir::vector::LoadOp>(op))
-                return mlir::dyn_cast<mlir::VectorType>(
-                    loadOp.getResult().getType());
-              if (auto storeOp = mlir::dyn_cast<mlir::vector::StoreOp>(op))
-                return mlir::dyn_cast<mlir::VectorType>(
-                    storeOp.getValueToStore().getType());
+          [](Operation *op) -> std::optional<SmallVector<int64_t>> {
+            auto extractVectorType = [](Operation *op) -> VectorType {
+              if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+                return loadOp.getVectorType();
+              if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+                return storeOp.getVectorType();
               return nullptr;
             };
 
@@ -73,25 +70,25 @@ struct XeGPUVectorLinearizePass final
             if (shape.size() != 2)
               return std::nullopt;
 
-            return mlir::SmallVector<int64_t>{1, shape[1]};
+            return SmallVector<int64_t>{1, shape[1]};
           });
-      mlir::vector::populateVectorUnrollPatterns(patterns, vectorOptions);
-      (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
+      vector::populateVectorUnrollPatterns(patterns, vectorOptions);
+      if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+        return signalPassFailure();
     }
 
     // Use vector linearization patterns
     {
-      mlir::MLIRContext &context = getContext();
-      mlir::TypeConverter converter;
-      mlir::RewritePatternSet patterns(&context);
-      mlir::ConversionTarget target(context);
-      mlir::vector::populateForVectorLinearize(converter, target);
-      mlir::vector::populateVectorLinearizeBasePatterns(converter, target,
-                                                        patterns);
-      mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
-          converter, target, patterns);
-      mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
-          converter, patterns, target);
+      MLIRContext &context = getContext();
+      TypeConverter converter;
+      RewritePatternSet patterns(&context);
+      ConversionTarget target(context);
+      vector::populateForVectorLinearize(converter, target);
+      vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
+      vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
+                                                            patterns);
+      scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                           target);
       if (failed(applyPartialConversion(getOperation(), target,
                                         std::move(patterns))))
         return signalPassFailure();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 61720884002c2..9985736e2cafb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -1,131 +1,5 @@
-// RUN: mlir-opt %s -split-input-file -xegpu-vector-linearize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -xegpu-vector-linearize -canonicalize | FileCheck %s
 
-// CHECK-LABEL: @test_linearize
-//  CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xf32>) -> vector<2x2xf32> {
-//       CHECK: %[[T0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xf32> to vector<4xf32>
-//       CHECK: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
-//       CHECK: %[[T1:.*]] = math.sin %[[T0]] : vector<4xf32>
-//       CHECK: %[[T2:.*]] = arith.addf %[[T0]], %[[CST]] : vector<4xf32>
-//       CHECK: %[[T3:.*]] = arith.addf %[[T2]], %[[T1]] : vector<4xf32>
-//       CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<4xf32> to vector<2x2xf32>
-//       CHECK: return %[[T4]] : vector<2x2xf32>
-func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
-  %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
-// Arith and math ops are handled in generic way, check some of them
-  %1 = math.sin %arg0 : vector<2x2xf32>
-  %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
-  %3 = arith.addf %2, %1 :  vector<2x2xf32>
-  return %3 : vector<2x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: test_const_novector
-//       CHECK:  %[[R:.*]] = arith.constant 42 : i32
-//       CHECK:  return %[[R]] : i32
-func.func @test_const_novector() -> i32 {
-  %0 = arith.constant 42 : i32
-  return %0 : i32
-}
-
-// -----
-// CHECK-LABEL: test_create_mask
-//       CHECK: vector.create_mask {{.*}} : vector<16xi1>
-func.func @test_create_mask() -> vector<1x16xi1> {
-  %c0 = arith.constant 0 : index
-  %c20 = arith.constant 20 : index
-  %0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
-  return %0 : vector<1x16xi1>
-}
-
-// -----
-// CHECK-LABEL: test_extract_strided_slice
-//  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<8x16xf32>) -> vector<8x8xf32>
-//       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<8x16xf32> to vector<128xf32>
-//       CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
-//       CHECK: [8, 9, 10, 11, 12, 13, 14, 15,
-//       CHECK: 24, 25, 26, 27, 28, 29, 30, 31,
-//       CHECK: 40, 41, 42, 43, 44, 45, 46, 47,
-//       CHECK: 56, 57, 58, 59, 60, 61, 62, 63,
-//       CHECK: 72, 73, 74, 75, 76, 77, 78, 79,
-//       CHECK: 88, 89, 90, 91, 92, 93, 94, 95,
-//       CHECK: 104, 105, 106, 107, 108, 109, 110, 111,
-//       CHECK: 120, 121, 122, 123, 124, 125, 126, 127] : vector<128xf32>, vector<128xf32>
-//       CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<8x8xf32>
-//       CHECK: return %[[RES]] : vector<8x8xf32>
-func.func @test_extract_strided_slice_1(%arg0 : vector<8x16xf32>) -> vector<8x8xf32> {
-  %0 = vector.extract_strided_slice %arg0 { sizes = [8, 8], strides = [1, 1], offsets = [0, 8]}
-     : vector<8x16xf32> to vector<8x8xf32>
-  return %0 : vector<8x8xf32>
-}
-
-// -----
-// CHECK-LABEL: test_extract_strided_slice_2
-//  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x32x8xf32>) -> vector<1x8x8xf32>
-//       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x32x8xf32> to vector<512xf32>
-//       CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
-//       CHECK: [448, 449, 450, 451, 452, 453, 454, 455,
-//       CHECK: 456, 457, 458, 459, 460, 461, 462, 463,
-//       CHECK: 464, 465, 466, 467, 468, 469, 470, 471,
-//       CHECK: 472, 473, 474, 475, 476, 477, 478, 479,
-//       CHECK: 480, 481, 482, 483, 484, 485, 486, 487,
-//       CHECK: 488, 489, 490, 491, 492, 493, 494, 495,
-//       CHECK: 496, 497, 498, 499, 500, 501, 502, 503,
-//       CHECK: 504, 505, 506, 507, 508, 509, 510, 511] : vector<512xf32>, vector<512xf32>
-//       CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<1x8x8xf32>
-//       CHECK: return %[[RES]] : vector<1x8x8xf32>
-func.func @test_extract_strided_slice_2(%arg0 : vector<2x32x8xf32>) -> vector<1x8x8xf32> {
-  %0 = vector.extract_strided_slice %arg0 { offsets = [1, 24], strides = [1, 1], sizes = [1, 8] }
-    : vector<2x32x8xf32> to vector<1x8x8xf32>
-  return %0 : vector<1x8x8xf32>
-}
-
-// -----
-// CHECK-LABEL: test_vector_shuffle
-//  CHECK-SAME: (%[[ORIG_ARG1:.*]]: vector<4x4xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>) -> vector<8x4xf32> {
-//       CHECK: %[[ARG2:.*]] = vector.shape_cast %[[ORIG_ARG2]] : vector<4x4xf32> to vector<16xf32>
-//       CHECK: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x4xf32> to vector<16xf32>
-//       CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG1]], %[[ARG2]]
-//       CHECK: [0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23,
-//       CHECK: 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-//       CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<32xf32> to vector<8x4xf32>
-//       CHECK: return %[[RES]] : vector<8x4xf32>
-func.func @test_vector_shuffle(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>) -> vector<8x4xf32> {
-  %0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x4xf32>, vector<4x4xf32>
-  return %0 : vector<8x4xf32>
-}
-
-// -----
-// CHECK-LABEL: test_vector_extract
-// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x4xf32>) -> vector<8x4xf32>
-// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x4xf32> to vector<64xf32>
-// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
-// CHECK: [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
-// CHECK: 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<64xf32>
-// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<32xf32> to vector<8x4xf32>
-// CHECK: return %[[RES]] : vector<8x4xf32>
-func.func @test_vector_extract(%arg0: vector<2x8x4xf32>) -> vector<8x4xf32> {
-  %0 = vector.extract %arg0[1]: vector<8x4xf32> from vector<2x8x4xf32>
-  return %0 : vector<8x4xf32>
-}
-
-// -----
-// CHECK-LABEL: test_vector_insert
-// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32>
-// CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
-// CHECK: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
-// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
-// CHECK: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
-// CHECK-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
-// CHECK-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
-// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
-// CHECK: return %[[RES]] : vector<2x8x4xf32>
-func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
-  %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
-  return %0 : vector<2x8x4xf32>
-}
-
-// -----
 // CHECK-LABEL: test_vector_insert_2d_idx
 // CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<4xf32>) -> vector<2x8x4xf32>
 // CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
@@ -157,133 +31,49 @@ func.func @test_vector_transpose(%arg: vector<2x8xf32>) -> vector<8x2xf32> {
 // CHECK-LABEL: test_vector_transpose_16x16
 // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
 // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
-// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+// CHECK-62: vector.shuffle
 func.func @test_vector_transpose_16x16(%arg: vector<16x16xf32>) -> vector<16x16xf32> {
   %0 = vector.transpose %arg, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
   return %0 : vector<16x16xf32>
 }
 
-// -----
-// CHECK-LABEL: func.func @test_vector_store_load_4x4
-// CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4xf32>)
-// CHECK: %[[C3:.*]] = arith.constant 3 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[V0:.*]] = vector.load %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
-// CHECK: %[[V1:.*]] = vector.load %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
-// CHECK: %[[V2:.*]] = vector.load %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
-// CHECK: %[[V3:.*]] = vector.load %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
-// CHECK: vector.store %[[V0]], %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
-// CHECK: vector.store %[[V1]], %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
-// CHECK: vector.store %[[V2]], %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
-// CHECK: vector.store %[[V3]], %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf32>, vector<4xf32>
-func.func @test_vector_store_load_4x4(%buffer: memref<4x4xf32>) {
-  %c0 = arith.constant 0 : index
-  %0 = vector.load %buffer[%c0, %c0] : memref<4x4xf32>, vector<4x4xf32>
-  vector.store %0, %buffer[%c0, %c0] : memref<4x4xf32>, vector<4x4xf32>
-  return
-}
-
 // -----
 
+// CHECK-LABEL: func.func @test_vector_store_load_4x4_f16
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf16>)
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[LOAD0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: %[[LOAD1:.*]] = vector.load %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: %[[LOAD2:.*]] = vector.load %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: %[[LOAD3:.*]] = vector.load %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: vector.store %[[LOAD0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: vector.store %[[LOAD1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: vector.store %[[LOAD2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
+// CHECK: vector.store %[[LOAD3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
 func.func @test_vector_store_load_4x4_f16(%buffer: memref<4x4xf16>) {
   %c0 = arith.constant 0 : index
   %0 = vector.load %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
   vector.store %0, %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
   return
 }
-// CHECK-LABEL: func.func @test_vector_store_load_4x4_f16
-// CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4xf16>)
-// CHECK: %[[C3:.*]] = arith.constant 3 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[V0:.*]] = vector.load %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-// CHECK: %[[V1:.*]] = vector.load %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-// CHECK: %[[V2:.*]] = vector.load %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-// CHECK: %[[V3:.*]] = vector.load %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-// CHECK: vector.store %[[V0]], %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-// CHECK: vector.store %[[V1]], %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-// CHECK: vector.store %[[V2]], %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-// CHECK: vector.store %[[V3]], %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
-
 // -----
-// CHECK-LABEL: @test_linearize_index
-//  CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xindex>, %[[ARG1:.*]]: vector<2x2xi32>) -> vector<2x2xindex> {
-//       CHECK: %[[T0:.*]] = vector.shape_cast %[[ARG1]] : vector<2x2xi32> to vector<4xi32>
-//       CHECK: %[[T1:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xindex> to vector<4xindex>
-//       CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-//       CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[CST]] : vector<4xindex>
-//       CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : vector<4xindex> to vector<4xi32>
-//       CHECK: %[[T4:.*]] = arith.muli %[[T3]], %[[T0]] : vector<4xi32>
-//       CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : vector<4xi32> to vector<4xindex>
-//       CHECK: %[[T6:.*]] = vector.shape_cast %[[T5]] : vector<4xindex> to vector<2x2xindex>
-//       CHECK: return %[[T6]] : vector<2x2xindex>
+// CHECK-LABEL: func.func @test_linearize_index
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xindex>, %[[ARG1:.*]]: vector<2x2xi32>) -> vector<2x2xindex>
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: %[[CAST1:.*]] = vector.shape_cast %[[ARG1]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xindex> to vector<4xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST2]], %[[CST]] : vector<4xindex>
+// CHECK: %[[INDEX_CAST1:.*]] = arith.index_cast %[[ADDI]] : vector<4xindex> to vector<4xi32>
+// CHECK: %[[MULI:.*]] = arith.muli %[[INDEX_CAST1]], %[[CAST1]] : vector<4xi32>
+// CHECK: %[[INDEX_CAST2:.*]] = arith.index_cast %[[MULI]] : vector<4xi32> to vector<4xindex>
+// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[INDEX_CAST2]] : vector<4xindex> to vector<2x2xindex>
+// CHECK: return %[[RESULT]] : vector<2x2xindex>
 func.func @test_linearize_index(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi32>) -> vector<2x2xindex> {
   %0 = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xindex>
-// Arith and math ops are handled in generic way, check some of them
+  // Arith and math ops are handled in generic way, check some of them
   %1 = arith.addi %arg0, %0 :  vector<2x2xindex>
   %2 = arith.index_cast %1 : vector<2x2xindex> to vector<2x2xi32>
   %3 = arith.muli %2, %arg1 : vector<2x2xi32>
@@ -292,71 +82,204 @@ func.func @test_linearize_index(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi32>
 }
 
 // -----
-// CHECK-LABEL: @add_kernel_f32
-//       CHECK: %[[CST0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>
-//       CHECK: %[[CST1:.*]] = arith.constant dense<[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : vector<16xindex>
-//       CHECK: %[[T0:.*]] = vector.splat %{{.*}} : vector<16xindex>
-//       CHECK: %[[T1:.*]] = arith.addi %[[T0]], %[[CST0]] : vector<16xindex>
-//       CHECK: %[[T2:.*]] = arith.addi %[[T0]], %[[CST1]] : vector<16xindex>
-//       CHECK: %[[T3:.*]] = arith.index_cast %[[T1]] : vector<16xindex> to vector<16xi32>
-//       CHECK: %[[T4:.*]] = arith.index_cast %[[T2]] : vector<16xindex> to vector<16xi32>
-//       CHECK: %[[T5:.*]] = vector.splat %{{.*}} : vector<16xi32>
-//       CHECK: %[[T6:.*]] = arith.addi %[[T5]], %[[T3]] : vector<16xi32>
-//       CHECK: %[[T7:.*]] = arith.addi %[[T5]], %[[T4]] : vector<16xi32>
-//       CHECK: %[[T8:.*]] = arith.index_cast %[[T6]] : vector<16xi32> to vector<16xindex>
-//       CHECK: %[[T9:.*]] = arith.index_cast %[[T7]] : vector<16xi32> to vector<16xindex>
-gpu.module @add_kernel_f32 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Bfloat16ConversionINTEL, BFloat16TypeKHR, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, VectorAnyINTEL, VectorComputeINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_bfloat16, SPV_KHR_expect_assume, SPV_INTEL_bfloat16_conversion, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
-  gpu.func @add_kernel_f32(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array<i32: 1, 32, 1>, known_grid_size = array<i32: 1, 1, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
-    %cst = arith.constant dense<true> : vector<16xi1>
-    %c32 = arith.constant 32 : index
-    %c1024_i32 = arith.constant 1024 : i32
-    %cst_0 = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
-    %cst_1 = arith.constant dense<[[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]> : vector<1x16xindex>
-    %thread_id_x = gpu.thread_id  x
-    %thread_id_y = gpu.thread_id  y
-    %block_dim_y = gpu.block_dim  y
-    %0 = arith.muli %thread_id_x, %block_dim_y : index
-    %1 = arith.addi %0, %thread_id_y : index
-    %cast = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
-    %cast_2 = memref.cast %arg1 : memref<*xf32> to memref<?xf32>
-    %cast_3 = memref.cast %arg2 : memref<*xf32> to memref<?xf32>
-    %2 = arith.remsi %1, %c32 : index
-    %3 = arith.muli %2, %c32 : index
-    %4 = vector.splat %3 : vector<1x16xindex>
-    %5 = arith.addi %4, %cst_0 : vector<1x16xindex>
-    %6 = arith.addi %4, %cst_1 : vector<1x16xindex>
-    %7 = arith.index_cast %5 : vector<1x16xindex> to vector<1x16xi32>
-    %8 = arith.index_cast %6 : vector<1x16xindex> to vector<1x16xi32>
-    %block_id_x = gpu.block_id  x
-    %9 = arith.index_cast %block_id_x : index to i32
-    %10 = arith.muli %9, %c1024_i32 : i32
-    %11 = vector.splat %10 : vector<1x16xi32>
-    %12 = arith.addi %11, %7 : vector<1x16xi32>
-    %13 = arith.addi %11, %8 : vector<1x16xi32>
-    %14 = arith.index_cast %12 : vector<1x16xi32> to vector<1x16xindex>
-    %15 = arith.index_cast %13 : vector<1x16xi32> to vector<1x16xindex>
-    %16 = vector.shape_cast %14 : vector<1x16xindex> to vector<16xindex>
-    %17 = xegpu.create_tdesc %cast, %16 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
-    %18 = vector.shape_cast %15 : vector<1x16xindex> to vector<16xindex>
-    %19 = xegpu.create_tdesc %cast, %18 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
-    %20 = xegpu.load %17, %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
-    %21 = vector.shape_cast %20 : vector<16xf32> to vector<1x16xf32>
-    %22 = xegpu.load %19, %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
-    %23 = vector.shape_cast %22 : vector<16xf32> to vector<1x16xf32>
-    %24 = xegpu.create_tdesc %cast_2, %16 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
-    %25 = xegpu.create_tdesc %cast_2, %18 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
-    %26 = xegpu.load %24, %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
-    %27 = vector.shape_cast %26 : vector<16xf32> to vector<1x16xf32>
-    %28 = xegpu.load %25, %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1> -> vector<16xf32>
-    %29 = vector.shape_cast %28 : vector<16xf32> to vector<1x16xf32>
-    %30 = arith.addf %21, %27 : vector<1x16xf32>
-    %31 = arith.addf %23, %29 : vector<1x16xf32>
-    %32 = xegpu.create_tdesc %cast_3, %16 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
-    %33 = xegpu.create_tdesc %cast_3, %18 : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>
-    %34 = vector.shape_cast %30 : vector<1x16xf32> to vector<16xf32>
-    xegpu.store %34, %32, %cst <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1>
-    %35 = vector.shape_cast %31 : vector<1x16xf32> to vector<16xf32>
-    xegpu.store %35, %33, %cst <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<write_back>, l3_hint = #xegpu.cache_hint<write_back>}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space =  global, chunk_size = 1 : i64>>, vector<16xi1>
+// CHECK-LABEL: func.func @broadcast_stretch_at_start
+// CHECK-SAME: (%[[ARG0:.*]]: vector<1x4xf32>) -> vector<3x4xf32>
+// CHECK: %[[POISON:.*]] = ub.poison : vector<12xf32>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[ARG0]] : vector<1x4xf32> to vector<4xf32>
+// CHECK: %[[SHUFFLE1:.*]] = vector.shuffle %[[POISON]], %[[CAST]] [12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32>
+// CHECK: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[CAST]] [0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32>
+// CHECK: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[CAST]] [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] : vector<12xf32>, vector<4xf32>
+// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<12xf32> to vector<3x4xf32>
+func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
+  %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32>
+  return %0 : vector<3x4xf32>
+}
+
+// -----
+// CHECK-LABEL: func.func @broadcast_stretch_at_end
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4x1xf32>) -> vector<4x3xf32>
+// CHECK: %[[POISON:.*]] = ub.poison : vector<12xf32>
+// CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG0]][0, 0] : f32 from vector<4x1xf32>
+// CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[EXTRACT1]] : f32 to vector<3xf32>
+// CHECK: vector.shuffle
+// CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1, 0] : f32 from vector<4x1xf32>
+// CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[EXTRACT2]] : f32 to vector<3xf32>
+// CHECK: vector.shuffle
+// CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG0]][2, 0] : f32 from vector<4x1xf32>
+// CHECK: %[[BROADCAST3:.*]] = vector.broadcast %[[EXTRACT3]] : f32 to vector<3xf32>
+// CHECK: vector.shuffle
+// CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][3, 0] : f32 from vector<4x1xf32>
+// CHECK: %[[BROADCAST4:.*]] = vector.broadcast %[[EXTRACT4]] : f32 to vector<3xf32>
+// CHECK: vector.shuffle
+// CHECK: vector.shape_cast {{.*}} : vector<12xf32> to vector<4x3xf32>
+func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
+  %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32>
+  return %0 : vector<4x3xf32>
+}
+
+// -----
+// CHECK-LABEL: func.func @broadcast_stretch_in_middle
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4x1x2xf32>) -> vector<4x3x2xf32>
+// CHECK: ub.poison : vector<6xf32>
+// CHECK: ub.poison : vector<24xf32>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[ARG0]] : vector<4x1x2xf32> to vector<8xf32>
+// CHECK-COUNT-20: vector.shuffle
+// CHECK: vector.shape_cast {{.*}} : vector<24xf32> to vector<4x3x2xf32>
+// CHECK-NOT: vector.broadcast
+func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
+  %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32>
+  return %0 : vector<4x3x2xf32>
+}
+
+// CHECK-LABEL: func.func @gather_memref_2d
+// CHECK-SAME: (%arg0: memref<?x?xf32>, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
+
+// CHECK: %0 = ub.poison : vector<6xf32>
+// CHECK: %c1 = arith.constant 1 : index
+// CHECK: %c0 = arith.constant 0 : index
+// CHECK: %1 = vector.shape_cast %arg3 : vector<2x3xf32> to vector<6xf32>
+
+// First shuffle + if ladder for row 0
+// CHECK: %2 = vector.shuffle %1, %1 [0, 1, 2]
+// CHECK: %3 = vector.extract %arg2[0, 0]
+// CHECK: %4 = vector.extract %arg1[0, 0]
+// CHECK: %5 = arith.addi %4, %c1
+// CHECK: %6 = scf.if %3 -> (vector<3xf32>) {
+// CHECK:   %{{.*}} = vector.load %arg0[%c0, %5] : memref<?x?xf32>, vector<1xf32>
+// CHECK:   %{{.*}} = vector.extract {{.*}}[0] : f32
+// CHECK:   %{{.*}} = vector.insert {{.*}}, %2 [0] : f32 into vector<3xf32>
+// CHECK:   scf.yield {{.*}} : vector<3xf32>
+// CHECK: } else {
+// CHECK:   scf.yield %2 : vector<3xf32>
+// CHECK: }
+
+// CHECK: %7 = vector.extract %arg2[0, 1]
+// CHECK: %8 = vector.extract %arg1[0, 1]
+// CHECK: %9 = arith.addi %8, %c1
+// CHECK: %10 = scf.if %7 -> (vector<3xf32>)
+
+// … (similar checks for the rest of row 0, then row 1)
+
+// CHECK: %15 = vector.shuffle %0, %{{.*}} [6, 7, 8, 3, 4, 5]
+// CHECK: %16 = vector.shuffle %1, %1 [3, 4, 5]
+
+// Row 1 if ladder checks
+// CHECK: %17 = vector.extract %arg2[1, 0]
+// CHECK: %18 = vector.extract %arg1[1, 0]
+// CHECK: %19 = arith.addi %18, %c1
+// CHECK: %20 = scf.if %17 -> (vector<3xf32>)
+
+// … (similar checks for remaining row 1 inserts)
+
+// Final reshuffle and cast
+// CHECK: %29 = vector.shuffle %15, %{{.*}} [0, 1, 2, 6, 7, 8]
+// CHECK: %30 = vector.shape_cast %29 : vector<6xf32> to vector<2x3xf32>
+// CHECK: return %30 : vector<2x3xf32>
+func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+// -----
+// Check for vector linearization in XeGPU dialect.
+// The vector<64xf16> loaded from memory is linearized into 4 vector<8xf16> using vector.shuffle ops.
+// The pattern is similar to the one used in test_vector_transpose_16x16 above.
+gpu.module @test_kernel {
+  // CHECK-LABEL: gpu.func @test_kernel
+  gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf32>) kernel {
+    %c24 = arith.constant 24 : index
+    %c16 = arith.constant 16 : index
+    %c8 = arith.constant 8 : index
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<32x32xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+    %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<64xf16>
+    // CHECK: %[[V1:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<64xf16>, vector<64xf16>
+    %2 = vector.shape_cast %1 : vector<64xf16> to vector<2x32x1xf16>
+    %3 = vector.extract %2[0] : vector<32x1xf16> from vector<2x32x1xf16>
+    // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[V1]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xf16>, vector<32xf16>
+    %4 = vector.extract_strided_slice %3 {offsets = [0], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
+    // CHECK: %[[V3:.*]] = vector.shuffle %[[V1]], %[[V1]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16>
+    %5 = vector.extract_strided_slice %3 {offsets = [8], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
+    // CHECK: %[[V4:.*]] = vector.shuffle %[[V1]], %[[V1]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xf16>, vector<32xf16>
+    %6 = vector.extract_strided_slice %3 {offsets = [16], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
+    %7 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x32xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+    %8 = xegpu.load_nd %7 <{packed}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<64xf16>
+    // CHECK: %[[V5:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<64xf16>, vector<64xf16>
+    %9 = vector.shape_cast %8 : vector<64xf16> to vector<2x32x1xf16>
+    %10 = vector.extract %9[0] : vector<32x1xf16> from vector<2x32x1xf16>
+    // CHECK: %[[V6:.*]] = vector.shuffle %[[V5]], %[[V5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16>
+    %11 = vector.extract_strided_slice %10 {offsets = [0], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16>
+    %12 = vector.extract %9[1] : vector<32x1xf16> from vector<2x32x1xf16>
+    // CHECK: %[[V7:.*]] = vector.shuffle %{{.*}}, %{{.*}} [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf16>, vector<64xf16>
+    // CHECK: %[[V8:.*]] = vector.shuffle %[[V7]], %[[V7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16>
+    %13 = vector.extract_strided_slice %12 {offsets = [0], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16>
+    // CHECK: %[[V9:.*]] = vector.shuffle %[[V1]], %[[V1]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16>
+    %14 = vector.extract_strided_slice %3 {offsets = [24], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
+    %15 = vector.extract %2[1] : vector<32x1xf16> from vector<2x32x1xf16>
+    // CHECK: %[[V10:.*]] = vector.shuffle %{{.*}}, %{{.*}} [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf16>, vector<64xf16>
+    // CHECK: %[[V11:.*]] = vector.shuffle %[[V10]], %[[V10]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xf16>, vector<32xf16>
+    %16 = vector.extract_strided_slice %15 {offsets = [0], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
+    // CHECK: %[[V12:.*]] = vector.shuffle %[[V10]], %[[V10]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16>
+    %17 = vector.extract_strided_slice %15 {offsets = [8], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
+    // CHECK: %[[V13:.*]] = vector.shuffle %[[V10]], %[[V10]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xf16>, vector<32xf16>
+    %18 = vector.extract_strided_slice %15 {offsets = [16], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
+    // CHECK: %[[V14:.*]] = vector.shuffle %[[V5]], %[[V5]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16>
+    %19 = vector.extract_strided_slice %10 {offsets = [16], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16>
+    // CHECK: %[[V15:.*]] = vector.shuffle %[[V7]], %[[V7]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16>
+    %20 = vector.extract_strided_slice %12 {offsets = [16], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16>
+    // CHECK: %[[V16:.*]] = vector.shuffle %[[V10]], %[[V10]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16>
+    %21 = vector.extract_strided_slice %15 {offsets = [24], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
+    // CHECK-NOT: vector.shape_cast
+    // CHECK-NOT: vector.extract
+    // CHECK-NOT: vector.extract_strided_slice
+    %22 = vector.shape_cast %4 : vector<8x1xf16> to vector<8xf16>
+    %23 = vector.shape_cast %11 : vector<16x1xf16> to vector<16xf16>
+    %24 = xegpu.dpas %22, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+    %25 = vector.shape_cast %13 : vector<16x1xf16> to vector<16xf16>
+    %26 = xegpu.dpas %22, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+    %27 = vector.shape_cast %5 : vector<8x1xf16> to vector<8xf16>
+    %28 = xegpu.dpas %27, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+    %29 = xegpu.dpas %27, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+    %30 = vector.shape_cast %6 : vector<8x1xf16> to vector<8xf16>
+    %31 = xegpu.dpas %30, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+    %32 = xegpu.dpas %30, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+    %33 = vector.shape_cast %14 : vector<8x1xf16> to vector<8xf16>
+    %34 = xegpu.dpas %33, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+    %35 = xegpu.dpas %33, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+    %36 = vector.shape_cast %16 : vector<8x1xf16> to vector<8xf16>
+    %37 = vector.shape_cast %19 : vector<16x1xf16> to vector<16xf16>
+    %38 = xegpu.dpas %36, %37, %24 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+    %39 = vector.shape_cast %20 : vector<16x1xf16> to vector<16xf16>
+    %40 = xegpu.dpas %36, %39, %26 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+    %41 = vector.shape_cast %17 : vector<8x1xf16> to vector<8xf16>
+    %42 = xegpu.dpas %41, %37, %28 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+    %43 = xegpu.dpas %41, %39, %29 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+    %44 = vector.shape_cast %18 : vector<8x1xf16> to vector<8xf16>
+    %45 = xegpu.dpas %44, %37, %31 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+    %46 = xegpu.dpas %44, %39, %32 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+    %47 = vector.shape_cast %21 : vector<8x1xf16> to vector<8xf16>
+    %48 = xegpu.dpas %47, %37, %34 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+    %49 = xegpu.dpas %47, %39, %35 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+    %50 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %38, %50  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+    %51 = xegpu.create_nd_tdesc %arg2[%c0, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %40, %51  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+    %52 = xegpu.create_nd_tdesc %arg2[%c8, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %42, %52  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+    %53 = xegpu.create_nd_tdesc %arg2[%c8, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %43, %53  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+    %54 = xegpu.create_nd_tdesc %arg2[%c16, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %45, %54  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+    %55 = xegpu.create_nd_tdesc %arg2[%c16, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %46, %55  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+    %56 = xegpu.create_nd_tdesc %arg2[%c24, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %48, %56  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+    %57 = xegpu.create_nd_tdesc %arg2[%c24, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %49, %57  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
     gpu.return
   }
 }
+
+

>From 224d3beb6ff97bc7b289e6b54417958455df833b Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 18 Sep 2025 17:55:53 +0000
Subject: [PATCH 4/6] Address review comments.

---
 .../Dialect/XeGPU/xegpu-vector-linearize.mlir | 141 ++++++------------
 1 file changed, 48 insertions(+), 93 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 9985736e2cafb..ec98172c478ea 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -183,101 +183,56 @@ func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask
 }
 
 // -----
-// Check for vector linearization in XeGPU dialect.
-// The vector<64xf16> loaded from memory is linearized into 4 vector<8xf16> using vector.shuffle ops.
-// The pattern is similar to the one used in test_vector_transpose_16x16 above.
+// Check for vector linearization interoperability with XeGPU dialect ops.
+// The `xegpu-vector-linearize` pass does not itself affect the XeGPU ops.
+
+// CHECK: gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel {
+// CHECK: %c0 = arith.constant 0 : index
+// CHECK: %cst = arith.constant dense<0.000000e+00> : vector<64xf16>
+// CHECK: %cst_0 = arith.constant dense<5.000000e+00> : vector<64xf32>
+
+// CHECK: %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0]
+// CHECK: %1 = xegpu.load_nd %0
+// CHECK: %2 = vector.shape_cast %1 : vector<8x16xf16> to vector<128xf16>
+// CHECK: %3 = vector.shuffle %2, %cst {{.*}} : vector<128xf16>, vector<64xf16>
+// CHECK: %4 = vector.shape_cast %3 : vector<128xf16> to vector<8x16xf16>
+
+// CHECK: %5 = xegpu.create_nd_tdesc %arg1[%c0, %c0]
+// CHECK: %6 = xegpu.load_nd %5
+// CHECK: %7 = vector.shape_cast %6 : vector<16x16xf16> to vector<256xf16>
+// CHECK: %8 = vector.shuffle %7, %cst {{.*}} : vector<256xf16>, vector<64xf16>
+// CHECK: %9 = vector.shape_cast %8 : vector<256xf16> to vector<16x16xf16>
+
+// CHECK: %10 = xegpu.dpas %4, %9 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+// CHECK: %11 = vector.shape_cast %10 : vector<8x16xf32> to vector<128xf32>
+// CHECK: %12 = vector.shuffle %11, %11 {{.*}} : vector<128xf32>, vector<128xf32>
+// CHECK: %13 = arith.addf %12, %cst_0 : vector<64xf32>
+// CHECK: %14 = vector.shuffle %11, %13 {{.*}} : vector<128xf32>, vector<64xf32>
+// CHECK: %15 = vector.shape_cast %14 : vector<128xf32> to vector<8x16xf32>
+
+// CHECK: %16 = xegpu.create_nd_tdesc %arg2[%c0, %c0]
+// CHECK: xegpu.store_nd %15, %16
+// CHECK: gpu.return
+
 gpu.module @test_kernel {
-  // CHECK-LABEL: gpu.func @test_kernel
-  gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf32>) kernel {
-    %c24 = arith.constant 24 : index
-    %c16 = arith.constant 16 : index
-    %c8 = arith.constant 8 : index
+  gpu.func @test_kernel(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %C: memref<8x16xf32>) kernel  {
     %c0 = arith.constant 0 : index
-    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<32x32xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
-    %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<64xf16>
-    // CHECK: %[[V1:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<64xf16>, vector<64xf16>
-    %2 = vector.shape_cast %1 : vector<64xf16> to vector<2x32x1xf16>
-    %3 = vector.extract %2[0] : vector<32x1xf16> from vector<2x32x1xf16>
-    // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[V1]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xf16>, vector<32xf16>
-    %4 = vector.extract_strided_slice %3 {offsets = [0], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
-    // CHECK: %[[V3:.*]] = vector.shuffle %[[V1]], %[[V1]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16>
-    %5 = vector.extract_strided_slice %3 {offsets = [8], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
-    // CHECK: %[[V4:.*]] = vector.shuffle %[[V1]], %[[V1]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xf16>, vector<32xf16>
-    %6 = vector.extract_strided_slice %3 {offsets = [16], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
-    %7 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x32xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
-    %8 = xegpu.load_nd %7 <{packed}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<64xf16>
-    // CHECK: %[[V5:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<64xf16>, vector<64xf16>
-    %9 = vector.shape_cast %8 : vector<64xf16> to vector<2x32x1xf16>
-    %10 = vector.extract %9[0] : vector<32x1xf16> from vector<2x32x1xf16>
-    // CHECK: %[[V6:.*]] = vector.shuffle %[[V5]], %[[V5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16>
-    %11 = vector.extract_strided_slice %10 {offsets = [0], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16>
-    %12 = vector.extract %9[1] : vector<32x1xf16> from vector<2x32x1xf16>
-    // CHECK: %[[V7:.*]] = vector.shuffle %{{.*}}, %{{.*}} [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf16>, vector<64xf16>
-    // CHECK: %[[V8:.*]] = vector.shuffle %[[V7]], %[[V7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16>
-    %13 = vector.extract_strided_slice %12 {offsets = [0], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16>
-    // CHECK: %[[V9:.*]] = vector.shuffle %[[V1]], %[[V1]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16>
-    %14 = vector.extract_strided_slice %3 {offsets = [24], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
-    %15 = vector.extract %2[1] : vector<32x1xf16> from vector<2x32x1xf16>
-    // CHECK: %[[V10:.*]] = vector.shuffle %{{.*}}, %{{.*}} [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf16>, vector<64xf16>
-    // CHECK: %[[V11:.*]] = vector.shuffle %[[V10]], %[[V10]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xf16>, vector<32xf16>
-    %16 = vector.extract_strided_slice %15 {offsets = [0], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
-    // CHECK: %[[V12:.*]] = vector.shuffle %[[V10]], %[[V10]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16>
-    %17 = vector.extract_strided_slice %15 {offsets = [8], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
-    // CHECK: %[[V13:.*]] = vector.shuffle %[[V10]], %[[V10]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xf16>, vector<32xf16>
-    %18 = vector.extract_strided_slice %15 {offsets = [16], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
-    // CHECK: %[[V14:.*]] = vector.shuffle %[[V5]], %[[V5]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16>
-    %19 = vector.extract_strided_slice %10 {offsets = [16], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16>
-    // CHECK: %[[V15:.*]] = vector.shuffle %[[V7]], %[[V7]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16>
-    %20 = vector.extract_strided_slice %12 {offsets = [16], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16>
-    // CHECK: %[[V16:.*]] = vector.shuffle %[[V10]], %[[V10]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16>
-    %21 = vector.extract_strided_slice %15 {offsets = [24], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16>
-    // CHECK-NOT: vector.shape_cast
-    // CHECK-NOT: vector.extract
-    // CHECK-NOT: vector.extract_strided_slice
-    %22 = vector.shape_cast %4 : vector<8x1xf16> to vector<8xf16>
-    %23 = vector.shape_cast %11 : vector<16x1xf16> to vector<16xf16>
-    %24 = xegpu.dpas %22, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-    %25 = vector.shape_cast %13 : vector<16x1xf16> to vector<16xf16>
-    %26 = xegpu.dpas %22, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-    %27 = vector.shape_cast %5 : vector<8x1xf16> to vector<8xf16>
-    %28 = xegpu.dpas %27, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-    %29 = xegpu.dpas %27, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-    %30 = vector.shape_cast %6 : vector<8x1xf16> to vector<8xf16>
-    %31 = xegpu.dpas %30, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-    %32 = xegpu.dpas %30, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-    %33 = vector.shape_cast %14 : vector<8x1xf16> to vector<8xf16>
-    %34 = xegpu.dpas %33, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-    %35 = xegpu.dpas %33, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32>
-    %36 = vector.shape_cast %16 : vector<8x1xf16> to vector<8xf16>
-    %37 = vector.shape_cast %19 : vector<16x1xf16> to vector<16xf16>
-    %38 = xegpu.dpas %36, %37, %24 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
-    %39 = vector.shape_cast %20 : vector<16x1xf16> to vector<16xf16>
-    %40 = xegpu.dpas %36, %39, %26 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
-    %41 = vector.shape_cast %17 : vector<8x1xf16> to vector<8xf16>
-    %42 = xegpu.dpas %41, %37, %28 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
-    %43 = xegpu.dpas %41, %39, %29 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
-    %44 = vector.shape_cast %18 : vector<8x1xf16> to vector<8xf16>
-    %45 = xegpu.dpas %44, %37, %31 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
-    %46 = xegpu.dpas %44, %39, %32 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
-    %47 = vector.shape_cast %21 : vector<8x1xf16> to vector<8xf16>
-    %48 = xegpu.dpas %47, %37, %34 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
-    %49 = xegpu.dpas %47, %39, %35 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
-    %50 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-    xegpu.store_nd %38, %50  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-    %51 = xegpu.create_nd_tdesc %arg2[%c0, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-    xegpu.store_nd %40, %51  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-    %52 = xegpu.create_nd_tdesc %arg2[%c8, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-    xegpu.store_nd %42, %52  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-    %53 = xegpu.create_nd_tdesc %arg2[%c8, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-    xegpu.store_nd %43, %53  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-    %54 = xegpu.create_nd_tdesc %arg2[%c16, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-    xegpu.store_nd %45, %54  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-    %55 = xegpu.create_nd_tdesc %arg2[%c16, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-    xegpu.store_nd %46, %55  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-    %56 = xegpu.create_nd_tdesc %arg2[%c24, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-    xegpu.store_nd %48, %56  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
-    %57 = xegpu.create_nd_tdesc %arg2[%c24, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-    xegpu.store_nd %49, %57  : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+    %cst_vec_0 = arith.constant dense<0.000000e+00> : vector<8x8xf16>
+    %cst_vec_1 = arith.constant dense<0.000000e+00> : vector<8x8xf16>
+    %cst_vec_2 = arith.constant dense<5.000000e+00> : vector<8x8xf32>
+    %a_tdesc = xegpu.create_nd_tdesc %A[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 1>>
+    %a_val = xegpu.load_nd %a_tdesc : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 1>> -> vector<8x16xf16>
+    %a_val_0 = vector.insert_strided_slice %cst_vec_0, %a_val{offsets = [0, 0], sizes = [8, 8], strides = [1, 1]}: vector<8x8xf16> into vector<8x16xf16>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 1>>
+
+    %b_val = xegpu.load_nd  %b_tdesc : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 1>> -> vector<16x16xf16>
+    %b_val_0 = vector.insert_strided_slice %cst_vec_1, %b_val{offsets = [0, 0], sizes = [8, 8], strides = [1, 1]}: vector<8x8xf16> into vector<16x16xf16>
+    %c_val = xegpu.dpas %a_val_0, %b_val_0 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    %c_val_0 = vector.extract_strided_slice %c_val {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32>
+    %c_addf = arith.addf %c_val_0, %cst_vec_2 : vector<8x8xf32>
+    %c_result = vector.insert_strided_slice %c_addf, %c_val {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x8xf32> into vector<8x16xf32>
+    %c_tdesc = xegpu.create_nd_tdesc %C[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<array_length = 1>>
+    xegpu.store_nd %c_result, %c_tdesc : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
     gpu.return
   }
 }

>From b4f8cbfb33041ac0231b452d516ce6525b9ff34b Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 18 Sep 2025 21:25:16 +0000
Subject: [PATCH 5/6] Address review comments.

Add vector unroll support for n-D laod/store.
---
 .../XeGPU/Transforms/XeGPUVectorLinearize.cpp | 36 +++++++++++++++----
 .../Dialect/XeGPU/xegpu-vector-linearize.mlir | 23 ++++++++++++
 2 files changed, 52 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
index 2bb302f4287c4..24da724bf6d81 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
@@ -16,6 +16,8 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
 
 #include <optional>
 
@@ -48,7 +50,8 @@ struct XeGPUVectorLinearizePass final
         return signalPassFailure();
     }
 
-    // Unroll load store from <MxN> to M <1xN> load/stores and then linearize
+    // Unroll load/store from <d1xd2x...xdk> to (d1*d2*...*d(k-1)) slices of
+    // <1x1x...x1xdk>.
     {
       RewritePatternSet patterns(&getContext());
       vector::UnrollVectorOptions vectorOptions;
@@ -62,19 +65,36 @@ struct XeGPUVectorLinearizePass final
               return nullptr;
             };
 
-            auto vecType = extractVectorType(op);
+            VectorType vecType = extractVectorType(op);
             if (!vecType)
               return std::nullopt;
 
-            auto shape = vecType.getShape();
-            if (shape.size() != 2)
+            // Only handle rank >= 2 so we actually unroll something.
+            int64_t rank = vecType.getRank();
+            if (rank < 2)
               return std::nullopt;
 
-            return SmallVector<int64_t>{1, shape[1]};
+            ArrayRef<int64_t> shape = vecType.getShape();
+            // Bail if any of the (rank-1) leading dims are dynamic (can't fully
+            // unroll).
+            for (int64_t i = 0; i < rank - 1; ++i)
+              if (shape[i] == ShapedType::kDynamic) {
+                LLVM_DEBUG(llvm::dbgs()
+                           << "Dynamic leading dim " << i << " in " << vecType
+                           << " prevents full unroll.\n");
+                return std::nullopt;
+              }
+
+            // Produce native shape: 1 x 1 x ... x (original last dim).
+            SmallVector<int64_t> native(rank, 1);
+            native.back() = shape.back();
+            return native;
           });
       vector::populateVectorUnrollPatterns(patterns, vectorOptions);
-      if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+        LLVM_DEBUG(llvm::dbgs() << "Unroll failed.\n");
         return signalPassFailure();
+      }
     }
 
     // Use vector linearization patterns
@@ -90,8 +110,10 @@ struct XeGPUVectorLinearizePass final
       scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                            target);
       if (failed(applyPartialConversion(getOperation(), target,
-                                        std::move(patterns))))
+                                        std::move(patterns)))) {
+        LLVM_DEBUG(llvm::dbgs() << "Linearization failed.\n");
         return signalPassFailure();
+      }
     }
   }
 };
diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index ec98172c478ea..0bb7d7d3d8b1b 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -59,6 +59,29 @@ func.func @test_vector_store_load_4x4_f16(%buffer: memref<4x4xf16>) {
   vector.store %0, %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
   return
 }
+
+// -----
+// CHECK-LABEL: func.func @test_vector_store_load_4x4x4
+// CHECK-SAME: (%[[BUF:.*]]: memref<4x4x4xf32>)
+// Constants (order not important)
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// All 16 scalar-slice (row/col plane) loads of 1D vectors
+// CHECK-COUNT-16: vector.load {{.*}} : memref<4x4x4xf32>, vector<4xf32>
+// No remaining 3D vector load
+// CHECK-NOT: vector.load {{.*}} : memref<4x4x4xf32>, vector<4x4x4xf32>
+// All 16 stores of 1D vectors
+// CHECK-COUNT-16: vector.store {{.*}} : memref<4x4x4xf32>, vector<4xf32>
+// CHECK: return
+func.func @test_vector_store_load_4x4x4(%buffer: memref<4x4x4xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %buffer[%c0, %c0, %c0] : memref<4x4x4xf32>, vector<4x4x4xf32>
+  vector.store %0, %buffer[%c0, %c0, %c0] : memref<4x4x4xf32>, vector<4x4x4xf32>
+  return
+}
+
 // -----
 // CHECK-LABEL: func.func @test_linearize_index
 // CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xindex>, %[[ARG1:.*]]: vector<2x2xi32>) -> vector<2x2xindex>

>From 6b22d6d3936d5de3b640ae2f65d372c6cf2f3798 Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Thu, 18 Sep 2025 22:05:25 +0000
Subject: [PATCH 6/6] Address review comments.

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td       |  2 +-
 .../XeGPU/Transforms/XeGPUVectorLinearize.cpp     | 15 +++------------
 2 files changed, 4 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 77c57ccb0746f..83b128e2c7cbf 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -81,7 +81,7 @@ def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> {
     This pass linearizes n-D vectors to 1-D vectors for lowering to XeVM.
   }];
   let dependentDialects = ["arith::ArithDialect", "memref::MemRefDialect",
-                           "scf::SCFDialect", "vector::VectorDialect"];
+                           "scf::SCFDialect", "ub::UBDialect", "vector::VectorDialect"];
 }
 
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
index 24da724bf6d81..e31c37a2459ad 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
 #include "llvm/Support/raw_ostream.h"
 
 #include <optional>
@@ -75,16 +76,6 @@ struct XeGPUVectorLinearizePass final
               return std::nullopt;
 
             ArrayRef<int64_t> shape = vecType.getShape();
-            // Bail if any of the (rank-1) leading dims are dynamic (can't fully
-            // unroll).
-            for (int64_t i = 0; i < rank - 1; ++i)
-              if (shape[i] == ShapedType::kDynamic) {
-                LLVM_DEBUG(llvm::dbgs()
-                           << "Dynamic leading dim " << i << " in " << vecType
-                           << " prevents full unroll.\n");
-                return std::nullopt;
-              }
-
             // Produce native shape: 1 x 1 x ... x (original last dim).
             SmallVector<int64_t> native(rank, 1);
             native.back() = shape.back();
@@ -92,7 +83,7 @@ struct XeGPUVectorLinearizePass final
           });
       vector::populateVectorUnrollPatterns(patterns, vectorOptions);
       if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
-        LLVM_DEBUG(llvm::dbgs() << "Unroll failed.\n");
+        LDBG() << "Unroll failed.";
         return signalPassFailure();
       }
     }
@@ -111,7 +102,7 @@ struct XeGPUVectorLinearizePass final
                                                            target);
       if (failed(applyPartialConversion(getOperation(), target,
                                         std::move(patterns)))) {
-        LLVM_DEBUG(llvm::dbgs() << "Linearization failed.\n");
+        LDBG() << "Linearization failed.";
         return signalPassFailure();
       }
     }



More information about the Mlir-commits mailing list