[Mlir-commits] [mlir] [MLIR][XeGPU] Switch to 1D representation for SIMT code (PR #135116)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 10 13:31:24 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

<details>
<summary>Changes</summary>

As suggested by the name, this PR switch to use the 1D vector to represent SIMT code, thus we can get rid of unnecessary layout attributes in dpas and tensor_desc for SIMT code. This also leads to the relax of verifier for SIMT code, since parameters needed for checking are architecture specific. The correctness check is supposed to be done by a target object.  

---

Patch is 57.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135116.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+8-11) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+2-1) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+12-14) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+137-91) 
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+22-79) 
- (modified) mlir/test/Dialect/XeGPU/ops.mlir (+76-86) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 16a7f63d60c82..5fa18754305ca 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -833,16 +833,16 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
     data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
     and `C/D: vector<8x16xf32>`. Besides the matrix size requirements, DPAS
     also requires A and B to be loaded with the required data layout. Specially,
-
     VNNI layout is required for B operand. It is achieved via adding `packed`
     attribute to the `load_nd` operator.  Due to the VNNI transformation, B operands
     can be represented as a 3D vector, with the last dimension representing the VNNI
     factor, which is computed as `32/bit_width_of_elem_type`. Thus, `B: vector<16x16xf16>`
     can be represented as `B: vector<8x16x2xf16>`.
 
-    In SIMT mode, DpasOp expects layout attributes `a`, `b`, and `c` (only if acc is used)
-    which describe the data fragment owned by each work-item w.r.t. the tensor descriptor
-    these data are loaded from.
+    In SIMT code, each work-item from a subgroup holds a data fragment for A, B, C and the result,
+    which are represented as 1D vectors. Please refer to [OpenCL Intel extentions]
+    (https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html)
+    for more details about the fragment distribution.
 
     Note: on PVC, the hardware can perform load with VNNI transformation when data
           element type is 16-bit or lower precision, taking 2 or 4 elements from
@@ -850,13 +850,10 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
   }];
 
   let arguments = (ins
-    XeGPU_DpasOpType : $lhs,
-    XeGPU_DpasOpType : $rhs,
-    Optional<XeGPU_Vector2DType>: $acc,
-    OptionalAttr<XeGPU_LayoutAttr>:$a_layout,
-    OptionalAttr<XeGPU_LayoutAttr>:$b_layout,
-    OptionalAttr<XeGPU_LayoutAttr>:$c_layout);
-  let results = (outs XeGPU_Vector2DType: $result);
+    XeGPU_DpasOprType : $lhs,
+    XeGPU_DpasOprType : $rhs,
+    Optional<XeGPU_DpasResType>: $acc);
+  let results = (outs XeGPU_DpasResType: $result);
 
   let extraClassDeclaration = [{
     VectorType getLhsType() {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 173f1462fdd73..3cb71788a15ef 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -17,7 +17,8 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64,
 def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
 def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
-def XeGPU_DpasOpType: VectorOfRankAndType<[2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
+def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
 def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
 def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
 def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 171a15ce27b59..269e445c3790c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -10,6 +10,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <numeric>
 
 namespace mlir {
 namespace xegpu {
@@ -336,19 +337,20 @@ LogicalResult TensorDescType::verify(
 //        [n_distribution_units, lane_data_size]
 FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
   auto layout = llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
-  // If no layout is provided, tensor desc is not used in SIMT mode.
-  if (!layout)
+  // It only works for subgroup level layout, which only has lane_layout
+  // and lane_data, and is to distribute a SIMD code into SIMT code.
+  if (!layout || !layout.isSgLayout())
     return failure();
 
   SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
   SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef());
   auto tdescShape = getShape();
 
-  auto laneDataSize = 1, sgSize = 1;
-  for (auto [laneDim, laneDataDim] : llvm::zip_equal(laneLayout, laneData)) {
-    laneDataSize *= laneDataDim;
-    sgSize *= laneDim;
-  }
+  // compute sgSize by multiply elements of laneLayout
+  // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
+  // e.g. for 1D layout, sgSize = laneLayout[0]
+  auto sgSize = std::accumulate(laneLayout.begin(), laneLayout.end(), 1,
+                                std::multiplies<int64_t>());
 
   // Case 1: regular loads/stores
   auto scatterAttr = getEncodingAsScatterTensorDescAttr();
@@ -356,12 +358,9 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
     auto chunkSize = scatterAttr.getChunkSize().getInt();
     // Verify if the first dimension of the tensor descriptor shape is
     // distributable.
-    assert(tdescShape[0] % (laneLayout[0]) == 0 &&
+    assert(tdescShape[0] == laneLayout[0] &&
            "tensor descriptor shape is not distributable");
-    if (chunkSize > 1)
-      return VectorType::get({chunkSize / laneDataSize, laneDataSize},
-                             getElementType());
-    return VectorType::get({laneDataSize}, getElementType());
+    return VectorType::get({chunkSize}, getElementType());
   }
 
   // Case 2: block loads/stores
@@ -376,8 +375,7 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
   // tensorSize must be adjusted for array_length.
   tensorSize *= getArrayLength();
 
-  return VectorType::get({tensorSize / (sgSize * laneDataSize), laneDataSize},
-                         getElementType());
+  return VectorType::get({tensorSize / sgSize}, getElementType());
 }
 
 } // namespace xegpu
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0d67e3d70f945..1dafc9936107e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -73,38 +73,6 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
          kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
 }
 
-// Helper to validate value shape of LoadNd and StoreNd ops.
-static LogicalResult
-isArgShapesValid(TensorDescType tdescTy, VectorType valueTy,
-                 ArrayRef<int64_t> adjustedTdescShape,
-                 function_ref<InFlightDiagnostic()> emitError) {
-  auto layout = tdescTy.getLayoutAttr();
-  auto valueShape = valueTy.getShape();
-  // layout not present means IR is in SIMD mode. In this case value shape must
-  // match adjusted tensor descriptor shape.
-  if (!layout)
-    return valueShape == adjustedTdescShape
-               ? success()
-               : emitError()
-                     << "Value shape " << makeString(valueShape)
-                     << " is not consistent with tensor descriptor " << tdescTy;
-
-  // layout present means IR is in SIMT mode. In this case layout determines the
-  // value shape.
-  auto expectedValueShapeOrFailure = tdescTy.getDistributedVectorType();
-  assert(succeeded(expectedValueShapeOrFailure) &&
-         "Failed to compute distributed vector shape for "
-         "tensor descriptor ");
-
-  return valueTy == expectedValueShapeOrFailure.value()
-             ? success()
-             : emitError()
-                   << "Result shape " << makeString(valueShape)
-                   << " is not consistent with distributed vector shape "
-                   << makeString(expectedValueShapeOrFailure.value().getShape())
-                   << " for tensor descriptor " << tdescTy;
-}
-
 // Checks if the given shape is evenly distributed based on the layout
 // and data factors provided by the LayoutAttr. The function ensures that
 // each dimension of the shape can be evenly divided by the corresponding
@@ -302,9 +270,33 @@ LogicalResult LoadNdOp::verify() {
   if (!isReadHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  auto array_len = tdescTy.getArrayLength();
+  int tdescElems = tdescTy.getNumElements() * tdescTy.getArrayLength();
+  int valueElems = valueTy.getNumElements();
+
+  // If the result vector is 1D and has less elements than the tensor
+  // descriptor, it is supposed to be a SIMT op. The layout attribute in
+  // tensor_desc is not needed.
+  if (valueElems < tdescElems && valueTy.getRank() == 1) {
+    // SIMT mode doesn't need LayoutAttr.
+    if (tdescTy.getLayoutAttr())
+      return emitOpError()
+             << "TensorDesc doesn't need LayoutAttr for SIMT code";
+
+    // For SIMT code, the load is evenly distributed across all lanes in a
+    // subgroup. Since subgroup size is arch dependent, we only check even
+    // distribution here.
+    if (tdescElems % valueElems)
+      return emitOpError()
+             << "Result shape " << makeString(getShapeOf(valueTy))
+             << " is not a valid distribution for tensor descriptor "
+             << tdescTy;
+
+    return success();
+  }
+
+  // Check SIMD mode.
   // adjusted tensor descriptor shape tracks the expected shape of the result.
-  auto adjustedTdescShape = getShapeOf(tdescTy);
+  auto tdescShape = getShapeOf(tdescTy);
   auto valueShape = getShapeOf(valueTy);
 
   if (getTranspose()) {
@@ -316,7 +308,7 @@ LogicalResult LoadNdOp::verify() {
     });
 
     if (valid)
-      transpose(trans, adjustedTdescShape);
+      transpose(trans, tdescShape);
     else
       mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
   }
@@ -325,8 +317,8 @@ LogicalResult LoadNdOp::verify() {
     if (tdescTy.getRank() == 2) {
       const int axis = 0;
       auto vnni_factor = valueShape.back();
-      adjustedTdescShape[axis] /= vnni_factor;
-      adjustedTdescShape.push_back(vnni_factor);
+      tdescShape[axis] /= vnni_factor;
+      tdescShape.push_back(vnni_factor);
     } else {
       mlir::emitWarning(getLoc())
           << "Invalid Packed Attr. It is ignored (available for 2D "
@@ -334,13 +326,18 @@ LogicalResult LoadNdOp::verify() {
     }
   }
 
+  auto array_len = tdescTy.getArrayLength();
   if (array_len > 1) {
-    auto it = adjustedTdescShape.begin();
-    adjustedTdescShape.insert(it, array_len);
+    tdescShape.insert(tdescShape.begin(), array_len);
+  }
+
+  if (tdescShape != valueShape) {
+    return emitOpError() << "Result shape " << makeString(valueShape)
+                         << " is not consistent with tensor descriptor "
+                         << tdescTy;
   }
 
-  return isArgShapesValid(tdescTy, valueTy, adjustedTdescShape,
-                          [&]() { return emitOpError(); });
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -368,11 +365,40 @@ LogicalResult StoreNdOp::verify() {
   if (!isWriteHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
+  auto array_len = dstTy.getArrayLength();
+  if (array_len > 1)
+    return emitOpError("array length is not supported by store_nd.\n");
+
+  auto tdescElems = dstTy.getNumElements();
+  auto valueElems = valTy.getNumElements();
+
+  // Similar to LoadNdOp, if the value vector is 1D and has less elements than
+  // the tensor descriptor, it is supposed to be a SIMT op. The layout attribute
+  // in tensor_desc is not needed.
+  if (valTy.getRank() == 1 && valueElems < tdescElems) {
+    // SIMT mode doesn't need LayoutAttr.
+    if (dstTy.getLayoutAttr())
+      return emitOpError()
+             << "TensorDesc doesn't need LayoutAttr for SIMT code";
+
+    if (tdescElems % valueElems) {
+      return emitOpError()
+             << "Value shape " << makeString(getShapeOf(valTy))
+             << " is not a valid distribution for tensor descriptor " << dstTy;
+    }
+    return success();
+  }
+
+  // SIMD code should have the same shape as the tensor descriptor.
   auto tdescShape = getShapeOf(dstTy);
   auto valueShape = getShapeOf(valTy);
+  if (tdescShape != valueShape) {
+    return emitOpError() << "Value shape " << makeString(valueShape)
+                         << " is not consistent with tensor descriptor "
+                         << dstTy;
+  }
 
-  return isArgShapesValid(dstTy, valTy, tdescShape,
-                          [&]() { return emitOpError(); });
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -520,14 +546,41 @@ LogicalResult LoadGatherOp::verify() {
   if (tdescShape[0] != maskShape[0])
     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
 
+  auto chunkSize = tdescTy.getChunkSize();
+  // for SIMT code, the value should be 1D vector with size of chunkSize.
+  if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
+    if (valueTy.getNumElements() != chunkSize) {
+      return emitOpError()
+             << "Result shape " << makeString(valueShape)
+             << " is not a valid distribution for tensor descriptor "
+             << tdescTy;
+    } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
+      if (tdescTy.getLayoutAttr())
+        return emitOpError()
+               << "TensorDesc doesn't need LayoutAttr for SIMT code";
+      if (getTransposeAttr())
+        return emitOpError() << "doesn't need TransposeAttr for SIMT code";
+    }
+    return success();
+  } else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
+    // for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
+    // it is a valid SIMT code if chunkSize happens to be the same as
+    // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
+    return success();
+  }
+
+  // For SIMD code verification.
   if (tdescTy.getRank() == 2) {
     if (!getTransposeAttr())
       return emitOpError("load of rank-2 tensor has to be transposed.");
     transpose({1, 0}, tdescShape);
   }
 
-  return isArgShapesValid(tdescTy, valueTy, tdescShape,
-                          [&]() { return emitOpError(); });
+  if (tdescShape != valueShape)
+    return emitOpError() << "Result shape " << makeString(valueShape)
+                         << " is not consistent with tensor descriptor "
+                         << tdescTy;
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -559,14 +612,42 @@ LogicalResult StoreScatterOp::verify() {
   if (tdescShape[0] != maskShape[0])
     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
 
+  auto chunkSize = tdescTy.getChunkSize();
+  // for SIMT code, the value should be 1D vector with size of chunkSize.
+  if (valueTy.getRank() == 1 && valueTy.getNumElements() != tdescShape[0]) {
+    if (valueTy.getNumElements() != chunkSize) {
+      return emitOpError()
+             << "Value shape " << makeString(valueShape)
+             << " is not a valid distribution for tensor descriptor "
+             << tdescTy;
+    } else { // valid SIMT code doesn't need LayoutAttr and TransposeAttr.
+      if (tdescTy.getLayoutAttr())
+        return emitOpError()
+               << "TensorDesc doesn't need LayoutAttr for SIMT code";
+      if (getTransposeAttr())
+        return emitOpError() << "doesn't need TransposeAttr for SIMT code";
+    }
+    return success();
+  } else if (valueTy.getRank() == 1 && tdescShape[0] == chunkSize) {
+    // for 1D vector and valueTy.getNumElements() == tdescShape[0] case,
+    // it is a valid SIMT code if chunkSize happens to be the same as
+    // subgroup size, e.g., tensor_desc<16x16xf16, chunkSize = 16>
+    return success();
+  }
+
+  // for SIMD code verification.
   if (tdescTy.getRank() == 2) {
     if (!getTransposeAttr())
       return emitOpError("Store of a rank-2 tensor has to be transposed.");
     transpose({1, 0}, tdescShape);
   }
 
-  return isArgShapesValid(tdescTy, valueTy, tdescShape,
-                          [&]() { return emitOpError(); });
+  if (tdescShape != valueShape)
+    return emitOpError() << "Value shape " << makeString(valueShape)
+                         << " is not consistent with tensor descriptor "
+                         << tdescTy;
+
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -602,51 +683,16 @@ LogicalResult DpasOp::verify() {
   auto rhsShape = getRhsType().getShape();
   auto resShape = getResultType().getShape();
 
-  auto aLayout = getALayoutAttr();
-  auto bLayout = getBLayoutAttr();
-  auto cLayout = getCLayoutAttr();
-
-  // make sure the layout attribute is either set for every available
-  // operand or simply not set at all. C is special, since ACC is optional.
-  auto hasValidLayoutAttrs = [&]() {
-    bool result = (aLayout != nullptr) ^ (bLayout != nullptr);
-    if (hasAcc()) {
-      result |= (aLayout != nullptr) ^ (cLayout != nullptr);
-    }
-    return !result;
-  };
+  if (getAcc()) {
+    if (getAcc().getType() != getResultType())
+      return emitOpError("Expecting the acc type to be the same as result.");
+  }
 
-  if (!hasValidLayoutAttrs())
-    return emitOpError(
-        "layout attributes should be either set for all operands (for SIMT "
-        "code) or not set at all (for SIMD code).");
-
-  // query the scope from aLayout (a valid setting).
-  if (aLayout) {
-    // In SIMT mode, All data fragments must be 2D
-    if (lhsRank != 2 || rhsRank != 2 || resRank != 2)
-      return emitOpError("expecting lhs, rhs, and result to be a 2D vector.");
-
-    auto laneLayoutA = aLayout.getLaneLayout();
-    auto laneLayoutB = bLayout.getLaneLayout();
-    auto laneLayoutC = cLayout.getLaneLayout();
-    // Obtain the expanded shapes of the operands and result using lane_layout.
-    // NOTE: For B, get rid of the packed dimension for the expanded shape.
-    SmallVector<int64_t> expandedShapeA = {lhsShape[0] * laneLayoutA[0],
-                                           lhsShape[1] * laneLayoutA[1]};
-    SmallVector<int64_t> expandedShapeB = {
-        rhsShape[0] * rhsShape[1] * laneLayoutB[0], 1 * laneLayoutB[1]};
-    SmallVector<int64_t> expandedShapeC = {resShape[0] * laneLayoutC[0],
-                                           resShape[1] * laneLayoutC[1]};
-    auto bK = expandedShapeB[0];
-    if (bK != expandedShapeA[1])
-      return emitOpError("K-dimension mismatch.");
-    if (expandedShapeA[0] != expandedShapeC[0])
-      return emitOpError("M-dimension mismatch.");
-    if (expandedShapeB[1] != expandedShapeC[1])
-      return emitOpError("N-dimension mismatch.");
-  } else { // For other scopes, operands' shape should match the mxkxn
-           // semantics.
+  // SIMT code: skip the check since lack of semantic info at this level.
+  // Users need to ensure the correctness.
+  if (lhsRank == 1 && rhsRank == 1 && resRank == 1) {
+    return success();
+  } else { // SIMD code
     if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
       return emitOpError(
           "expecting lhs and result to be a 2D vector, and rhs to be either "
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 48df33a591908..a02427b6e317b 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -79,25 +79,10 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) {
 
 // -----
 func.func @test_load_nd_layout(%src: memref<24x32xf32>) {
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-    !xegpu.tensor_desc<8x16xf32,   #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  // expected-error at +1 {{Result shape [8, 2] is not consistent with distributed vector shape [8, 1]}}
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
-      l2_hint = #xegpu.cache_hint<uncached>}>
-    : !xegpu.tensor_desc<8x16xf32,   #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    -> vector<8x2xf32>
-  return
-}
-
-// -----
-func.func @test...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/135116


More information about the Mlir-commits mailing list