[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass (PR #142797)
Nishant Patel
llvmlistbot at llvm.org
Fri Jun 6 08:50:17 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/142797
>From 08f7eb9752e682cb6b3b6c4f40fad613a0b0d940 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 2 Jun 2025 18:05:59 +0000
Subject: [PATCH 1/2] Add support elementwise ops in Wg to Sg distribute pass
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 223 ++++
.../XeGPU/xegpu-wg-to-sg-elemwise.mlir | 971 ++++++++++++++++++
2 files changed, 1194 insertions(+)
create mode 100644 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3bf76af674ba0..972394a7b40ad 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -8,15 +8,18 @@
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
namespace mlir {
namespace xegpu {
@@ -314,6 +317,179 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+// This pattern matches elementwise ops (unary/binary) in math/arith dialects
+// with 1D or 2D vector types
+template <typename Op>
+struct WgToSgElementwiseOp : public OpConversionPattern<Op> {
+ using OpConversionPattern<Op>::OpConversionPattern;
+ using OneToNOpAdaptor = typename OpConversionPattern<Op>::OneToNOpAdaptor;
+
+ LogicalResult
+ matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // All operands/results must be 1D or 2D vectors
+ auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType || (resultType.getRank() != 1 && resultType.getRank() != 2))
+ return rewriter.notifyMatchFailure(
+ op, "Result type is not a 1D or 2D vector");
+
+ ArrayRef<int64_t> shape = resultType.getShape();
+ for (Value operand : op->getOperands()) {
+ auto operandType = dyn_cast<VectorType>(operand.getType());
+ if (!operandType || operandType.getRank() != resultType.getRank() ||
+ operandType.getShape() != shape) {
+ return rewriter.notifyMatchFailure(
+ op, "Operand type is not a 1D or 2D vector with the same shape as "
+ "result type");
+ }
+ }
+
+ // Check for layout attribute with sgLayout
+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+ if (!layout || !layout.getSgLayout())
+ return rewriter.notifyMatchFailure(
+ op, "Operation does not have a valid layout attribute for subgroup "
+ "distribution");
+
+ // Extract sgShape from layout
+ SmallVector<int64_t> sgShape;
+ if (auto sgDataAttr = layout.getSgData()) {
+ sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+ } else {
+ auto sgLayoutArr = layout.getSgLayout();
+ sgShape.reserve(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i) {
+ assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
+ sgShape.push_back(shape[i] / sgLayoutArr[i]);
+ }
+ }
+
+ // Each operand is a list of values
+ size_t numVariants = adaptor.getOperands().empty()
+ ? 0
+ : adaptor.getOperands().front().size();
+ for (auto &operandVec : adaptor.getOperands())
+ if (operandVec.size() != numVariants)
+ return rewriter.notifyMatchFailure(
+ op, "Operand lists have mismatched sizes");
+
+ SmallVector<Value> newResults;
+
+ auto origResultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ VectorType newResultType =
+ origResultType
+ ? VectorType::get(sgShape, origResultType.getElementType())
+ : VectorType::get(sgShape, resultType.getElementType());
+
+ for (size_t i = 0; i < numVariants; ++i) {
+ SmallVector<Value> operands;
+ for (auto &operandVec : adaptor.getOperands())
+ operands.push_back(operandVec[i]);
+
+ auto newOp = rewriter.create<Op>(op.getLoc(), newResultType, operands);
+
+ // Copy all attributes except "layout", and add "layout_result_0" with
+ // sgLayout/data dropped
+ for (auto attr : op->getAttrs()) {
+ if (attr.getName() != "layout")
+ newOp->setAttr(attr.getName(), attr.getValue());
+ }
+ newOp->setAttr("layout_result_0", layout.dropSgLayoutAndData());
+
+ newResults.push_back(newOp.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newResults});
+ return success();
+ }
+};
+
+// ---- ARITH ops ----
+using WgToSgAddFOp = WgToSgElementwiseOp<arith::AddFOp>;
+using WgToSgSubFOp = WgToSgElementwiseOp<arith::SubFOp>;
+using WgToSgNegFOp = WgToSgElementwiseOp<arith::NegFOp>;
+using WgToSgAddIOp = WgToSgElementwiseOp<arith::AddIOp>;
+using WgToSgSubIOp = WgToSgElementwiseOp<arith::SubIOp>;
+using WgToSgMulFOp = WgToSgElementwiseOp<arith::MulFOp>;
+using WgToSgMulIOp = WgToSgElementwiseOp<arith::MulIOp>;
+using WgToSgShLIOp = WgToSgElementwiseOp<arith::ShLIOp>;
+using WgToSgShRSIOp = WgToSgElementwiseOp<arith::ShRSIOp>;
+using WgToSgShRUIOp = WgToSgElementwiseOp<arith::ShRUIOp>;
+using WgToSgDivFOp = WgToSgElementwiseOp<arith::DivFOp>;
+using WgToSgDivSIOp = WgToSgElementwiseOp<arith::DivSIOp>;
+using WgToSgDivUIOp = WgToSgElementwiseOp<arith::DivUIOp>;
+using WgToSgMaximumFOp = WgToSgElementwiseOp<arith::MaximumFOp>;
+using WgToSgMinimumFOp = WgToSgElementwiseOp<arith::MinimumFOp>;
+using WgToSgRemSIOp = WgToSgElementwiseOp<arith::RemSIOp>;
+using WgToSgRemUIOp = WgToSgElementwiseOp<arith::RemUIOp>;
+using WgToSgTruncFOp = WgToSgElementwiseOp<arith::TruncFOp>;
+using WgToSgTruncIOp = WgToSgElementwiseOp<arith::TruncIOp>;
+using WgToSgExtFOp = WgToSgElementwiseOp<arith::ExtFOp>;
+using WgToSgExtSIOp = WgToSgElementwiseOp<arith::ExtSIOp>;
+using WgToSgExtUIOp = WgToSgElementwiseOp<arith::ExtUIOp>;
+using WgToSgSIToFPOp = WgToSgElementwiseOp<arith::SIToFPOp>;
+using WgToSgUIToFPOp = WgToSgElementwiseOp<arith::UIToFPOp>;
+using WgToSgFPToSIOp = WgToSgElementwiseOp<arith::FPToSIOp>;
+using WgToSgFPToUIOp = WgToSgElementwiseOp<arith::FPToUIOp>;
+using WgToSgIndexCastUIOp = WgToSgElementwiseOp<arith::IndexCastUIOp>;
+using WgToSgIndexCastOp = WgToSgElementwiseOp<arith::IndexCastOp>;
+using WgToSgBitcastOp = WgToSgElementwiseOp<arith::BitcastOp>;
+using WgToSgCmpIOp = WgToSgElementwiseOp<arith::CmpIOp>;
+using WgToSgCmpFOp = WgToSgElementwiseOp<arith::CmpFOp>;
+using WgToSgAndIOp = WgToSgElementwiseOp<arith::AndIOp>;
+using WgToSgCeilDivSIOp = WgToSgElementwiseOp<arith::CeilDivSIOp>;
+using WgToSgCeilDivUIOp = WgToSgElementwiseOp<arith::CeilDivUIOp>;
+using WgToSgFloorDivSIOp = WgToSgElementwiseOp<arith::FloorDivSIOp>;
+using WgToSgMaxNumFOp = WgToSgElementwiseOp<arith::MaxNumFOp>;
+using WgToSgMaxSIOp = WgToSgElementwiseOp<arith::MaxSIOp>;
+using WgToSgMaxUIOp = WgToSgElementwiseOp<arith::MaxUIOp>;
+using WgToSgMinNumFOp = WgToSgElementwiseOp<arith::MinNumFOp>;
+using WgToSgMinSIOp = WgToSgElementwiseOp<arith::MinSIOp>;
+using WgToSgMinUIOp = WgToSgElementwiseOp<arith::MinUIOp>;
+using WgToSgOrIOp = WgToSgElementwiseOp<arith::OrIOp>;
+using WgToSgRemFOp = WgToSgElementwiseOp<arith::RemFOp>;
+using WgToSgSelectOp = WgToSgElementwiseOp<arith::SelectOp>;
+using WgToSgXOrIOp = WgToSgElementwiseOp<arith::XOrIOp>;
+
+// ---- MATH ops ----
+using WgToSgExpOp = WgToSgElementwiseOp<math::ExpOp>;
+using WgToSgSqrtOp = WgToSgElementwiseOp<math::SqrtOp>;
+using WgToSgAbsFOp = WgToSgElementwiseOp<math::AbsFOp>;
+using WgToSgCosOp = WgToSgElementwiseOp<math::CosOp>;
+using WgToSgCoshOp = WgToSgElementwiseOp<math::CoshOp>;
+using WgToSgAcosOp = WgToSgElementwiseOp<math::AcosOp>;
+using WgToSgAcoshOp = WgToSgElementwiseOp<math::AcoshOp>;
+using WgToSgSinOp = WgToSgElementwiseOp<math::SinOp>;
+using WgToSgSinhOp = WgToSgElementwiseOp<math::SinhOp>;
+using WgToSgAsinOp = WgToSgElementwiseOp<math::AsinOp>;
+using WgToSgAsinhOp = WgToSgElementwiseOp<math::AsinhOp>;
+using WgToSgTanOp = WgToSgElementwiseOp<math::TanOp>;
+using WgToSgTanhOp = WgToSgElementwiseOp<math::TanhOp>;
+using WgToSgAtanOp = WgToSgElementwiseOp<math::AtanOp>;
+using WgToSgAtan2Op = WgToSgElementwiseOp<math::Atan2Op>;
+using WgToSgAtanhOp = WgToSgElementwiseOp<math::AtanhOp>;
+using WgToSgErfOp = WgToSgElementwiseOp<math::ErfOp>;
+using WgToSgLogOp = WgToSgElementwiseOp<math::LogOp>;
+using WgToSgLog2Op = WgToSgElementwiseOp<math::Log2Op>;
+using WgToSgFloorOp = WgToSgElementwiseOp<math::FloorOp>;
+using WgToSgCeilOp = WgToSgElementwiseOp<math::CeilOp>;
+using WgToSgPowFOp = WgToSgElementwiseOp<math::PowFOp>;
+using WgToSgRsqrtOp = WgToSgElementwiseOp<math::RsqrtOp>;
+using WgToSgAbsIOp = WgToSgElementwiseOp<math::AbsIOp>;
+using WgToSgCbrtOp = WgToSgElementwiseOp<math::CbrtOp>;
+using WgToSgCopySignOp = WgToSgElementwiseOp<math::CopySignOp>;
+using WgToSgCtPopOp = WgToSgElementwiseOp<math::CtPopOp>;
+using WgToSgErfcOp = WgToSgElementwiseOp<math::ErfcOp>;
+using WgToSgExp2Op = WgToSgElementwiseOp<math::Exp2Op>;
+using WgToSgExpM1Op = WgToSgElementwiseOp<math::ExpM1Op>;
+using WgToSgFPowIOp = WgToSgElementwiseOp<math::FPowIOp>;
+using WgToSgIPowIOp = WgToSgElementwiseOp<math::IPowIOp>;
+using WgToSgLog10Op = WgToSgElementwiseOp<math::Log10Op>;
+using WgToSgLog1pOp = WgToSgElementwiseOp<math::Log1pOp>;
+using WgToSgRoundOp = WgToSgElementwiseOp<math::RoundOp>;
+using WgToSgRoundEvenOp = WgToSgElementwiseOp<math::RoundEvenOp>;
+using WgToSgTruncOp = WgToSgElementwiseOp<math::TruncOp>;
+
} // namespace
namespace mlir {
@@ -322,6 +498,27 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
patterns.getContext());
+ // Add elementwise operations that can be distributed to subgroups
+ patterns.add<
+ WgToSgAddFOp, WgToSgSubFOp, WgToSgExpOp, WgToSgSqrtOp, WgToSgAbsFOp,
+ WgToSgCosOp, WgToSgCoshOp, WgToSgAcosOp, WgToSgAcoshOp, WgToSgSinOp,
+ WgToSgSinhOp, WgToSgAsinOp, WgToSgAsinhOp, WgToSgTanOp, WgToSgTanhOp,
+ WgToSgAtanOp, WgToSgAtan2Op, WgToSgAtanhOp, WgToSgErfOp, WgToSgLogOp,
+ WgToSgLog2Op, WgToSgFloorOp, WgToSgCeilOp, WgToSgPowFOp, WgToSgRsqrtOp,
+ WgToSgNegFOp, WgToSgAddIOp, WgToSgSubIOp, WgToSgMulFOp, WgToSgMulIOp,
+ WgToSgShLIOp, WgToSgShRSIOp, WgToSgShRUIOp, WgToSgDivFOp, WgToSgDivSIOp,
+ WgToSgDivUIOp, WgToSgMaximumFOp, WgToSgMinimumFOp, WgToSgRemSIOp,
+ WgToSgRemUIOp, WgToSgTruncFOp, WgToSgTruncIOp, WgToSgExtFOp,
+ WgToSgExtSIOp, WgToSgExtUIOp, WgToSgSIToFPOp, WgToSgUIToFPOp,
+ WgToSgFPToSIOp, WgToSgFPToUIOp, WgToSgIndexCastUIOp, WgToSgIndexCastOp,
+ WgToSgBitcastOp, WgToSgCmpIOp, WgToSgCmpFOp, WgToSgAndIOp,
+ WgToSgCeilDivSIOp, WgToSgCeilDivUIOp, WgToSgFloorDivSIOp, WgToSgMaxNumFOp,
+ WgToSgMaxSIOp, WgToSgMaxUIOp, WgToSgMinNumFOp, WgToSgMinSIOp,
+ WgToSgMinUIOp, WgToSgOrIOp, WgToSgRemFOp, WgToSgSelectOp, WgToSgXOrIOp,
+ WgToSgAbsIOp, WgToSgCbrtOp, WgToSgCopySignOp, WgToSgCtPopOp, WgToSgErfcOp,
+ WgToSgExp2Op, WgToSgExpM1Op, WgToSgFPowIOp, WgToSgIPowIOp, WgToSgLog10Op,
+ WgToSgLog1pOp, WgToSgRoundOp, WgToSgRoundEvenOp, WgToSgTruncOp>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -368,6 +565,32 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
return isLegal(layout);
});
+ target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
+ [=](Operation *op) -> std::optional<bool> {
+ // Handle unary and binary operations
+ if (op->getNumOperands() < 1 || op->getNumOperands() > 2)
+ return true;
+
+ // check if input and output are vectors
+ VectorType resultType =
+ dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType || resultType.getRank() != 2)
+ return true;
+
+ // Check if all operands are vectors
+ for (Value operand : op->getOperands()) {
+ VectorType operandType = dyn_cast<VectorType>(operand.getType());
+ if (!operandType || operandType.getRank() != 2 ||
+ operandType.getShape() != resultType.getShape()) {
+ return true;
+ }
+ }
+
+ // check layout attribute
+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(
+ op->getAttrOfType<xegpu::LayoutAttr>("layout"));
+ return isLegal(layout);
+ });
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
new file mode 100644
index 0000000000000..c45312e4c2d74
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -0,0 +1,971 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @test_elementwise_ops {
+ // CHECK-LABEL: test_elemwise_ops
+ gpu.func @test_elemwise_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_c = xegpu.load_nd %tdesc_c
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_d = xegpu.load_nd %tdesc_d
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+
+ // Floating point ops
+ // CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.subf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.sqrt {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.absf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.cos {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.cosh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.acos {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.acosh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.sin {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.sinh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.asin {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.asinh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.tan {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.tanh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.atan {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.atan2 {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.atanh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.erf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.log {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.log2 {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.floor {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.ceil {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.rsqrt {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.mulf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.divf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.maximumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.minimumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.addi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.subi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.muli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.shli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.shrsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.shrui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.divsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.divui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.remsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ %addf = arith.addf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %subf = arith.subf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %exp = math.exp %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sqrt = math.sqrt %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %absf = math.absf %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cos = math.cos %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cosh = math.cosh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %acos = math.acos %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %acosh = math.acosh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sin = math.sin %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sinh = math.sinh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %asin = math.asin %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %asinh = math.asinh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %tan = math.tan %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %tanh = math.tanh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atan = math.atan %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atan2 = math.atan2 %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atanh = math.atanh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %erf = math.erf %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %log = math.log %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %log2 = math.log2 %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %floor = math.floor %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %ceil = math.ceil %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %powf = math.powf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %rsqrt = math.rsqrt %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %negf = arith.negf %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %mulf = arith.mulf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %divf = arith.divf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %maximumf = arith.maximumf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %minimumf = arith.minimumf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+
+ // Integer ops
+ %addi = arith.addi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %subi = arith.subi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %muli = arith.muli %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %shli = arith.shli %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %shrsi = arith.shrsi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %shrui = arith.shrui %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %divsi = arith.divsi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %divui = arith.divui %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %remsi = arith.remsi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %remui = arith.remui %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+
+ gpu.return
+ }
+
+ // 1 to N decomposition of elementwise operations
+ // CHECK-LABEL: test_elemwise_ops_sg_rr_assignment
+ gpu.func @test_elemwise_ops_sg_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_c = xegpu.load_nd %tdesc_c
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_d = xegpu.load_nd %tdesc_d
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+
+ // Floating point ops
+ // CHECK-COUNT-12: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.subf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.exp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.sqrt {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.absf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.cos {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.cosh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.acos {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.acosh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.sin {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.sinh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.asin {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.asinh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.tan {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.tanh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.atan {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.atan2 {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.atanh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.erf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.log {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.log2 {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.floor {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.ceil {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: math.rsqrt {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.mulf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.divf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.maximumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.minimumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.addi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.subi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.muli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.shli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.shrsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.shrui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.divsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.divui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.remsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ %addf = arith.addf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %subf = arith.subf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %exp = math.exp %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sqrt = math.sqrt %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %absf = math.absf %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cos = math.cos %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cosh = math.cosh %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %acos = math.acos %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %acosh = math.acosh %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sin = math.sin %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sinh = math.sinh %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %asin = math.asin %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %asinh = math.asinh %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %tan = math.tan %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %tanh = math.tanh %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atan = math.atan %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atan2 = math.atan2 %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atanh = math.atanh %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %erf = math.erf %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %log = math.log %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %log2 = math.log2 %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %floor = math.floor %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %ceil = math.ceil %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %powf = math.powf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %rsqrt = math.rsqrt %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %negf = arith.negf %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %mulf = arith.mulf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %divf = arith.divf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %maximumf = arith.maximumf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %minimumf = arith.minimumf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+
+ // Integer ops
+ %addi = arith.addi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %subi = arith.subi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %muli = arith.muli %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %shli = arith.shli %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %shrsi = arith.shrsi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %shrui = arith.shrui %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %divsi = arith.divsi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %divui = arith.divui %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %remsi = arith.remsi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %remui = arith.remui %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+
+ gpu.return
+ }
+
+ // CHECK-LABEL: test_all_type_conversion_ops
+ gpu.func @test_all_type_conversion_ops(
+ %a: memref<24x32xf32>, %b: memref<24x32xi32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_c = xegpu.load_nd %tdesc_c
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_d = xegpu.load_nd %tdesc_d
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+
+ // CHECK: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32> to vector<12x8xf16>
+ // CHECK: arith.trunci {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32> to vector<12x8xi16>
+ // CHECK: arith.extf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf16> to vector<12x8xf32>
+ // CHECK: arith.extsi {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi16> to vector<12x8xi32>
+ // CHECK: arith.extui {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi16> to vector<12x8xi32>
+ // CHECK: arith.sitofp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32> to vector<12x8xf32>
+ // CHECK: arith.uitofp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32> to vector<12x8xf32>
+ // CHECK: arith.fptosi {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32> to vector<12x8xi32>
+ // CHECK: arith.fptoui {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32> to vector<12x8xi32>
+ // CHECK: arith.index_castui {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32> to vector<12x8xindex>
+ // CHECK: arith.index_cast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xindex> to vector<12x8xi32>
+ // CHECK: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32> to vector<12x8xf32>
+ // TruncFOp: f32 -> f16
+ %truncf = arith.truncf %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32> to vector<24x32xf16>
+ // TruncIOp: i32 -> i16
+ %trunci = arith.trunci %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32> to vector<24x32xi16>
+ // ExtFOp: f16 -> f32
+ %truncf16 = arith.truncf %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32> to vector<24x32xf16>
+ %extf = arith.extf %truncf16
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf16> to vector<24x32xf32>
+ // ExtSIOp: i16 -> i32
+ %extsi = arith.extsi %trunci
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi16> to vector<24x32xi32>
+ // ExtUIOp: i16 -> i32 (unsigned)
+ %extui = arith.extui %trunci
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi16> to vector<24x32xi32>
+ // SIToFPOp: i32 -> f32
+ %sitofp = arith.sitofp %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32> to vector<24x32xf32>
+ // UIToFPOp: i32 -> f32 (unsigned)
+ %uitofp = arith.uitofp %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32> to vector<24x32xf32>
+ // FPToSIOp: f32 -> i32
+ %fptosi = arith.fptosi %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32> to vector<24x32xi32>
+ // FPToUIOp: f32 -> i32 (unsigned)
+ %fptoui = arith.fptoui %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32> to vector<24x32xi32>
+ // IndexCastUIOp: i32 -> index
+ %indexcastui = arith.index_castui %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32> to vector<24x32xindex>
+ // IndexCastOp: index -> i32
+ %indexcast = arith.index_cast %indexcastui
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xindex> to vector<24x32xi32>
+ // BitcastOp: i32 -> f32
+ %bitcast = arith.bitcast %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32> to vector<24x32xf32>
+ gpu.return
+ }
+
+
+ // CHECK-LABEL: gpu.func @test_all_type_conversion_ops_rr_assignment
+ gpu.func @test_all_type_conversion_ops_rr_assignment(
+ %a: memref<24x32xf32>, %b: memref<24x32xi32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_c = xegpu.load_nd %tdesc_c
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_d = xegpu.load_nd %tdesc_d
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+
+ // CHECK-COUNT-12: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32> to vector<2x2xf16>
+ // CHECK-COUNT-12: arith.trunci {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32> to vector<2x2xi16>
+ // CHECK-COUNT-12: arith.extf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf16> to vector<2x2xf32>
+ // CHECK-COUNT-12: arith.extsi {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi16> to vector<2x2xi32>
+ // CHECK-COUNT-12: arith.extui {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi16> to vector<2x2xi32>
+ // CHECK-COUNT-12: arith.sitofp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32> to vector<2x2xf32>
+ // CHECK-COUNT-12: arith.uitofp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32> to vector<2x2xf32>
+ // CHECK-COUNT-12: arith.fptosi {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32> to vector<2x2xi32>
+ // CHECK-COUNT-12: arith.fptoui {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32> to vector<2x2xi32>
+ // CHECK-COUNT-12: arith.index_castui {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32> to vector<2x2xindex>
+ // CHECK-COUNT-12: arith.index_cast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xindex> to vector<2x2xi32>
+ // CHECK-COUNT-12: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32> to vector<2x2xf32>
+ // TruncFOp: f32 -> f16
+ %truncf = arith.truncf %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32> to vector<24x32xf16>
+ // TruncIOp: i32 -> i16
+ %trunci = arith.trunci %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32> to vector<24x32xi16>
+ // ExtFOp: f16 -> f32
+ %truncf16 = arith.truncf %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32> to vector<24x32xf16>
+ %extf = arith.extf %truncf16
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf16> to vector<24x32xf32>
+ // ExtSIOp: i16 -> i32
+ %extsi = arith.extsi %trunci
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi16> to vector<24x32xi32>
+ // ExtUIOp: i16 -> i32 (unsigned)
+ %extui = arith.extui %trunci
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi16> to vector<24x32xi32>
+ // SIToFPOp: i32 -> f32
+ %sitofp = arith.sitofp %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32> to vector<24x32xf32>
+ // UIToFPOp: i32 -> f32 (unsigned)
+ %uitofp = arith.uitofp %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32> to vector<24x32xf32>
+ // FPToSIOp: f32 -> i32
+ %fptosi = arith.fptosi %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32> to vector<24x32xi32>
+ // FPToUIOp: f32 -> i32 (unsigned)
+ %fptoui = arith.fptoui %load_a
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32> to vector<24x32xi32>
+ // IndexCastUIOp: i32 -> index
+ %indexcastui = arith.index_castui %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32> to vector<24x32xindex>
+ // IndexCastOp: index -> i32
+ %indexcast = arith.index_cast %indexcastui
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xindex> to vector<24x32xi32>
+ // BitcastOp: i32 -> f32
+ %bitcast = arith.bitcast %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32> to vector<24x32xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @test_cmp_ops
+ gpu.func @test_cmp_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_c = xegpu.load_nd %tdesc_c
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_d = xegpu.load_nd %tdesc_d
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+
+ // CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.cmpi ne, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.cmpi slt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.cmpi sle, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.cmpi sgt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.cmpi sge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.cmpi ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.cmpi ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.cmpi ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.cmpi uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.cmpf oeq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf ogt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf oge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf olt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf ole, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf one, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf ord, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf ueq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf une, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf uno, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // Integer comparisons
+ %cmpi_eq = arith.cmpi eq, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_ne = arith.cmpi ne, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_slt = arith.cmpi slt, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_sle = arith.cmpi sle, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_sgt = arith.cmpi sgt, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_sge = arith.cmpi sge, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_ult = arith.cmpi ult, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_ule = arith.cmpi ule, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_ugt = arith.cmpi ugt, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_uge = arith.cmpi uge, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+
+ // Floating point comparisons
+ %cmpf_oeq = arith.cmpf oeq, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ogt = arith.cmpf ogt, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_oge = arith.cmpf oge, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_olt = arith.cmpf olt, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ole = arith.cmpf ole, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_one = arith.cmpf one, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ord = arith.cmpf ord, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ueq = arith.cmpf ueq, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ugt = arith.cmpf ugt, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_uge = arith.cmpf uge, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ult = arith.cmpf ult, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ule = arith.cmpf ule, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_une = arith.cmpf une, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_uno = arith.cmpf uno, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @test_cmp_ops_rr_assignment
+ gpu.func @test_cmp_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_c = xegpu.load_nd %tdesc_c
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_d = xegpu.load_nd %tdesc_d
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+
+ // CHECK-COUNT-12: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.cmpi ne, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.cmpi slt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.cmpi sle, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.cmpi sgt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.cmpi sge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.cmpi ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.cmpi ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.cmpi ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-COUNT-12: arith.cmpi uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // Floating point comparisons
+ // CHECK-COUNT-12: arith.cmpf oeq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf ogt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf oge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf olt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf ole, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf one, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf ord, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf ueq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf une, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-COUNT-12: arith.cmpf uno, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+
+ // Integer comparisons
+ %cmpi_eq = arith.cmpi eq, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_ne = arith.cmpi ne, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_slt = arith.cmpi slt, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_sle = arith.cmpi sle, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_sgt = arith.cmpi sgt, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_sge = arith.cmpi sge, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_ult = arith.cmpi ult, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_ule = arith.cmpi ule, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_ugt = arith.cmpi ugt, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cmpi_uge = arith.cmpi uge, %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+
+ // Floating point comparisons
+ %cmpf_oeq = arith.cmpf oeq, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ogt = arith.cmpf ogt, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_oge = arith.cmpf oge, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_olt = arith.cmpf olt, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ole = arith.cmpf ole, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_one = arith.cmpf one, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ord = arith.cmpf ord, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ueq = arith.cmpf ueq, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ugt = arith.cmpf ugt, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_uge = arith.cmpf uge, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ult = arith.cmpf ult, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_ule = arith.cmpf ule, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_une = arith.cmpf une, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf_uno = arith.cmpf uno, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+
+ gpu.return
+ }
+
+ gpu.func @test_extra_elemwise_ops(
+ %a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>, %e: memref<24x32xi1>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_e = xegpu.create_nd_tdesc %e[0, 0] : memref<24x32xi1>
+ -> !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_c = xegpu.load_nd %tdesc_c
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_d = xegpu.load_nd %tdesc_d
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_e = xegpu.load_nd %tdesc_e
+ : !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi1>
+
+ // CHECK: arith.andi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.ori {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.xori {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.ceildivsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.ceildivui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.floordivsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.maxnumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.maxsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.maxui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.minnumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.minsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.minui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.remf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.absi {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: math.cbrt {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.copysign {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.ctpop {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: math.erfc {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.exp2 {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.expm1 {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.fpowi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>, vector<12x8xi32>
+ // CHECK: math.ipowi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: math.log10 {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.log1p {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.round {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.roundeven {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.trunc {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // arith ops
+ %andi = arith.andi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %ori = arith.ori %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %xori = arith.xori %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %ceildivsi = arith.ceildivsi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %ceildivui = arith.ceildivui %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %floordivsi = arith.floordivsi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %maxnumf = arith.maxnumf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %maxsi = arith.maxsi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %maxui = arith.maxui %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %minnumf = arith.minnumf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %minsi = arith.minsi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %minui = arith.minui %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %remf = arith.remf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cmpf = arith.cmpf ult, %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+
+ // math ops
+ %absi = math.absi %load_c
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %cbrt = math.cbrt %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %copysign = math.copysign %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %ctpop = math.ctpop %load_c
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %erfc = math.erfc %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %exp2 = math.exp2 %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %expm1 = math.expm1 %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %fpowi = math.fpowi %load_a, %load_c
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>, vector<24x32xi32>
+ %ipowi = math.ipowi %load_c, %load_d
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ %log10 = math.log10 %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %log1p = math.log1p %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %round = math.round %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %roundeven = math.roundeven %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %trunc = math.trunc %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ gpu.return
+ }
+}
>From e215e22faeb6e1e8123ebedf914f17053d3d3851 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 6 Jun 2025 15:17:08 +0000
Subject: [PATCH 2/2] Clean up
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 6 +-
.../XeGPU/xegpu-wg-to-sg-elemwise.mlir | 105 +++++++++++++++---
2 files changed, 92 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 972394a7b40ad..771642f1a34e9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -317,8 +317,7 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
-// This pattern matches elementwise ops (unary/binary) in math/arith dialects
-// with 1D or 2D vector types
+// This pattern transforms elementwise ops (unary/binary) in math/arith dialect
template <typename Op>
struct WgToSgElementwiseOp : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
@@ -344,7 +343,6 @@ struct WgToSgElementwiseOp : public OpConversionPattern<Op> {
}
}
- // Check for layout attribute with sgLayout
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
if (!layout || !layout.getSgLayout())
return rewriter.notifyMatchFailure(
@@ -364,7 +362,6 @@ struct WgToSgElementwiseOp : public OpConversionPattern<Op> {
}
}
- // Each operand is a list of values
size_t numVariants = adaptor.getOperands().empty()
? 0
: adaptor.getOperands().front().size();
@@ -586,7 +583,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
}
}
- // check layout attribute
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(
op->getAttrOfType<xegpu::LayoutAttr>("layout"));
return isLegal(layout);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
index c45312e4c2d74..85767f4f2bd67 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -1,8 +1,8 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
-gpu.module @test_elementwise_ops {
- // CHECK-LABEL: test_elemwise_ops
- gpu.func @test_elemwise_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+gpu.module @elementwise_ops {
+ // CHECK-LABEL: elemwise_ops
+ gpu.func @elemwise_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
@@ -193,8 +193,8 @@ gpu.module @test_elementwise_ops {
}
// 1 to N decomposition of elementwise operations
- // CHECK-LABEL: test_elemwise_ops_sg_rr_assignment
- gpu.func @test_elemwise_ops_sg_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ // CHECK-LABEL: elemwise_ops_rr_assignment
+ gpu.func @elemwise_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
@@ -219,45 +219,85 @@ gpu.module @test_elementwise_ops {
// Floating point ops
// CHECK-COUNT-12: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.addf
// CHECK-COUNT-12: arith.subf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.subf
// CHECK-COUNT-12: math.exp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.exp
// CHECK-COUNT-12: math.sqrt {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.sqrt
// CHECK-COUNT-12: math.absf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.absf
// CHECK-COUNT-12: math.cos {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.cos
// CHECK-COUNT-12: math.cosh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.cosh
// CHECK-COUNT-12: math.acos {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.acos
// CHECK-COUNT-12: math.acosh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.acosh
// CHECK-COUNT-12: math.sin {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.sin
// CHECK-COUNT-12: math.sinh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.sinh
// CHECK-COUNT-12: math.asin {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.asin
// CHECK-COUNT-12: math.asinh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.asinh
// CHECK-COUNT-12: math.tan {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.tan
// CHECK-COUNT-12: math.tanh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.tanh
// CHECK-COUNT-12: math.atan {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.atan
// CHECK-COUNT-12: math.atan2 {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.atan2
// CHECK-COUNT-12: math.atanh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.atanh
// CHECK-COUNT-12: math.erf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.erf
// CHECK-COUNT-12: math.log {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.log
// CHECK-COUNT-12: math.log2 {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.log2
// CHECK-COUNT-12: math.floor {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.floor
// CHECK-COUNT-12: math.ceil {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.ceil
// CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.powf
// CHECK-COUNT-12: math.rsqrt {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: math.rsqrt
// CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.negf
// CHECK-COUNT-12: arith.mulf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.mulf
// CHECK-COUNT-12: arith.divf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.divf
// CHECK-COUNT-12: arith.maximumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.maximumf
// CHECK-COUNT-12: arith.minimumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.minimumf
// CHECK-COUNT-12: arith.addi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.addi
// CHECK-COUNT-12: arith.subi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.subi
// CHECK-COUNT-12: arith.muli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.muli
// CHECK-COUNT-12: arith.shli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.shli
// CHECK-COUNT-12: arith.shrsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.shrsi
// CHECK-COUNT-12: arith.shrui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.shrui
// CHECK-COUNT-12: arith.divsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.divsi
// CHECK-COUNT-12: arith.divui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.divui
// CHECK-COUNT-12: arith.remsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.remsi
// CHECK-COUNT-12: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.remui
%addf = arith.addf %load_a, %load_b
{layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
: vector<24x32xf32>
@@ -384,8 +424,8 @@ gpu.module @test_elementwise_ops {
gpu.return
}
- // CHECK-LABEL: test_all_type_conversion_ops
- gpu.func @test_all_type_conversion_ops(
+ // CHECK-LABEL: type_conversion_ops
+ gpu.func @type_conversion_ops(
%a: memref<24x32xf32>, %b: memref<24x32xi32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
@@ -476,8 +516,8 @@ gpu.module @test_elementwise_ops {
}
- // CHECK-LABEL: gpu.func @test_all_type_conversion_ops_rr_assignment
- gpu.func @test_all_type_conversion_ops_rr_assignment(
+ // CHECK-LABEL: gpu.func @type_conversion_ops_rr_assignment
+ gpu.func @type_conversion_ops_rr_assignment(
%a: memref<24x32xf32>, %b: memref<24x32xi32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
@@ -502,17 +542,29 @@ gpu.module @test_elementwise_ops {
-> vector<24x32xi32>
// CHECK-COUNT-12: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32> to vector<2x2xf16>
+ // CHECK-NOT: arith.truncf
// CHECK-COUNT-12: arith.trunci {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32> to vector<2x2xi16>
+ // CHECK-NOT: arith.trunci
// CHECK-COUNT-12: arith.extf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf16> to vector<2x2xf32>
+ // CHECK-NOT: arith.extf
// CHECK-COUNT-12: arith.extsi {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi16> to vector<2x2xi32>
+ // CHECK-NOT: arith.extsi
// CHECK-COUNT-12: arith.extui {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi16> to vector<2x2xi32>
+ // CHECK-NOT: arith.extui
// CHECK-COUNT-12: arith.sitofp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32> to vector<2x2xf32>
+ // CHECK-NOT: arith.sitofp
// CHECK-COUNT-12: arith.uitofp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32> to vector<2x2xf32>
+ // CHECK-NOT: arith.uitofp
// CHECK-COUNT-12: arith.fptosi {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32> to vector<2x2xi32>
+ // CHECK-NOT: arith.fptosi
// CHECK-COUNT-12: arith.fptoui {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32> to vector<2x2xi32>
+ // CHECK-NOT: arith.fptoui
// CHECK-COUNT-12: arith.index_castui {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32> to vector<2x2xindex>
+ // CHECK-NOT: arith.index_castui
// CHECK-COUNT-12: arith.index_cast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xindex> to vector<2x2xi32>
+ // CHECK-NOT: arith.index_cast
// CHECK-COUNT-12: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32> to vector<2x2xf32>
+ // CHECK-NOT: arith.bitcast
// TruncFOp: f32 -> f16
%truncf = arith.truncf %load_a
{layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
@@ -567,8 +619,8 @@ gpu.module @test_elementwise_ops {
gpu.return
}
- // CHECK-LABEL: gpu.func @test_cmp_ops
- gpu.func @test_cmp_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ // CHECK-LABEL: gpu.func @comparison_ops
+ gpu.func @comparison_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
@@ -693,8 +745,8 @@ gpu.module @test_elementwise_ops {
gpu.return
}
- // CHECK-LABEL: gpu.func @test_cmp_ops_rr_assignment
- gpu.func @test_cmp_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ // CHECK-LABEL: gpu.func @comparison_ops_rr_assignment
+ gpu.func @comparison_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
@@ -718,30 +770,54 @@ gpu.module @test_elementwise_ops {
-> vector<24x32xi32>
// CHECK-COUNT-12: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.cmpi eq
// CHECK-COUNT-12: arith.cmpi ne, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.cmpi ne
// CHECK-COUNT-12: arith.cmpi slt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.cmpi slt
// CHECK-COUNT-12: arith.cmpi sle, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.cmpi sle
// CHECK-COUNT-12: arith.cmpi sgt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.cmpi sgt
// CHECK-COUNT-12: arith.cmpi sge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.cmpi sge
// CHECK-COUNT-12: arith.cmpi ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.cmpi ult
// CHECK-COUNT-12: arith.cmpi ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.cmpi ule
// CHECK-COUNT-12: arith.cmpi ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.cmpi ugt
// CHECK-COUNT-12: arith.cmpi uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xi32>
+ // CHECK-NOT: arith.cmpi uge
// Floating point comparisons
// CHECK-COUNT-12: arith.cmpf oeq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf oeq
// CHECK-COUNT-12: arith.cmpf ogt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf ogt
// CHECK-COUNT-12: arith.cmpf oge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf oge
// CHECK-COUNT-12: arith.cmpf olt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf olt
// CHECK-COUNT-12: arith.cmpf ole, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf ole
// CHECK-COUNT-12: arith.cmpf one, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf one
// CHECK-COUNT-12: arith.cmpf ord, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf ord
// CHECK-COUNT-12: arith.cmpf ueq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf ueq
// CHECK-COUNT-12: arith.cmpf ugt, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf ugt
// CHECK-COUNT-12: arith.cmpf uge, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf uge
// CHECK-COUNT-12: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf ult
// CHECK-COUNT-12: arith.cmpf ule, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf ule
// CHECK-COUNT-12: arith.cmpf une, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf une
// CHECK-COUNT-12: arith.cmpf uno, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+ // CHECK-NOT: arith.cmpf uno
// Integer comparisons
%cmpi_eq = arith.cmpi eq, %load_c, %load_d
@@ -822,7 +898,8 @@ gpu.module @test_elementwise_ops {
gpu.return
}
- gpu.func @test_extra_elemwise_ops(
+ // CHECK-LABEL: gpu.func @elementwise_ops
+ gpu.func @elementwise_ops(
%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>, %e: memref<24x32xi1>) {
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
More information about the Mlir-commits
mailing list